initial public commit
This commit is contained in:
@@ -0,0 +1,59 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type migration struct {
|
||||
file string
|
||||
up func(db dbx.Builder) error
|
||||
down func(db dbx.Builder) error
|
||||
}
|
||||
|
||||
// MigrationsList defines a list with migration definitions
|
||||
type MigrationsList struct {
|
||||
list []*migration
|
||||
}
|
||||
|
||||
// Item returns a single migration from the list by its index.
|
||||
func (l *MigrationsList) Item(index int) *migration {
|
||||
return l.list[index]
|
||||
}
|
||||
|
||||
// Items returns the internal migrations list slice.
|
||||
func (l *MigrationsList) Items() []*migration {
|
||||
return l.list
|
||||
}
|
||||
|
||||
// Register adds new migration definition to the list.
|
||||
//
|
||||
// If `optFilename` is not provided, it will try to get the name from its .go file.
|
||||
//
|
||||
// The list will be sorted automatically based on the migrations file name.
|
||||
func (l *MigrationsList) Register(
|
||||
up func(db dbx.Builder) error,
|
||||
down func(db dbx.Builder) error,
|
||||
optFilename ...string,
|
||||
) {
|
||||
var file string
|
||||
if len(optFilename) > 0 {
|
||||
file = optFilename[0]
|
||||
} else {
|
||||
_, path, _, _ := runtime.Caller(1)
|
||||
file = filepath.Base(path)
|
||||
}
|
||||
|
||||
l.list = append(l.list, &migration{
|
||||
file: file,
|
||||
up: up,
|
||||
down: down,
|
||||
})
|
||||
|
||||
sort.Slice(l.list, func(i int, j int) bool {
|
||||
return l.list[i].file < l.list[j].file
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMigrationsList(t *testing.T) {
|
||||
l := MigrationsList{}
|
||||
|
||||
l.Register(nil, nil, "3_test.go")
|
||||
l.Register(nil, nil, "1_test.go")
|
||||
l.Register(nil, nil, "2_test.go")
|
||||
l.Register(nil, nil /* auto detect file name */)
|
||||
|
||||
expected := []string{
|
||||
"1_test.go",
|
||||
"2_test.go",
|
||||
"3_test.go",
|
||||
"list_test.go",
|
||||
}
|
||||
|
||||
items := l.Items()
|
||||
if len(items) != len(expected) {
|
||||
t.Fatalf("Expected %d items, got %d: \n%#v", len(expected), len(items), items)
|
||||
}
|
||||
|
||||
for i, name := range expected {
|
||||
item := l.Item(i)
|
||||
if item.file != name {
|
||||
t.Fatalf("Expected name %s for index %d, got %s", name, i, item.file)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/AlecAivazis/survey/v2"
|
||||
"github.com/fatih/color"
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
const migrationsTable = "_migrations"
|
||||
|
||||
// Runner defines a simple struct for managing the execution of db migrations.
|
||||
type Runner struct {
|
||||
db *dbx.DB
|
||||
migrationsList MigrationsList
|
||||
tableName string
|
||||
}
|
||||
|
||||
// NewRunner creates and initializes a new db migrations Runner instance.
|
||||
func NewRunner(db *dbx.DB, migrationsList MigrationsList) (*Runner, error) {
|
||||
runner := &Runner{
|
||||
db: db,
|
||||
migrationsList: migrationsList,
|
||||
tableName: migrationsTable,
|
||||
}
|
||||
|
||||
if err := runner.createMigrationsTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return runner, nil
|
||||
}
|
||||
|
||||
// Run interactively executes the current runner with the provided args.
|
||||
//
|
||||
// The following commands are supported:
|
||||
// - up - applies all migrations
|
||||
// - down [n] - reverts the last n applied migrations
|
||||
// - create NEW_MIGRATION_NAME - create NEW_MIGRATION_NAME.go file from a migration template
|
||||
func (r *Runner) Run(args ...string) error {
|
||||
cmd := "up"
|
||||
if len(args) > 0 {
|
||||
cmd = args[0]
|
||||
}
|
||||
|
||||
switch cmd {
|
||||
case "up":
|
||||
applied, err := r.Up()
|
||||
if err != nil {
|
||||
color.Red(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if len(applied) == 0 {
|
||||
color.Green("No new migrations to apply.")
|
||||
} else {
|
||||
for _, file := range applied {
|
||||
color.Green("Applied %s", file)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
case "down":
|
||||
toRevertCount := 1
|
||||
if len(args) > 1 {
|
||||
toRevertCount = cast.ToInt(args[1])
|
||||
if toRevertCount < 0 {
|
||||
// revert all applied migrations
|
||||
toRevertCount = len(r.migrationsList.Items())
|
||||
}
|
||||
}
|
||||
|
||||
confirm := false
|
||||
prompt := &survey.Confirm{
|
||||
Message: fmt.Sprintf("Do you really want to revert the last %d applied migration(s)?", toRevertCount),
|
||||
}
|
||||
survey.AskOne(prompt, &confirm)
|
||||
if !confirm {
|
||||
fmt.Println("The command has been cancelled")
|
||||
return nil
|
||||
}
|
||||
|
||||
reverted, err := r.Down(toRevertCount)
|
||||
if err != nil {
|
||||
color.Red(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if len(reverted) == 0 {
|
||||
color.Green("No migrations to revert.")
|
||||
} else {
|
||||
for _, file := range reverted {
|
||||
color.Green("Reverted %s", file)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
case "create":
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("Missing migration file name")
|
||||
}
|
||||
|
||||
name := args[1]
|
||||
|
||||
var dir string
|
||||
if len(args) == 3 {
|
||||
dir = args[2]
|
||||
}
|
||||
if dir == "" {
|
||||
// If not specified, auto point to the default migrations folder.
|
||||
//
|
||||
// NB!
|
||||
// Since the create command makes sense only during development,
|
||||
// it is expected the user to be in the app working directory
|
||||
// and to be using `go run ...`
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dir = path.Join(wd, "migrations")
|
||||
}
|
||||
|
||||
resultFilePath := path.Join(
|
||||
dir,
|
||||
fmt.Sprintf("%d_%s.go", time.Now().Unix(), inflector.Snakecase(name)),
|
||||
)
|
||||
|
||||
confirm := false
|
||||
prompt := &survey.Confirm{
|
||||
Message: fmt.Sprintf("Do you really want to create migration %q?", resultFilePath),
|
||||
}
|
||||
survey.AskOne(prompt, &confirm)
|
||||
if !confirm {
|
||||
fmt.Println("The command has been cancelled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensure that migrations dir exist
|
||||
if err := os.MkdirAll(dir, os.ModePerm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.WriteFile(resultFilePath, []byte(createTemplateContent), 0644); err != nil {
|
||||
return fmt.Errorf("Failed to save migration file %q\n", resultFilePath)
|
||||
}
|
||||
|
||||
fmt.Printf("Successfully created file %q\n", resultFilePath)
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("Unsupported command: %q\n", cmd)
|
||||
}
|
||||
}
|
||||
|
||||
// Up executes all unapplied migrations for the provided runner.
|
||||
//
|
||||
// On success returns list with the applied migrations file names.
|
||||
func (r *Runner) Up() ([]string, error) {
|
||||
applied := []string{}
|
||||
|
||||
err := r.db.Transactional(func(tx *dbx.Tx) error {
|
||||
for _, m := range r.migrationsList.Items() {
|
||||
// skip applied
|
||||
if r.isMigrationApplied(tx, m.file) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.up(tx); err != nil {
|
||||
return fmt.Errorf("Failed to apply migration %s: %w", m.file, err)
|
||||
}
|
||||
|
||||
if err := r.saveAppliedMigration(tx, m.file); err != nil {
|
||||
return fmt.Errorf("Failed to save applied migration info for %s: %w", m.file, err)
|
||||
}
|
||||
|
||||
applied = append(applied, m.file)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return applied, nil
|
||||
}
|
||||
|
||||
// Down reverts the last `toRevertCount` applied migrations.
|
||||
//
|
||||
// On success returns list with the reverted migrations file names.
|
||||
func (r *Runner) Down(toRevertCount int) ([]string, error) {
|
||||
applied := []string{}
|
||||
|
||||
err := r.db.Transactional(func(tx *dbx.Tx) error {
|
||||
totalReverted := 0
|
||||
|
||||
for i := len(r.migrationsList.Items()) - 1; i >= 0; i-- {
|
||||
m := r.migrationsList.Item(i)
|
||||
|
||||
// skip unapplied
|
||||
if !r.isMigrationApplied(tx, m.file) {
|
||||
continue
|
||||
}
|
||||
|
||||
// revert limit reached
|
||||
if toRevertCount-totalReverted <= 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if err := m.down(tx); err != nil {
|
||||
return fmt.Errorf("Failed to revert migration %s: %w", m.file, err)
|
||||
}
|
||||
|
||||
if err := r.saveRevertedMigration(tx, m.file); err != nil {
|
||||
return fmt.Errorf("Failed to save reverted migration info for %s: %w", m.file, err)
|
||||
}
|
||||
|
||||
applied = append(applied, m.file)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return applied, nil
|
||||
}
|
||||
|
||||
func (r *Runner) createMigrationsTable() error {
|
||||
rawQuery := fmt.Sprintf(
|
||||
"CREATE TABLE IF NOT EXISTS %v (file VARCHAR(255) PRIMARY KEY NOT NULL, applied INTEGER NOT NULL)",
|
||||
r.db.QuoteTableName(r.tableName),
|
||||
)
|
||||
|
||||
_, err := r.db.NewQuery(rawQuery).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Runner) isMigrationApplied(tx dbx.Builder, file string) bool {
|
||||
var exists bool
|
||||
|
||||
err := tx.Select("count(*)").
|
||||
From(r.tableName).
|
||||
Where(dbx.HashExp{"file": file}).
|
||||
Limit(1).
|
||||
Row(&exists)
|
||||
|
||||
return err == nil && exists
|
||||
}
|
||||
|
||||
func (r *Runner) saveAppliedMigration(tx dbx.Builder, file string) error {
|
||||
_, err := tx.Insert(r.tableName, dbx.Params{
|
||||
"file": file,
|
||||
"applied": time.Now().Unix(),
|
||||
}).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Runner) saveRevertedMigration(tx dbx.Builder, file string) error {
|
||||
_, err := tx.Delete(r.tableName, dbx.HashExp{"file": file}).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestNewRunner(t *testing.T) {
|
||||
testDB, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
l := MigrationsList{}
|
||||
l.Register(nil, nil, "1_test.go")
|
||||
l.Register(nil, nil, "2_test.go")
|
||||
l.Register(nil, nil, "3_test.go")
|
||||
|
||||
r, err := NewRunner(testDB.DB, l)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(r.migrationsList.Items()) != len(l.Items()) {
|
||||
t.Fatalf("Expected the same migrations list to be assigned, got \n%#v", r.migrationsList)
|
||||
}
|
||||
|
||||
expectedQueries := []string{
|
||||
"CREATE TABLE IF NOT EXISTS `_migrations` (file VARCHAR(255) PRIMARY KEY NOT NULL, applied INTEGER NOT NULL)",
|
||||
}
|
||||
if len(expectedQueries) != len(testDB.CalledQueries) {
|
||||
t.Fatalf("Expected %d queries, got %d: \n%v", len(expectedQueries), len(testDB.CalledQueries), testDB.CalledQueries)
|
||||
}
|
||||
for _, q := range expectedQueries {
|
||||
if !list.ExistInSlice(q, testDB.CalledQueries) {
|
||||
t.Fatalf("Query %s was not found in \n%v", q, testDB.CalledQueries)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunnerUpAndDown(t *testing.T) {
|
||||
testDB, err := createTestDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
var test1UpCalled bool
|
||||
var test1DownCalled bool
|
||||
var test2UpCalled bool
|
||||
var test2DownCalled bool
|
||||
|
||||
l := MigrationsList{}
|
||||
l.Register(func(db dbx.Builder) error {
|
||||
test1UpCalled = true
|
||||
return nil
|
||||
}, func(db dbx.Builder) error {
|
||||
test1DownCalled = true
|
||||
return nil
|
||||
}, "1_test")
|
||||
l.Register(func(db dbx.Builder) error {
|
||||
test2UpCalled = true
|
||||
return nil
|
||||
}, func(db dbx.Builder) error {
|
||||
test2DownCalled = true
|
||||
return nil
|
||||
}, "2_test")
|
||||
|
||||
r, err := NewRunner(testDB.DB, l)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// simulate partially run migration
|
||||
r.saveAppliedMigration(testDB, r.migrationsList.Item(0).file)
|
||||
|
||||
// Up()
|
||||
// ---
|
||||
if _, err := r.Up(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if test1UpCalled {
|
||||
t.Fatalf("Didn't expect 1_test to be called")
|
||||
}
|
||||
|
||||
if !test2UpCalled {
|
||||
t.Fatalf("Expected 2_test to be called")
|
||||
}
|
||||
|
||||
// simulate unrun migration
|
||||
var test3DownCalled bool
|
||||
r.migrationsList.Register(nil, func(db dbx.Builder) error {
|
||||
test3DownCalled = true
|
||||
return nil
|
||||
}, "3_test")
|
||||
|
||||
// Down()
|
||||
// ---
|
||||
if _, err := r.Down(2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if test3DownCalled {
|
||||
t.Fatal("Didn't expect 3_test to be reverted.")
|
||||
}
|
||||
|
||||
if !test1DownCalled || !test2DownCalled {
|
||||
t.Fatalf("Expected 1_test and 2_test to be reverted, got %v and %v", test1DownCalled, test2DownCalled)
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Helpers
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
type testDB struct {
|
||||
*dbx.DB
|
||||
CalledQueries []string
|
||||
}
|
||||
|
||||
// NB! Don't forget to call `db.Close()` at the end of the test.
|
||||
func createTestDB() (*testDB, error) {
|
||||
sqlDB, err := sql.Open("sqlite", ":memory:")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := testDB{DB: dbx.NewFromDB(sqlDB, "sqlite")}
|
||||
db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
db.CalledQueries = append(db.CalledQueries, sql)
|
||||
}
|
||||
db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
|
||||
db.CalledQueries = append(db.CalledQueries, sql)
|
||||
}
|
||||
|
||||
return &db, nil
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package migrate
|
||||
|
||||
const createTemplateContent = `package migrations
|
||||
|
||||
import (
|
||||
"github.com/pocketbase/dbx"
|
||||
m "github.com/pocketbase/pocketbase/migrations"
|
||||
)
|
||||
|
||||
func init() {
|
||||
m.Register(func(db dbx.Builder) error {
|
||||
// add up queries...
|
||||
|
||||
return nil
|
||||
}, func(db dbx.Builder) error {
|
||||
// add down queries...
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
`
|
||||
Reference in New Issue
Block a user