From 7ac3a744401006a8bbea2719721871e181e5f302 Mon Sep 17 00:00:00 2001 From: Gani Georgiev Date: Sun, 27 Nov 2022 23:00:58 +0200 Subject: [PATCH] refactored automigrate to be more granular --- plugins/jsvm/vm.go | 17 + plugins/migratecmd/automigrate.go | 77 ++-- plugins/migratecmd/migratecmd.go | 15 +- plugins/migratecmd/templates.go | 583 +++++++++++++++++++++++++++++- 4 files changed, 649 insertions(+), 43 deletions(-) diff --git a/plugins/jsvm/vm.go b/plugins/jsvm/vm.go index 2d8df2ac..b2e994b4 100644 --- a/plugins/jsvm/vm.go +++ b/plugins/jsvm/vm.go @@ -9,6 +9,7 @@ import ( "github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/daos" "github.com/pocketbase/pocketbase/models" + "github.com/pocketbase/pocketbase/models/schema" ) func NewBaseVM(app core.App) *goja.Runtime { @@ -33,6 +34,7 @@ func NewBaseVM(app core.App) *goja.Runtime { collectionConstructor(vm) recordConstructor(vm) adminConstructor(vm) + schemaConstructor(vm) daoConstructor(vm) dbxBinds(vm) @@ -66,6 +68,21 @@ func adminConstructor(vm *goja.Runtime) { }) } +func schemaConstructor(vm *goja.Runtime) { + vm.Set("Schema", func(call goja.ConstructorCall) *goja.Object { + instance := &schema.Schema{} + instanceValue := vm.ToValue(instance).(*goja.Object) + instanceValue.SetPrototype(call.This.Prototype()) + return instanceValue + }) + vm.Set("SchemaField", func(call goja.ConstructorCall) *goja.Object { + instance := &schema.SchemaField{} + instanceValue := vm.ToValue(instance).(*goja.Object) + instanceValue.SetPrototype(call.This.Prototype()) + return instanceValue + }) +} + func daoConstructor(vm *goja.Runtime) { vm.Set("Dao", func(call goja.ConstructorCall) *goja.Object { db, ok := call.Argument(0).Export().(dbx.Builder) diff --git a/plugins/migratecmd/automigrate.go b/plugins/migratecmd/automigrate.go index 34dbbe3c..38bb95b4 100644 --- a/plugins/migratecmd/automigrate.go +++ b/plugins/migratecmd/automigrate.go @@ -1,16 +1,14 @@ package migratecmd import ( + "database/sql" "errors" "fmt" "os" - "os/exec" "path/filepath" "sort" - "strings" "time" - "github.com/pocketbase/dbx" "github.com/pocketbase/pocketbase/core" m "github.com/pocketbase/pocketbase/migrations" "github.com/pocketbase/pocketbase/models" @@ -19,34 +17,34 @@ import ( const migrationsTable = "_migrations" const automigrateSuffix = "_automigrate" +const collectionsCacheKey = "_automigrate_collections" // onCollectionChange handles the automigration snapshot generation on // collection change event (create/update/delete). -func (p *plugin) onCollectionChange() func(*core.ModelEvent) error { +func (p *plugin) afterCollectionChange() func(*core.ModelEvent) error { return func(e *core.ModelEvent) error { if e.Model.TableName() != "_collections" { return nil // not a collection } - collections := []*models.Collection{} - if err := p.app.Dao().CollectionQuery().OrderBy("created ASC").All(&collections); err != nil { - return fmt.Errorf("failed to fetch collections list: %v", err) - } - if len(collections) == 0 { - return errors.New("missing collections to automigrate") + oldCollections, err := p.getCachedCollections() + if err != nil { + return err } - oldFiles, err := p.getAllMigrationNames() - if err != nil { - return fmt.Errorf("failed to fetch migration files list: %v", err) + old, _ := oldCollections[e.Model.GetId()] + + new, err := p.app.Dao().FindCollectionByNameOrId(e.Model.GetId()) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err } var template string var templateErr error if p.options.TemplateLang == TemplateLangJS { - template, templateErr = p.jsSnapshotTemplate(collections) + template, templateErr = p.jsDiffTemplate(new, old) } else { - template, templateErr = p.goSnapshotTemplate(collections) + template, templateErr = p.goDiffTemplate(new, old) } if templateErr != nil { return fmt.Errorf("failed to resolve template: %v", templateErr) @@ -64,31 +62,40 @@ func (p *plugin) onCollectionChange() func(*core.ModelEvent) error { return fmt.Errorf("failed to save automigrate file: %v", err) } - // remove the old untracked automigrate file - // (only if the last one was automigrate!) - if len(oldFiles) > 0 && strings.HasSuffix(oldFiles[len(oldFiles)-1], automigrateSuffix+"."+p.options.TemplateLang) { - olfName := oldFiles[len(oldFiles)-1] - oldPath := filepath.Join(p.options.Dir, olfName) - - isUntracked := exec.Command(p.options.GitPath, "ls-files", "--error-unmatch", oldPath).Run() != nil - if isUntracked { - // delete the old automigrate from the db if it was already applied - _, err := p.app.Dao().DB().Delete(migrationsTable, dbx.HashExp{"file": olfName}).Execute() - if err != nil { - return fmt.Errorf("failed to delete last applied automigrate from the migration db: %v", err) - } - - // delete the old automigrate file from the filesystem - if err := os.Remove(oldPath); err != nil && !os.IsNotExist(err) { - return fmt.Errorf("failed to delete last automigrates from the filesystem: %v", err) - } - } - } + p.refreshCachedCollections() return nil } } +func (p *plugin) refreshCachedCollections() error { + var collections []*models.Collection + if err := p.app.Dao().CollectionQuery().All(&collections); err != nil { + return err + } + + mapped := map[string]*models.Collection{} + for _, c := range collections { + mapped[c.Id] = c + } + + p.app.Cache().Set(collectionsCacheKey, mapped) + + return nil +} + +func (p *plugin) getCachedCollections() (map[string]*models.Collection, error) { + if !p.app.Cache().Has(collectionsCacheKey) { + if err := p.refreshCachedCollections(); err != nil { + return nil, err + } + } + + result, _ := p.app.Cache().Get(collectionsCacheKey).(map[string]*models.Collection) + + return result, nil +} + // getAllMigrationNames return sorted slice with both applied and new // local migration file names. func (p *plugin) getAllMigrationNames() ([]string, error) { diff --git a/plugins/migratecmd/migratecmd.go b/plugins/migratecmd/migratecmd.go index 8a479e18..7ae84eea 100644 --- a/plugins/migratecmd/migratecmd.go +++ b/plugins/migratecmd/migratecmd.go @@ -81,12 +81,17 @@ func Register(app core.App, rootCmd *cobra.Command, options *Options) error { // watch for collection changes if p.options.Automigrate { + p.app.OnAfterBootstrap().Add(func(e *core.BootstrapEvent) error { + p.refreshCachedCollections() + return nil + }) + if _, err := exec.LookPath(p.options.GitPath); err != nil { color.Yellow("WARNING: Automigrate cannot be enabled because %s is not installed or accessible.", p.options.GitPath) } else { - p.app.OnModelAfterCreate().Add(p.onCollectionChange()) - p.app.OnModelAfterUpdate().Add(p.onCollectionChange()) - p.app.OnModelAfterDelete().Add(p.onCollectionChange()) + p.app.OnModelAfterCreate().Add(p.afterCollectionChange()) + p.app.OnModelAfterUpdate().Add(p.afterCollectionChange()) + p.app.OnModelAfterDelete().Add(p.afterCollectionChange()) } } @@ -171,9 +176,9 @@ func (p *plugin) migrateCreateHandler(template string, args []string) error { if template == "" { var templateErr error if p.options.TemplateLang == TemplateLangJS { - template, templateErr = p.jsCreateTemplate() + template, templateErr = p.jsBlankTemplate() } else { - template, templateErr = p.goCreateTemplate() + template, templateErr = p.goBlankTemplate() } if templateErr != nil { return fmt.Errorf("Failed to resolve create template: %v\n", templateErr) diff --git a/plugins/migratecmd/templates.go b/plugins/migratecmd/templates.go index 5e6cfea8..61e31506 100644 --- a/plugins/migratecmd/templates.go +++ b/plugins/migratecmd/templates.go @@ -1,9 +1,13 @@ package migratecmd import ( + "bytes" "encoding/json" + "errors" "fmt" "path/filepath" + "strconv" + "strings" "github.com/pocketbase/pocketbase/models" ) @@ -17,7 +21,7 @@ const ( // JavaScript templates // ------------------------------------------------------------------- -func (p *plugin) jsCreateTemplate() (string, error) { +func (p *plugin) jsBlankTemplate() (string, error) { const template = `migrate((db) => { // add up queries... }, (db) => { @@ -48,11 +52,230 @@ func (p *plugin) jsSnapshotTemplate(collections []*models.Collection) (string, e return fmt.Sprintf(template, string(jsonData)), nil } +func (p *plugin) jsCreateTemplate(collection *models.Collection) (string, error) { + jsonData, err := json.MarshalIndent(collection, " ", " ") + if err != nil { + return "", fmt.Errorf("failed to serialize collections list: %v", err) + } + + const template = `migrate((db) => { + const collection = unmarshal(%s, new Collection()); + + return Dao(db).saveCollection(collection); +}, (db) => { + const dao = new Dao(db); + const collection = dao.findCollectionByNameOrId(%q); + + return dao.deleteCollection(collection); +}) +` + + return fmt.Sprintf(template, string(jsonData), collection.Id), nil +} + +func (p *plugin) jsDeleteTemplate(collection *models.Collection) (string, error) { + jsonData, err := json.MarshalIndent(collection, " ", " ") + if err != nil { + return "", fmt.Errorf("failed to serialize collections list: %v", err) + } + + const template = `migrate((db) => { + const dao = new Dao(db); + const collection = dao.findCollectionByNameOrId(%q); + + return dao.deleteCollection(collection); +}, (db) => { + const collection = unmarshal(%s, new Collection()); + + return Dao(db).saveCollection(collection); +}) +` + + return fmt.Sprintf(template, collection.Id, string(jsonData)), nil +} + +func (p *plugin) jsDiffTemplate(new *models.Collection, old *models.Collection) (string, error) { + if new == nil && old == nil { + return "", errors.New("the diff template require at least one of the collection to be non-nil") + } + + if new == nil { + return p.jsDeleteTemplate(old) + } + + if old == nil { + return p.jsCreateTemplate(new) + } + + upParts := []string{} + downParts := []string{} + varName := "collection" + + if old.Name != new.Name { + upParts = append(upParts, fmt.Sprintf("%s.name = %q", varName, new.Name)) + downParts = append(downParts, fmt.Sprintf("%s.name = %q", varName, old.Name)) + } + + if old.Type != new.Type { + upParts = append(upParts, fmt.Sprintf("%s.type = %q", varName, new.Type)) + downParts = append(downParts, fmt.Sprintf("%s.type = %q", varName, old.Type)) + } + + if old.System != new.System { + upParts = append(upParts, fmt.Sprintf("%s.system = %t", varName, new.System)) + downParts = append(downParts, fmt.Sprintf("%s.system = %t", varName, old.System)) + } + + // --- + // note: strconv.Quote is used because %q converts the rule operators in unicode char codes + // --- + + if old.ListRule != new.ListRule { + if old.ListRule != nil && new.ListRule == nil { + upParts = append(upParts, fmt.Sprintf("%s.listRule = null", varName)) + downParts = append(downParts, fmt.Sprintf("%s.listRule = %s", varName, strconv.Quote(*old.ListRule))) + } else if old.ListRule == nil && new.ListRule != nil || *old.ListRule != *new.ListRule { + upParts = append(upParts, fmt.Sprintf("%s.listRule = %s", varName, strconv.Quote(*new.ListRule))) + downParts = append(downParts, fmt.Sprintf("%s.listRule = null", varName)) + } + } + + if old.ViewRule != new.ViewRule { + if old.ViewRule != nil && new.ViewRule == nil { + upParts = append(upParts, fmt.Sprintf("%s.viewRule = null", varName)) + downParts = append(downParts, fmt.Sprintf("%s.viewRule = %s", varName, strconv.Quote(*old.ViewRule))) + } else if old.ViewRule == nil && new.ViewRule != nil || *old.ViewRule != *new.ViewRule { + upParts = append(upParts, fmt.Sprintf("%s.viewRule = %s", varName, strconv.Quote(*new.ViewRule))) + downParts = append(downParts, fmt.Sprintf("%s.viewRule = null", varName)) + } + } + + if old.CreateRule != new.CreateRule { + if old.CreateRule != nil && new.CreateRule == nil { + upParts = append(upParts, fmt.Sprintf("%s.createRule = null", varName)) + downParts = append(downParts, fmt.Sprintf("%s.createRule = %s", varName, strconv.Quote(*old.CreateRule))) + } else if old.CreateRule == nil && new.CreateRule != nil || *old.CreateRule != *new.CreateRule { + upParts = append(upParts, fmt.Sprintf("%s.createRule = %s", varName, strconv.Quote(*new.CreateRule))) + downParts = append(downParts, fmt.Sprintf("%s.createRule = null", varName)) + } + } + + if old.UpdateRule != new.UpdateRule { + if old.UpdateRule != nil && new.UpdateRule == nil { + upParts = append(upParts, fmt.Sprintf("%s.updateRule = null", varName)) + downParts = append(downParts, fmt.Sprintf("%s.updateRule = %s", varName, strconv.Quote(*old.UpdateRule))) + } else if old.UpdateRule == nil && new.UpdateRule != nil || *old.UpdateRule != *new.UpdateRule { + upParts = append(upParts, fmt.Sprintf("%s.updateRule = %s", varName, strconv.Quote(*new.UpdateRule))) + downParts = append(downParts, fmt.Sprintf("%s.updateRule = null", varName)) + } + } + + if old.DeleteRule != new.DeleteRule { + if old.DeleteRule != nil && new.DeleteRule == nil { + upParts = append(upParts, fmt.Sprintf("%s.deleteRule = null", varName)) + downParts = append(downParts, fmt.Sprintf("%s.deleteRule = %s", varName, strconv.Quote(*old.DeleteRule))) + } else if old.DeleteRule == nil && new.DeleteRule != nil || *old.DeleteRule != *new.DeleteRule { + upParts = append(upParts, fmt.Sprintf("%s.deleteRule = %s", varName, strconv.Quote(*new.DeleteRule))) + downParts = append(downParts, fmt.Sprintf("%s.deleteRule = null", varName)) + } + } + + // Options + rawNewOptions, err := json.MarshalIndent(new.Options, " ", " ") + if err != nil { + return "", err + } + rawOldOptions, err := json.MarshalIndent(old.Options, " ", " ") + if err != nil { + return "", err + } + if !bytes.Equal(rawNewOptions, rawOldOptions) { + upParts = append(upParts, fmt.Sprintf("%s.options = %s", varName, rawNewOptions)) + downParts = append(downParts, fmt.Sprintf("%s.options = %s", varName, rawOldOptions)) + } + + // Schema + // --- + // deleted fields + for _, oldField := range old.Schema.Fields() { + if new.Schema.GetFieldById(oldField.Id) != nil { + continue // exist + } + rawOldField, err := json.MarshalIndent(oldField, " ", " ") + if err != nil { + return "", err + } + upParts = append(upParts, fmt.Sprintf("%s.schema.removeField(%q)", varName, oldField.Id)) + downParts = append(downParts, fmt.Sprintf("%s.schema.addField(unmarshal(%s, new SchemaField()))", varName, rawOldField)) + } + // created fields + for _, newField := range new.Schema.Fields() { + if old.Schema.GetFieldById(newField.Id) != nil { + continue // exist + } + rawNewField, err := json.MarshalIndent(newField, " ", " ") + if err != nil { + return "", err + } + upParts = append(upParts, fmt.Sprintf("%s.schema.addField(unmarshal(%s, new SchemaField()))", varName, rawNewField)) + downParts = append(downParts, fmt.Sprintf("%s.schema.removeField(%q)", varName, newField.Id)) + } + // modified fields + for _, newField := range new.Schema.Fields() { + oldField := old.Schema.GetFieldById(newField.Id) + if oldField == nil { + continue + } + + rawNewField, err := json.MarshalIndent(newField, " ", " ") + if err != nil { + return "", err + } + + rawOldField, err := json.MarshalIndent(oldField, " ", " ") + if err != nil { + return "", err + } + + if bytes.Equal(rawNewField, rawOldField) { + continue // no change + } + + upParts = append(upParts, "// upsert") + upParts = append(upParts, fmt.Sprintf("%s.schema.addField(unmarshal(%s, new SchemaField()))", varName, rawNewField)) + downParts = append(downParts, "// upsert") + downParts = append(downParts, fmt.Sprintf("%s.schema.addField(unmarshal(%s, new SchemaField()))", varName, rawOldField)) + } + // --- + + up := strings.Join(upParts, "\n ") + down := strings.Join(downParts, "\n ") + + const template = `migrate((db) => { + const dao = new Dao(db) + const collection = dao.findCollectionByNameOrId(%q) + + %s; + + return dao.saveCollection(collection) +}, (db) => { + const dao = new Dao(db) + const collection = dao.findCollectionByNameOrId(%q) + + %s; + + return dao.saveCollection(collection) +}) +` + + return fmt.Sprintf(template, old.Id, up, new.Id, down), nil +} + // ------------------------------------------------------------------- // Go templates // ------------------------------------------------------------------- -func (p *plugin) goCreateTemplate() (string, error) { +func (p *plugin) goBlankTemplate() (string, error) { const template = `package %s import ( @@ -89,8 +312,8 @@ import ( "github.com/pocketbase/dbx" "github.com/pocketbase/pocketbase/daos" - m "github.com/pocketbase/pocketbase/migrations" "github.com/pocketbase/pocketbase/models" + m "github.com/pocketbase/pocketbase/migrations" ) func init() { @@ -110,3 +333,357 @@ func init() { ` return fmt.Sprintf(template, filepath.Base(p.options.Dir), string(jsonData)), nil } + +func (p *plugin) goCreateTemplate(collection *models.Collection) (string, error) { + jsonData, err := json.MarshalIndent(collection, "\t", "\t\t") + if err != nil { + return "", fmt.Errorf("failed to serialize collections list: %v", err) + } + + const template = `package %s + +import ( + "encoding/json" + + "github.com/pocketbase/dbx" + "github.com/pocketbase/pocketbase/daos" + "github.com/pocketbase/pocketbase/models" + m "github.com/pocketbase/pocketbase/migrations" +) + +func init() { + m.Register(func(db dbx.Builder) error { + jsonData := ` + "`%s`" + ` + + collection := *models.Collection{} + if err := json.Unmarshal([]byte(jsonData), &collection); err != nil { + return err + } + + return daos.New(db).SaveCollection(collection) + }, func(db dbx.Builder) error { + dao := daos.New(db); + + collection, err := dao.FindCollectionByNameOrId(%q) + if err != nil { + return err + } + + return dao.DeleteCollection(collection) + }) +} +` + + return fmt.Sprintf( + template, + filepath.Base(p.options.Dir), + string(jsonData), + collection.Id, + ), nil +} + +func (p *plugin) goDeleteTemplate(collection *models.Collection) (string, error) { + jsonData, err := json.MarshalIndent(collection, "\t", "\t\t") + if err != nil { + return "", fmt.Errorf("failed to serialize collections list: %v", err) + } + + const template = `package %s + +import ( + "encoding/json" + + "github.com/pocketbase/dbx" + "github.com/pocketbase/pocketbase/daos" + "github.com/pocketbase/pocketbase/models" + m "github.com/pocketbase/pocketbase/migrations" +) + +func init() { + m.Register(func(db dbx.Builder) error { + dao := daos.New(db); + + collection, err := dao.FindCollectionByNameOrId(%q) + if err != nil { + return err + } + + return dao.DeleteCollection(collection) + }, func(db dbx.Builder) error { + jsonData := ` + "`%s`" + ` + + collection := *models.Collection{} + if err := json.Unmarshal([]byte(jsonData), &collection); err != nil { + return err + } + + return daos.New(db).SaveCollection(collection) + }) +} +` + + return fmt.Sprintf( + template, + filepath.Base(p.options.Dir), + collection.Id, + string(jsonData), + ), nil +} + +func (p *plugin) goDiffTemplate(new *models.Collection, old *models.Collection) (string, error) { + if new == nil && old == nil { + return "", errors.New("the diff template require at least one of the collection to be non-nil") + } + + if new == nil { + return p.goDeleteTemplate(old) + } + + if old == nil { + return p.goCreateTemplate(new) + } + + upParts := []string{} + downParts := []string{} + varName := "collection" + var importSchema bool + var importTypes bool + + if old.Name != new.Name { + upParts = append(upParts, fmt.Sprintf("%s.Name = %q\n", varName, new.Name)) + downParts = append(downParts, fmt.Sprintf("%s.Name = %q\n", varName, old.Name)) + } + + if old.Type != new.Type { + upParts = append(upParts, fmt.Sprintf("%s.Type = %q\n", varName, new.Type)) + downParts = append(downParts, fmt.Sprintf("%s.Type = %q\n", varName, old.Type)) + } + + if old.System != new.System { + upParts = append(upParts, fmt.Sprintf("%s.System = %t\n", varName, new.System)) + downParts = append(downParts, fmt.Sprintf("%s.System = %t\n", varName, old.System)) + } + + // --- + // note: strconv.Quote is used because %q converts the rule operators in unicode char codes + // --- + + if old.ListRule != new.ListRule { + if old.ListRule != nil && new.ListRule == nil { + importTypes = true + upParts = append(upParts, fmt.Sprintf("%s.ListRule = nil\n", varName)) + downParts = append(downParts, fmt.Sprintf("%s.ListRule = types.Pointer(%s)\n", varName, strconv.Quote(*old.ListRule))) + } else if old.ListRule == nil && new.ListRule != nil || *old.ListRule != *new.ListRule { + importTypes = true + upParts = append(upParts, fmt.Sprintf("%s.ListRule = types.Pointer(%s)\n", varName, strconv.Quote(*new.ListRule))) + downParts = append(downParts, fmt.Sprintf("%s.ListRule = nil\n", varName)) + } + } + + if old.ViewRule != new.ViewRule { + if old.ViewRule != nil && new.ViewRule == nil { + importTypes = true + upParts = append(upParts, fmt.Sprintf("%s.ViewRule = nil\n", varName)) + downParts = append(downParts, fmt.Sprintf("%s.ViewRule = types.Pointer(%s)\n", varName, strconv.Quote(*old.ViewRule))) + } else if old.ViewRule == nil && new.ViewRule != nil || *old.ViewRule != *new.ViewRule { + importTypes = true + upParts = append(upParts, fmt.Sprintf("%s.ViewRule = types.Pointer(%s)\n", varName, strconv.Quote(*new.ViewRule))) + downParts = append(downParts, fmt.Sprintf("%s.ViewRule = nil\n", varName)) + } + } + + if old.CreateRule != new.CreateRule { + if old.CreateRule != nil && new.CreateRule == nil { + importTypes = true + upParts = append(upParts, fmt.Sprintf("%s.CreateRule = nil\n", varName)) + downParts = append(downParts, fmt.Sprintf("%s.CreateRule = types.Pointer(%s)\n", varName, strconv.Quote(*old.CreateRule))) + } else if old.CreateRule == nil && new.CreateRule != nil || *old.CreateRule != *new.CreateRule { + importTypes = true + upParts = append(upParts, fmt.Sprintf("%s.CreateRule = types.Pointer(%s)\n", varName, strconv.Quote(*new.CreateRule))) + downParts = append(downParts, fmt.Sprintf("%s.CreateRule = nil\n", varName)) + } + } + + if old.UpdateRule != new.UpdateRule { + if old.UpdateRule != nil && new.UpdateRule == nil { + importTypes = true + upParts = append(upParts, fmt.Sprintf("%s.UpdateRule = nil\n", varName)) + downParts = append(downParts, fmt.Sprintf("%s.UpdateRule = types.Pointer(%s)\n", varName, strconv.Quote(*old.UpdateRule))) + } else if old.UpdateRule == nil && new.UpdateRule != nil || *old.UpdateRule != *new.UpdateRule { + importTypes = true + upParts = append(upParts, fmt.Sprintf("%s.UpdateRule = types.Pointer(%s)\n", varName, strconv.Quote(*new.UpdateRule))) + downParts = append(downParts, fmt.Sprintf("%s.UpdateRule = nil\n", varName)) + } + } + + if old.DeleteRule != new.DeleteRule { + if old.DeleteRule != nil && new.DeleteRule == nil { + importTypes = true + upParts = append(upParts, fmt.Sprintf("%s.DeleteRule = nil\n", varName)) + downParts = append(downParts, fmt.Sprintf("%s.DeleteRule = types.Pointer(%s)\n", varName, strconv.Quote(*old.DeleteRule))) + } else if old.DeleteRule == nil && new.DeleteRule != nil || *old.DeleteRule != *new.DeleteRule { + importTypes = true + upParts = append(upParts, fmt.Sprintf("%s.DeleteRule = types.Pointer(%s)\n", varName, strconv.Quote(*new.DeleteRule))) + downParts = append(downParts, fmt.Sprintf("%s.DeleteRule = nil\n", varName)) + } + } + + // Options + rawNewOptions, err := json.MarshalIndent(new.Options, "\t\t", "\t") + if err != nil { + return "", err + } + rawOldOptions, err := json.MarshalIndent(old.Options, "\t\t", "\t") + if err != nil { + return "", err + } + if !bytes.Equal(rawNewOptions, rawOldOptions) { + upParts = append(upParts, "options := map[string]any{}") + upParts = append(upParts, fmt.Sprintf("json.Unmarshal([]byte(`%s`), &options)", rawNewOptions)) + upParts = append(upParts, fmt.Sprintf("%s.SetOptions(options)\n", varName)) + // --- + downParts = append(downParts, "options := map[string]any{}") + downParts = append(downParts, fmt.Sprintf("json.Unmarshal([]byte(`%s`), &options)", rawOldOptions)) + downParts = append(downParts, fmt.Sprintf("%s.SetOptions(options)\n", varName)) + } + + // Schema + // --------------------------------------------------------------- + // deleted fields + for _, oldField := range old.Schema.Fields() { + if new.Schema.GetFieldById(oldField.Id) != nil { + continue // exist + } + + rawOldField, err := json.MarshalIndent(oldField, "\t\t", "\t") + if err != nil { + return "", err + } + + importSchema = true + fieldVar := fmt.Sprintf("del_%s", oldField.Name) + + upParts = append(upParts, "// remove") + upParts = append(upParts, fmt.Sprintf("%s.Schema.RemoveField(%q)\n", varName, oldField.Id)) + + downParts = append(downParts, "// add") + downParts = append(downParts, fmt.Sprintf("%s := &schema.SchemaField{}", fieldVar)) + downParts = append(downParts, fmt.Sprintf("json.Unmarshal([]byte(`%s`), %s)", rawOldField, fieldVar)) + downParts = append(downParts, fmt.Sprintf("%s.Schema.AddField(%s)\n", varName, fieldVar)) + } + + // created fields + for _, newField := range new.Schema.Fields() { + if old.Schema.GetFieldById(newField.Id) != nil { + continue // exist + } + + rawNewField, err := json.MarshalIndent(newField, "\t\t", "\t") + if err != nil { + return "", err + } + + importSchema = true + fieldVar := fmt.Sprintf("new_%s", newField.Name) + + upParts = append(upParts, "// add") + upParts = append(upParts, fmt.Sprintf("%s := &schema.SchemaField{}", fieldVar)) + upParts = append(upParts, fmt.Sprintf("json.Unmarshal([]byte(`%s`), %s)", rawNewField, fieldVar)) + upParts = append(upParts, fmt.Sprintf("%s.Schema.AddField(%s)\n", varName, fieldVar)) + + downParts = append(downParts, "// remove") + downParts = append(downParts, fmt.Sprintf("%s.Schema.RemoveField(%q)\n", varName, newField.Id)) + } + + // modified fields + for _, newField := range new.Schema.Fields() { + oldField := old.Schema.GetFieldById(newField.Id) + if oldField == nil { + continue + } + + rawNewField, err := json.MarshalIndent(newField, "\t\t", "\t") + if err != nil { + return "", err + } + + rawOldField, err := json.MarshalIndent(oldField, "\t\t", "\t") + if err != nil { + return "", err + } + + if bytes.Equal(rawNewField, rawOldField) { + continue // no change + } + + importSchema = true + fieldVar := fmt.Sprintf("edit_%s", newField.Name) + + upParts = append(upParts, "// upsert") + upParts = append(upParts, fmt.Sprintf("%s := &schema.SchemaField{}", fieldVar)) + upParts = append(upParts, fmt.Sprintf("json.Unmarshal([]byte(`%s`), %s)", rawNewField, fieldVar)) + upParts = append(upParts, fmt.Sprintf("%s.Schema.AddField(%s)\n", varName, fieldVar)) + + downParts = append(downParts, "// upsert") + downParts = append(downParts, fmt.Sprintf("%s := &schema.SchemaField{}", fieldVar)) + downParts = append(downParts, fmt.Sprintf("json.Unmarshal([]byte(`%s`), %s)", rawOldField, fieldVar)) + downParts = append(downParts, fmt.Sprintf("%s.Schema.AddField(%s)\n", varName, fieldVar)) + } + // --------------------------------------------------------------- + + up := strings.Join(upParts, "\n\t\t") + down := strings.Join(downParts, "\n\t\t") + + const template = `package %s + +import ( + "encoding/json" + + "github.com/pocketbase/dbx" + "github.com/pocketbase/pocketbase/daos" + m "github.com/pocketbase/pocketbase/migrations"%s +) + +func init() { + m.Register(func(db dbx.Builder) error { + dao := daos.New(db); + + collection, err := dao.FindCollectionByNameOrId(%q) + if err != nil { + return err + } + + %s + + return dao.SaveCollection(collection) + }, func(db dbx.Builder) error { + dao := daos.New(db); + + collection, err := dao.FindCollectionByNameOrId(%q) + if err != nil { + return err + } + + %s + + return dao.SaveCollection(collection) + }) +} +` + + var optImports string + if importSchema { + optImports += "\n\t\"github.com/pocketbase/pocketbase/models/schema\"" + } + if importTypes { + optImports += "\n\t\"github.com/pocketbase/pocketbase/tools/types\"" + } + + return fmt.Sprintf( + template, + filepath.Base(p.options.Dir), + optImports, + old.Id, strings.TrimSpace(up), + new.Id, strings.TrimSpace(down), + ), nil +}