restructered some of the internals and added basic js app hooks support

This commit is contained in:
Gani Georgiev
2023-06-08 17:59:08 +03:00
parent ff5508cb79
commit 3cf3e04866
24 changed files with 1218 additions and 422 deletions
+29 -30
View File
@@ -1,5 +1,9 @@
// Package ghupdate implements a new command to selfupdate the current
// PocketBase executable with the latest GitHub release.
//
// Example usage:
//
// ghupdate.MustRegister(app, app.RootCmd, ghupdate.Config{})
package ghupdate
import (
@@ -27,10 +31,10 @@ type HttpClient interface {
Do(req *http.Request) (*http.Response, error)
}
// Options defines optional struct to customize the default plugin behavior.
// Config defines the config options of the ghupdate plugin.
//
// NB! This plugin is considered experimental and its options may change in the future.
type Options struct {
// NB! This plugin is considered experimental and its config options may change in the future.
type Config struct {
// Owner specifies the account owner of the repository (default to "pocketbase").
Owner string
@@ -51,43 +55,38 @@ type Options struct {
// MustRegister registers the ghupdate plugin to the provided app instance
// and panic if it fails.
func MustRegister(app core.App, rootCmd *cobra.Command, options *Options) {
if err := Register(app, rootCmd, options); err != nil {
func MustRegister(app core.App, rootCmd *cobra.Command, config Config) {
if err := Register(app, rootCmd, config); err != nil {
panic(err)
}
}
// Register registers the ghupdate plugin to the provided app instance.
func Register(app core.App, rootCmd *cobra.Command, options *Options) error {
func Register(app core.App, rootCmd *cobra.Command, config Config) error {
p := &plugin{
app: app,
currentVersion: rootCmd.Version,
config: config,
}
if options != nil {
p.options = options
} else {
p.options = &Options{}
if p.config.Owner == "" {
p.config.Owner = "pocketbase"
}
if p.options.Owner == "" {
p.options.Owner = "pocketbase"
if p.config.Repo == "" {
p.config.Repo = "pocketbase"
}
if p.options.Repo == "" {
p.options.Repo = "pocketbase"
if p.config.ArchiveExecutable == "" {
p.config.ArchiveExecutable = "pocketbase"
}
if p.options.ArchiveExecutable == "" {
p.options.ArchiveExecutable = "pocketbase"
if p.config.HttpClient == nil {
p.config.HttpClient = http.DefaultClient
}
if p.options.HttpClient == nil {
p.options.HttpClient = http.DefaultClient
}
if p.options.Context == nil {
p.options.Context = context.Background()
if p.config.Context == nil {
p.config.Context = context.Background()
}
rootCmd.AddCommand(p.updateCmd())
@@ -98,7 +97,7 @@ func Register(app core.App, rootCmd *cobra.Command, options *Options) error {
type plugin struct {
app core.App
currentVersion string
options *Options
config Config
}
func (p *plugin) updateCmd() *cobra.Command {
@@ -130,10 +129,10 @@ func (p *plugin) update(withBackup bool) error {
color.Yellow("Fetching release information...")
latest, err := fetchLatestRelease(
p.options.Context,
p.options.HttpClient,
p.options.Owner,
p.options.Repo,
p.config.Context,
p.config.HttpClient,
p.config.Owner,
p.config.Repo,
)
if err != nil {
return err
@@ -161,7 +160,7 @@ func (p *plugin) update(withBackup bool) error {
// download the release asset
assetZip := filepath.Join(releaseDir, asset.Name)
if err := downloadFile(p.options.Context, p.options.HttpClient, asset.DownloadUrl, assetZip); err != nil {
if err := downloadFile(p.config.Context, p.config.HttpClient, asset.DownloadUrl, assetZip); err != nil {
return err
}
@@ -183,7 +182,7 @@ func (p *plugin) update(withBackup bool) error {
renamedOldExec := oldExec + ".old"
defer os.Remove(renamedOldExec)
newExec := filepath.Join(extractDir, p.options.ArchiveExecutable)
newExec := filepath.Join(extractDir, p.config.ArchiveExecutable)
if _, err := os.Stat(newExec); err != nil {
// try again with an .exe extension
newExec = newExec + ".exe"
@@ -213,7 +212,7 @@ func (p *plugin) update(withBackup bool) error {
color.Yellow("Creating pb_data backup...")
backupName := fmt.Sprintf("@update_%s.zip", latest.Tag)
if err := p.app.CreateBackup(p.options.Context, backupName); err != nil {
if err := p.app.CreateBackup(p.config.Context, backupName); err != nil {
tryToRevertExecChanges()
return err
}
+165
View File
@@ -0,0 +1,165 @@
package jsvm
import (
"path/filepath"
"runtime"
"time"
"github.com/dop251/goja"
"github.com/dop251/goja_nodejs/console"
"github.com/dop251/goja_nodejs/eventloop"
"github.com/dop251/goja_nodejs/process"
"github.com/dop251/goja_nodejs/require"
"github.com/fatih/color"
"github.com/fsnotify/fsnotify"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/core"
)
// HooksConfig defines the config options of the JS app hooks plugin.
type HooksConfig struct {
// Dir specifies the directory with the JS app hooks.
//
// If not set it fallbacks to a relative "pb_data/../pb_hooks" directory.
Dir string
// Watch enables auto app restarts when a JS app hook file changes.
//
// Note that currently the application cannot be automatically restarted on Windows
// because the restart process relies on execve.
Watch bool
}
// MustRegisterHooks registers the JS hooks plugin to
// the provided app instance and panics if it fails.
//
// Example usage:
//
// jsvm.MustRegisterHooks(app, jsvm.HooksConfig{})
func MustRegisterHooks(app core.App, config HooksConfig) {
if err := RegisterHooks(app, config); err != nil {
panic(err)
}
}
// RegisterHooks registers the JS hooks plugin to the provided app instance.
func RegisterHooks(app core.App, config HooksConfig) error {
p := &hooks{app: app, config: config}
if p.config.Dir == "" {
p.config.Dir = filepath.Join(app.DataDir(), "../pb_hooks")
}
// fetch all js hooks sorted by their filename
files, err := filesContent(p.config.Dir, `^.*\.pb\.js$`)
if err != nil {
return err
}
dbx.HashExp{}.Build(app.DB(), nil)
registry := new(require.Registry) // this can be shared by multiple runtimes
loop := eventloop.NewEventLoop()
loop.Run(func(vm *goja.Runtime) {
registry.Enable(vm)
console.Enable(vm)
process.Enable(vm)
baseBinds(vm)
dbxBinds(vm)
filesystemBinds(vm)
tokensBinds(vm)
securityBinds(vm)
formsBinds(vm)
apisBinds(vm)
vm.Set("$app", app)
for file, content := range files {
_, err := vm.RunString(string(content))
if err != nil {
if p.config.Watch {
color.Red("Failed to execute %s: %v", file, err)
} else {
// return err
}
}
}
})
loop.Start()
app.OnTerminate().Add(func(e *core.TerminateEvent) error {
loop.StopNoWait()
return nil
})
if p.config.Watch {
return p.watchFiles()
}
return nil
}
type hooks struct {
app core.App
config HooksConfig
}
func (h *hooks) watchFiles() error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
h.app.OnTerminate().Add(func(e *core.TerminateEvent) error {
watcher.Close()
return nil
})
var debounceTimer *time.Timer
// start listening for events.
go func() {
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
if debounceTimer != nil {
debounceTimer.Stop()
}
debounceTimer = time.AfterFunc(100*time.Millisecond, func() {
// app restart is currently not supported on Windows
if runtime.GOOS == "windows" {
color.Yellow("File %s changed, please restart the app", event.Name)
} else {
color.Yellow("File %s changed, restarting...", event.Name)
if err := h.app.Restart(); err != nil {
color.Red("Failed to restart the app:", err)
}
}
})
case err, ok := <-watcher.Errors:
if !ok {
return
}
color.Red("Watch error:", err)
}
}
}()
// add the directory to watch
err = watcher.Add(h.config.Dir)
if err != nil {
watcher.Close()
return err
}
return nil
}
+48
View File
@@ -0,0 +1,48 @@
package jsvm
import (
"reflect"
"strings"
"unicode"
"github.com/dop251/goja"
)
var (
_ goja.FieldNameMapper = (*FieldMapper)(nil)
)
// FieldMapper provides custom mapping between Go and JavaScript property names.
//
// It is similar to the builtin "uncapFieldNameMapper" but also converts
// all uppercase identifiers to their lowercase equivalent (eg. "GET" -> "get").
type FieldMapper struct {
}
// FieldName implements the [FieldNameMapper.FieldName] interface method.
func (u FieldMapper) FieldName(_ reflect.Type, f reflect.StructField) string {
return convertGoToJSName(f.Name)
}
// MethodName implements the [FieldNameMapper.MethodName] interface method.
func (u FieldMapper) MethodName(_ reflect.Type, m reflect.Method) string {
return convertGoToJSName(m.Name)
}
func convertGoToJSName(name string) string {
allUppercase := true
for _, c := range name {
if c != '_' && !unicode.IsUpper(c) {
allUppercase = false
break
}
}
// eg. "JSON" -> "json"
if allUppercase {
return strings.ToLower(name)
}
// eg. "GetField" -> "getField"
return strings.ToLower(name[0:1]) + name[1:]
}
+40
View File
@@ -0,0 +1,40 @@
package jsvm_test
import (
"reflect"
"testing"
"github.com/pocketbase/pocketbase/plugins/jsvm"
)
func TestFieldMapper(t *testing.T) {
mapper := jsvm.FieldMapper{}
scenarios := []struct {
name string
expected string
}{
{"", ""},
{"test", "test"},
{"Test", "test"},
{"miXeD", "miXeD"},
{"MiXeD", "miXeD"},
{"ResolveRequestAsJSON", "resolveRequestAsJSON"},
{"Variable_with_underscore", "variable_with_underscore"},
{"ALLCAPS", "allcaps"},
{"NOTALLCAPs", "nOTALLCAPs"},
{"ALL_CAPS_WITH_UNDERSCORE", "all_caps_with_underscore"},
}
for i, s := range scenarios {
field := reflect.StructField{Name: s.name}
if v := mapper.FieldName(nil, field); v != s.expected {
t.Fatalf("[%d] Expected FieldName %q, got %q", i, s.expected, v)
}
method := reflect.Method{Name: s.name}
if v := mapper.MethodName(nil, method); v != s.expected {
t.Fatalf("[%d] Expected MethodName %q, got %q", i, s.expected, v)
}
}
}
+22 -59
View File
@@ -2,10 +2,9 @@ package jsvm
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/dop251/goja"
"github.com/dop251/goja_nodejs/console"
"github.com/dop251/goja_nodejs/process"
"github.com/dop251/goja_nodejs/require"
@@ -14,52 +13,36 @@ import (
m "github.com/pocketbase/pocketbase/migrations"
)
// MigrationsOptions defines optional struct to customize the default migrations loader behavior.
type MigrationsOptions struct {
// MigrationsConfig defines the config options of the JS migrations loader plugin.
type MigrationsConfig struct {
// Dir specifies the directory with the JS migrations.
//
// If not set it fallbacks to a relative "pb_data/../pb_migrations" directory.
Dir string
}
// migrations is the migrations loader plugin definition.
// Usually it is instantiated via RegisterMigrations or MustRegisterMigrations.
type migrations struct {
app core.App
options *MigrationsOptions
}
// MustRegisterMigrations registers the migrations loader plugin to
// MustRegisterMigrations registers the JS migrations loader plugin to
// the provided app instance and panics if it fails.
//
// Internally it calls RegisterMigrations(app, options).
// Example usage:
//
// If options is nil, by default the js files from pb_data/migrations are loaded.
// Set custom options.Dir if you want to change it to some other directory.
func MustRegisterMigrations(app core.App, options *MigrationsOptions) {
if err := RegisterMigrations(app, options); err != nil {
// jsvm.MustRegisterMigrations(app, jsvm.MigrationsConfig{})
func MustRegisterMigrations(app core.App, config MigrationsConfig) {
if err := RegisterMigrations(app, config); err != nil {
panic(err)
}
}
// RegisterMigrations registers the plugin to the provided app instance.
//
// If options is nil, by default the js files from pb_data/migrations are loaded.
// Set custom options.Dir if you want to change it to some other directory.
func RegisterMigrations(app core.App, options *MigrationsOptions) error {
l := &migrations{app: app}
// RegisterMigrations registers the JS migrations loader hooks plugin
// to the provided app instance.
func RegisterMigrations(app core.App, config MigrationsConfig) error {
l := &migrations{app: app, config: config}
if options != nil {
l.options = options
} else {
l.options = &MigrationsOptions{}
if l.config.Dir == "" {
l.config.Dir = filepath.Join(app.DataDir(), "../pb_migrations")
}
if l.options.Dir == "" {
l.options.Dir = filepath.Join(app.DataDir(), "../pb_migrations")
}
files, err := readDirFiles(l.options.Dir)
files, err := filesContent(l.config.Dir, `^.*\.js$`)
if err != nil {
return err
}
@@ -67,10 +50,13 @@ func RegisterMigrations(app core.App, options *MigrationsOptions) error {
registry := new(require.Registry) // this can be shared by multiple runtimes
for file, content := range files {
vm := NewBaseVM()
vm := goja.New()
registry.Enable(vm)
console.Enable(vm)
process.Enable(vm)
dbxBinds(vm)
tokensBinds(vm)
securityBinds(vm)
vm.Set("migrate", func(up, down func(db dbx.Builder) error) {
m.AppMigrations.Register(up, down, file)
@@ -85,30 +71,7 @@ func RegisterMigrations(app core.App, options *MigrationsOptions) error {
return nil
}
// readDirFiles returns a map with all directory files and their content.
//
// If directory with dirPath is missing, it returns an empty map and no error.
func readDirFiles(dirPath string) (map[string][]byte, error) {
files, err := os.ReadDir(dirPath)
if err != nil {
if os.IsNotExist(err) {
return map[string][]byte{}, nil
}
return nil, err
}
result := map[string][]byte{}
for _, f := range files {
if f.IsDir() || !strings.HasSuffix(f.Name(), ".js") {
continue // not a .js file
}
raw, err := os.ReadFile(filepath.Join(dirPath, f.Name()))
if err != nil {
return nil, err
}
result[f.Name()] = raw
}
return result, nil
type migrations struct {
app core.App
config MigrationsConfig
}
+208 -69
View File
@@ -5,36 +5,42 @@
//
// 1. JS Migrations loader:
//
// jsvm.MustRegisterMigrations(app, &jsvm.MigrationsOptions{
// Dir: "custom_js_migrations_dir_path", // default to "pb_data/../pb_migrations"
// jsvm.MustRegisterMigrations(app, jsvm.MigrationsConfig{
// Dir: "/custom/js/migrations/dir", // default to "pb_data/../pb_migrations"
// })
//
// 2. JS app hooks:
//
// jsvm.MustRegisterHooks(app, jsvm.HooksConfig{
// Dir: "/custom/js/hooks/dir", // default to "pb_data/../pb_hooks"
// })
package jsvm
import (
"encoding/json"
"os"
"path/filepath"
"reflect"
"strings"
"unicode"
"regexp"
"github.com/dop251/goja"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/labstack/echo/v5"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/forms"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/models/schema"
"github.com/pocketbase/pocketbase/tokens"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/mailer"
"github.com/pocketbase/pocketbase/tools/security"
)
func NewBaseVM() *goja.Runtime {
vm := goja.New()
func baseBinds(vm *goja.Runtime) {
vm.SetFieldNameMapper(FieldMapper{})
baseBinds(vm)
dbxBinds(vm)
return vm
}
func baseBinds(vm *goja.Runtime) {
vm.Set("unmarshal", func(src map[string]any, dest any) (any, error) {
raw, err := json.Marshal(src)
if err != nil {
@@ -72,49 +78,57 @@ func baseBinds(vm *goja.Runtime) {
vm.Set("Collection", func(call goja.ConstructorCall) *goja.Object {
instance := &models.Collection{}
return defaultConstructor(vm, call, instance)
return structConstructor(vm, call, instance)
})
vm.Set("Admin", func(call goja.ConstructorCall) *goja.Object {
instance := &models.Admin{}
return defaultConstructor(vm, call, instance)
return structConstructor(vm, call, instance)
})
vm.Set("Schema", func(call goja.ConstructorCall) *goja.Object {
instance := &schema.Schema{}
return defaultConstructor(vm, call, instance)
return structConstructor(vm, call, instance)
})
vm.Set("SchemaField", func(call goja.ConstructorCall) *goja.Object {
instance := &schema.SchemaField{}
return defaultConstructor(vm, call, instance)
return structConstructor(vm, call, instance)
})
vm.Set("Dao", func(call goja.ConstructorCall) *goja.Object {
db, ok := call.Argument(0).Export().(dbx.Builder)
if !ok || db == nil {
panic("missing required Dao(db) argument")
}
vm.Set("Mail", func(call goja.ConstructorCall) *goja.Object {
instance := &mailer.Message{}
return structConstructor(vm, call, instance)
})
instance := daos.New(db)
vm.Set("ValidationError", func(call goja.ConstructorCall) *goja.Object {
code, _ := call.Argument(0).Export().(string)
message, _ := call.Argument(1).Export().(string)
instance := validation.NewError(code, message)
instanceValue := vm.ToValue(instance).(*goja.Object)
instanceValue.SetPrototype(call.This.Prototype())
return instanceValue
})
}
func defaultConstructor(vm *goja.Runtime, call goja.ConstructorCall, instance any) *goja.Object {
if data := call.Argument(0).Export(); data != nil {
if raw, err := json.Marshal(data); err == nil {
json.Unmarshal(raw, instance)
vm.Set("Dao", func(call goja.ConstructorCall) *goja.Object {
concurrentDB, _ := call.Argument(0).Export().(dbx.Builder)
if concurrentDB == nil {
panic("missing required Dao(concurrentDB, [nonconcurrentDB]) argument")
}
}
instanceValue := vm.ToValue(instance).(*goja.Object)
instanceValue.SetPrototype(call.This.Prototype())
nonConcurrentDB, _ := call.Argument(1).Export().(dbx.Builder)
if nonConcurrentDB == nil {
nonConcurrentDB = concurrentDB
}
return instanceValue
instance := daos.NewMultiDB(concurrentDB, nonConcurrentDB)
instanceValue := vm.ToValue(instance).(*goja.Object)
instanceValue.SetPrototype(call.This.Prototype())
return instanceValue
})
}
func dbxBinds(vm *goja.Runtime) {
@@ -123,11 +137,7 @@ func dbxBinds(vm *goja.Runtime) {
obj.Set("exp", dbx.NewExp)
obj.Set("hashExp", func(data map[string]any) dbx.HashExp {
exp := dbx.HashExp{}
for k, v := range data {
exp[k] = v
}
return exp
return dbx.HashExp(data)
})
obj.Set("not", dbx.Not)
obj.Set("and", dbx.And)
@@ -144,8 +154,79 @@ func dbxBinds(vm *goja.Runtime) {
obj.Set("notBetween", dbx.NotBetween)
}
func apisBind(vm *goja.Runtime) {
func tokensBinds(vm *goja.Runtime) {
obj := vm.NewObject()
vm.Set("$tokens", obj)
// admin
obj.Set("adminAuthToken", tokens.NewAdminAuthToken)
obj.Set("adminResetPasswordToken", tokens.NewAdminResetPasswordToken)
obj.Set("adminFileToken", tokens.NewAdminFileToken)
// record
obj.Set("recordAuthToken", tokens.NewRecordAuthToken)
obj.Set("recordVerifyToken", tokens.NewRecordVerifyToken)
obj.Set("recordResetPasswordToken", tokens.NewRecordResetPasswordToken)
obj.Set("recordChangeEmailToken", tokens.NewRecordChangeEmailToken)
obj.Set("recordFileToken", tokens.NewRecordFileToken)
}
func securityBinds(vm *goja.Runtime) {
obj := vm.NewObject()
vm.Set("$security", obj)
// random
obj.Set("randomString", security.RandomString)
obj.Set("randomStringWithAlphabet", security.RandomStringWithAlphabet)
obj.Set("pseudorandomString", security.PseudorandomString)
obj.Set("pseudorandomStringWithAlphabet", security.PseudorandomStringWithAlphabet)
// jwt
obj.Set("parseUnverifiedToken", security.ParseUnverifiedJWT)
obj.Set("parseToken", security.ParseJWT)
obj.Set("createToken", security.NewToken)
}
func filesystemBinds(vm *goja.Runtime) {
obj := vm.NewObject()
vm.Set("$filesystem", obj)
obj.Set("fileFromPath", filesystem.NewFileFromPath)
obj.Set("fileFromBytes", filesystem.NewFileFromBytes)
obj.Set("fileFromMultipart", filesystem.NewFileFromMultipart)
}
func formsBinds(vm *goja.Runtime) {
registerFactoryAsConstructor(vm, "AdminLoginForm", forms.NewAdminLogin)
registerFactoryAsConstructor(vm, "AdminPasswordResetConfirmForm", forms.NewAdminPasswordResetConfirm)
registerFactoryAsConstructor(vm, "AdminPasswordResetRequestForm", forms.NewAdminPasswordResetRequest)
registerFactoryAsConstructor(vm, "AdminUpsertForm", forms.NewAdminUpsert)
registerFactoryAsConstructor(vm, "AppleClientSecretCreateForm", forms.NewAppleClientSecretCreate)
registerFactoryAsConstructor(vm, "CollectionUpsertForm", forms.NewCollectionUpsert)
registerFactoryAsConstructor(vm, "CollectionsImportForm", forms.NewCollectionsImport)
registerFactoryAsConstructor(vm, "RealtimeSubscribeForm", forms.NewRealtimeSubscribe)
registerFactoryAsConstructor(vm, "RecordEmailChangeConfirmForm", forms.NewRecordEmailChangeConfirm)
registerFactoryAsConstructor(vm, "RecordEmailChangeRequestForm", forms.NewRecordEmailChangeRequest)
registerFactoryAsConstructor(vm, "RecordOAuth2LoginForm", forms.NewRecordOAuth2Login)
registerFactoryAsConstructor(vm, "RecordPasswordLoginForm", forms.NewRecordPasswordLogin)
registerFactoryAsConstructor(vm, "RecordPasswordResetConfirmForm", forms.NewRecordPasswordResetConfirm)
registerFactoryAsConstructor(vm, "RecordPasswordResetRequestForm", forms.NewRecordPasswordResetRequest)
registerFactoryAsConstructor(vm, "RecordUpsertForm", forms.NewRecordUpsert)
registerFactoryAsConstructor(vm, "RecordVerificationConfirmForm", forms.NewRecordVerificationConfirm)
registerFactoryAsConstructor(vm, "RecordVerificationRequestForm", forms.NewRecordVerificationRequest)
registerFactoryAsConstructor(vm, "SettingsUpsertForm", forms.NewSettingsUpsert)
registerFactoryAsConstructor(vm, "TestEmailSendForm", forms.NewTestEmailSend)
registerFactoryAsConstructor(vm, "TestS3FilesystemForm", forms.NewTestS3Filesystem)
}
func apisBinds(vm *goja.Runtime) {
obj := vm.NewObject()
vm.Set("Route", func(call goja.ConstructorCall) *goja.Object {
instance := echo.Route{}
return structConstructor(vm, call, &instance)
})
vm.Set("$apis", obj)
// middlewares
@@ -158,49 +239,107 @@ func apisBind(vm *goja.Runtime) {
obj.Set("requireAdminOrOwnerAuth", apis.RequireAdminOrOwnerAuth)
obj.Set("activityLogger", apis.ActivityLogger)
// record helpers
obj.Set("requestData", apis.RequestData)
obj.Set("recordAuthResponse", apis.RecordAuthResponse)
obj.Set("enrichRecord", apis.EnrichRecord)
obj.Set("enrichRecords", apis.EnrichRecords)
// api errors
vm.Set("ApiError", func(call goja.ConstructorCall) *goja.Object {
status, _ := call.Argument(0).Export().(int64)
message, _ := call.Argument(1).Export().(string)
data := call.Argument(2).Export()
instance := apis.NewApiError(int(status), message, data)
instanceValue := vm.ToValue(instance).(*goja.Object)
instanceValue.SetPrototype(call.This.Prototype())
return instanceValue
})
obj.Set("notFoundError", apis.NewNotFoundError)
obj.Set("badRequestError", apis.NewBadRequestError)
obj.Set("forbiddenError", apis.NewForbiddenError)
obj.Set("unauthorizedError", apis.NewUnauthorizedError)
// record helpers
obj.Set("requestData", apis.RequestData)
obj.Set("enrichRecord", apis.EnrichRecord)
obj.Set("enrichRecords", apis.EnrichRecords)
}
// FieldMapper provides custom mapping between Go and JavaScript property names.
//
// It is similar to the builtin "uncapFieldNameMapper" but also converts
// all uppercase identifiers to their lowercase equivalent (eg. "GET" -> "get").
type FieldMapper struct {
// -------------------------------------------------------------------
// registerFactoryAsConstructor registers the factory function as native JS constructor.
func registerFactoryAsConstructor(vm *goja.Runtime, constructorName string, factoryFunc any) {
vm.Set(constructorName, func(call goja.ConstructorCall) *goja.Object {
f := reflect.ValueOf(factoryFunc)
args := []reflect.Value{}
for _, v := range call.Arguments {
args = append(args, reflect.ValueOf(v.Export()))
}
result := f.Call(args)
if len(result) != 1 {
panic("the factory function should return only 1 item")
}
value := vm.ToValue(result[0].Interface()).(*goja.Object)
value.SetPrototype(call.This.Prototype())
return value
})
}
// FieldName implements the [FieldNameMapper.FieldName] interface method.
func (u FieldMapper) FieldName(_ reflect.Type, f reflect.StructField) string {
return convertGoToJSName(f.Name)
}
// MethodName implements the [FieldNameMapper.MethodName] interface method.
func (u FieldMapper) MethodName(_ reflect.Type, m reflect.Method) string {
return convertGoToJSName(m.Name)
}
func convertGoToJSName(name string) string {
allUppercase := true
for _, c := range name {
if c != '_' && !unicode.IsUpper(c) {
allUppercase = false
break
// structConstructor wraps the provided struct with a native JS constructor.
func structConstructor(vm *goja.Runtime, call goja.ConstructorCall, instance any) *goja.Object {
if data := call.Argument(0).Export(); data != nil {
if raw, err := json.Marshal(data); err == nil {
json.Unmarshal(raw, instance)
}
}
// eg. "JSON" -> "json"
if allUppercase {
return strings.ToLower(name)
instanceValue := vm.ToValue(instance).(*goja.Object)
instanceValue.SetPrototype(call.This.Prototype())
return instanceValue
}
// filesContent returns a map with all direct files within the specified dir and their content.
//
// If directory with dirPath is missing or no files matching the pattern were found,
// it returns an empty map and no error.
//
// If pattern is empty string it matches all root files.
func filesContent(dirPath string, pattern string) (map[string][]byte, error) {
files, err := os.ReadDir(dirPath)
if err != nil {
if os.IsNotExist(err) {
return map[string][]byte{}, nil
}
return nil, err
}
// eg. "GetField" -> "getField"
return strings.ToLower(name[0:1]) + name[1:]
var exp *regexp.Regexp
if pattern != "" {
var err error
if exp, err = regexp.Compile(pattern); err != nil {
return nil, err
}
}
result := map[string][]byte{}
for _, f := range files {
if f.IsDir() || (exp != nil && !exp.MatchString(f.Name())) {
continue
}
raw, err := os.ReadFile(filepath.Join(dirPath, f.Name()))
if err != nil {
return nil, err
}
result[f.Name()] = raw
}
return result, nil
}
+473 -113
View File
@@ -1,18 +1,35 @@
package jsvm_test
package jsvm
import (
"reflect"
"encoding/json"
"mime/multipart"
"path/filepath"
"testing"
"github.com/dop251/goja"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/models/schema"
"github.com/pocketbase/pocketbase/plugins/jsvm"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/mailer"
"github.com/pocketbase/pocketbase/tools/security"
)
func TestBaseVMUnmarshal(t *testing.T) {
vm := jsvm.NewBaseVM()
// note: this test is useful as a reminder to update the tests in case
// a new base binding is added.
func TestBaseBindsCount(t *testing.T) {
vm := goja.New()
baseBinds(vm)
testBindsCount(vm, "this", 9, t)
}
func TestBaseBindsUnmarshal(t *testing.T) {
vm := goja.New()
baseBinds(vm)
v, err := vm.RunString(`unmarshal({ name: "test" }, new Collection())`)
if err != nil {
@@ -29,7 +46,7 @@ func TestBaseVMUnmarshal(t *testing.T) {
}
}
func TestBaseVMRecordBind(t *testing.T) {
func TestBaseBindsRecord(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
@@ -38,7 +55,8 @@ func TestBaseVMRecordBind(t *testing.T) {
t.Fatal(err)
}
vm := jsvm.NewBaseVM()
vm := goja.New()
baseBinds(vm)
vm.Set("collection", collection)
// without record data
@@ -74,75 +92,9 @@ func TestBaseVMRecordBind(t *testing.T) {
}
}
// @todo enable after https://github.com/dop251/goja/issues/426
// func TestBaseVMRecordGetAndSetBind(t *testing.T) {
// app, _ := tests.NewTestApp()
// defer app.Cleanup()
// collection, err := app.Dao().FindCollectionByNameOrId("users")
// if err != nil {
// t.Fatal(err)
// }
// vm := jsvm.NewBaseVM()
// vm.Set("collection", collection)
// vm.Set("getRecord", func() *models.Record {
// return models.NewRecord(collection)
// })
// _, runErr := vm.RunString(`
// const jsRecord = new Record(collection);
// jsRecord.email = "test@example.com"; // test js record setter
// const email = jsRecord.email; // test js record getter
// const goRecord = getRecord()
// goRecord.name = "test" // test go record setter
// const name = goRecord.name; // test go record getter
// `)
// if runErr != nil {
// t.Fatal(runErr)
// }
// expectedEmail := "test@example.com"
// expectedName := "test"
// jsRecord, ok := vm.Get("jsRecord").Export().(*models.Record)
// if !ok {
// t.Fatalf("Failed to export jsRecord")
// }
// if v := jsRecord.Email(); v != expectedEmail {
// t.Fatalf("Expected the js created record to have email %q, got %q", expectedEmail, v)
// }
// email := vm.Get("email").Export().(string)
// if email != expectedEmail {
// t.Fatalf("Expected exported email %q, got %q", expectedEmail, email)
// }
// goRecord, ok := vm.Get("goRecord").Export().(*models.Record)
// if !ok {
// t.Fatalf("Failed to export goRecord")
// }
// if v := goRecord.GetString("name"); v != expectedName {
// t.Fatalf("Expected the go created record to have name %q, got %q", expectedName, v)
// }
// name := vm.Get("name").Export().(string)
// if name != expectedName {
// t.Fatalf("Expected exported name %q, got %q", expectedName, name)
// }
// // ensure that the two record instances are not mixed
// if v := goRecord.Email(); v != "" {
// t.Fatalf("Expected the go created record to not have an email, got %q", v)
// }
// if v := jsRecord.GetString("name"); v != "" {
// t.Fatalf("Expected the js created record to not have a name, got %q", v)
// }
// }
func TestBaseVMCollectionBind(t *testing.T) {
vm := jsvm.NewBaseVM()
func TestBaseBindsCollection(t *testing.T) {
vm := goja.New()
baseBinds(vm)
v, err := vm.RunString(`new Collection({ name: "test", schema: [{name: "title", "type": "text"}] })`)
if err != nil {
@@ -164,7 +116,8 @@ func TestBaseVMCollectionBind(t *testing.T) {
}
func TestBaseVMAdminBind(t *testing.T) {
vm := jsvm.NewBaseVM()
vm := goja.New()
baseBinds(vm)
v, err := vm.RunString(`new Admin({ email: "test@example.com" })`)
if err != nil {
@@ -177,8 +130,9 @@ func TestBaseVMAdminBind(t *testing.T) {
}
}
func TestBaseVMSchemaBind(t *testing.T) {
vm := jsvm.NewBaseVM()
func TestBaseBindsSchema(t *testing.T) {
vm := goja.New()
baseBinds(vm)
v, err := vm.RunString(`new Schema([{name: "title", "type": "text"}])`)
if err != nil {
@@ -195,8 +149,9 @@ func TestBaseVMSchemaBind(t *testing.T) {
}
}
func TestBaseVMSchemaFieldBind(t *testing.T) {
vm := jsvm.NewBaseVM()
func TestBaseBindsSchemaField(t *testing.T) {
vm := goja.New()
baseBinds(vm)
v, err := vm.RunString(`new SchemaField({name: "title", "type": "text"})`)
if err != nil {
@@ -213,56 +168,461 @@ func TestBaseVMSchemaFieldBind(t *testing.T) {
}
}
func TestBaseVMDaoBind(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
func TestBaseBindsMail(t *testing.T) {
vm := goja.New()
baseBinds(vm)
vm := jsvm.NewBaseVM()
vm.Set("db", app.DB())
v, err := vm.RunString(`new Dao(db)`)
v, err := vm.RunString(`new Mail({
from: {name: "test_from", address: "test_from@example.com"},
to: [
{name: "test_to1", address: "test_to1@example.com"},
{name: "test_to2", address: "test_to2@example.com"},
],
bcc: [
{name: "test_bcc1", address: "test_bcc1@example.com"},
{name: "test_bcc2", address: "test_bcc2@example.com"},
],
cc: [
{name: "test_cc1", address: "test_cc1@example.com"},
{name: "test_cc2", address: "test_cc2@example.com"},
],
subject: "test_subject",
html: "test_html",
text: "test_text",
headers: {
header1: "a",
header2: "b",
}
})`)
if err != nil {
t.Fatal(err)
}
d, ok := v.Export().(*daos.Dao)
m, ok := v.Export().(*mailer.Message)
if !ok {
t.Fatalf("Expected daos.Dao, got %v", d)
t.Fatalf("Expected mailer.Message, got %v", m)
}
if d.DB() != app.DB() {
t.Fatalf("The db instances doesn't match")
raw, err := json.Marshal(m)
expected := `{"from":{"Name":"test_from","Address":"test_from@example.com"},"to":[{"Name":"test_to1","Address":"test_to1@example.com"},{"Name":"test_to2","Address":"test_to2@example.com"}],"bcc":[{"Name":"test_bcc1","Address":"test_bcc1@example.com"},{"Name":"test_bcc2","Address":"test_bcc2@example.com"}],"cc":[{"Name":"test_cc1","Address":"test_cc1@example.com"},{"Name":"test_cc2","Address":"test_cc2@example.com"}],"subject":"test_subject","html":"test_html","text":"test_text","headers":{"header1":"a","header2":"b"},"attachments":null}`
if string(raw) != expected {
t.Fatalf("Expected \n%s, \ngot \n%s", expected, raw)
}
}
func TestFieldMapper(t *testing.T) {
mapper := jsvm.FieldMapper{}
func TestBaseBindsValidationError(t *testing.T) {
vm := goja.New()
baseBinds(vm)
scenarios := []struct {
name string
expected string
js string
expectCode string
expectMessage string
}{
{"", ""},
{"test", "test"},
{"Test", "test"},
{"miXeD", "miXeD"},
{"MiXeD", "miXeD"},
{"ResolveRequestAsJSON", "resolveRequestAsJSON"},
{"Variable_with_underscore", "variable_with_underscore"},
{"ALLCAPS", "allcaps"},
{"NOTALLCAPs", "nOTALLCAPs"},
{"ALL_CAPS_WITH_UNDERSCORE", "all_caps_with_underscore"},
{
`new ValidationError()`,
"",
"",
},
{
`new ValidationError("test_code")`,
"test_code",
"",
},
{
`new ValidationError("test_code", "test_message")`,
"test_code",
"test_message",
},
}
for i, s := range scenarios {
field := reflect.StructField{Name: s.name}
if v := mapper.FieldName(nil, field); v != s.expected {
t.Fatalf("[%d] Expected FieldName %q, got %q", i, s.expected, v)
for _, s := range scenarios {
v, err := vm.RunString(s.js)
if err != nil {
t.Fatal(err)
}
method := reflect.Method{Name: s.name}
if v := mapper.MethodName(nil, method); v != s.expected {
t.Fatalf("[%d] Expected MethodName %q, got %q", i, s.expected, v)
m, ok := v.Export().(validation.Error)
if !ok {
t.Fatalf("[%s] Expected validation.Error, got %v", s.js, m)
}
if m.Code() != s.expectCode {
t.Fatalf("[%s] Expected code %q, got %q", s.js, s.expectCode, m.Code())
}
if m.Message() != s.expectMessage {
t.Fatalf("[%s] Expected message %q, got %q", s.js, s.expectMessage, m.Message())
}
}
}
func TestBaseBindsDao(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
vm := goja.New()
baseBinds(vm)
vm.Set("db", app.Dao().ConcurrentDB())
vm.Set("db2", app.Dao().NonconcurrentDB())
scenarios := []struct {
js string
concurrentDB dbx.Builder
nonconcurrentDB dbx.Builder
}{
{
js: "new Dao(db)",
concurrentDB: app.Dao().ConcurrentDB(),
nonconcurrentDB: app.Dao().ConcurrentDB(),
},
{
js: "new Dao(db, db2)",
concurrentDB: app.Dao().ConcurrentDB(),
nonconcurrentDB: app.Dao().NonconcurrentDB(),
},
}
for _, s := range scenarios {
v, err := vm.RunString(s.js)
if err != nil {
t.Fatalf("[%s] Failed to execute js script, got %v", s.js, err)
}
d, ok := v.Export().(*daos.Dao)
if !ok {
t.Fatalf("[%s] Expected daos.Dao, got %v", s.js, d)
}
if d.ConcurrentDB() != s.concurrentDB {
t.Fatalf("[%s] The ConcurrentDB instances doesn't match", s.js)
}
if d.NonconcurrentDB() != s.nonconcurrentDB {
t.Fatalf("[%s] The NonconcurrentDB instances doesn't match", s.js)
}
}
}
func TestDbxBinds(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
vm := goja.New()
vm.Set("db", app.Dao().DB())
baseBinds(vm)
dbxBinds(vm)
testBindsCount(vm, "$dbx", 15, t)
sceneraios := []struct {
js string
expected string
}{
{
`$dbx.exp("a = 1").build(db, {})`,
"a = 1",
},
{
`$dbx.hashExp({
"a": 1,
b: null,
c: [1, 2, 3],
}).build(db, {})`,
"`a`={:p0} AND `b` IS NULL AND `c` IN ({:p1}, {:p2}, {:p3})",
},
{
`$dbx.not($dbx.exp("a = 1")).build(db, {})`,
"NOT (a = 1)",
},
{
`$dbx.and($dbx.exp("a = 1"), $dbx.exp("b = 2")).build(db, {})`,
"(a = 1) AND (b = 2)",
},
{
`$dbx.or($dbx.exp("a = 1"), $dbx.exp("b = 2")).build(db, {})`,
"(a = 1) OR (b = 2)",
},
{
`$dbx.in("a", 1, 2, 3).build(db, {})`,
"`a` IN ({:p0}, {:p1}, {:p2})",
},
{
`$dbx.notIn("a", 1, 2, 3).build(db, {})`,
"`a` NOT IN ({:p0}, {:p1}, {:p2})",
},
{
`$dbx.like("a", "test1", "test2").match(true, false).build(db, {})`,
"`a` LIKE {:p0} AND `a` LIKE {:p1}",
},
{
`$dbx.orLike("a", "test1", "test2").match(false, true).build(db, {})`,
"`a` LIKE {:p0} OR `a` LIKE {:p1}",
},
{
`$dbx.notLike("a", "test1", "test2").match(true, false).build(db, {})`,
"`a` NOT LIKE {:p0} AND `a` NOT LIKE {:p1}",
},
{
`$dbx.orNotLike("a", "test1", "test2").match(false, false).build(db, {})`,
"`a` NOT LIKE {:p0} OR `a` NOT LIKE {:p1}",
},
{
`$dbx.exists($dbx.exp("a = 1")).build(db, {})`,
"EXISTS (a = 1)",
},
{
`$dbx.notExists($dbx.exp("a = 1")).build(db, {})`,
"NOT EXISTS (a = 1)",
},
{
`$dbx.between("a", 1, 2).build(db, {})`,
"`a` BETWEEN {:p0} AND {:p1}",
},
{
`$dbx.notBetween("a", 1, 2).build(db, {})`,
"`a` NOT BETWEEN {:p0} AND {:p1}",
},
}
for _, s := range sceneraios {
result, err := vm.RunString(s.js)
if err != nil {
t.Fatalf("[%s] Failed to execute js script, got %v", s.js, err)
}
v, _ := result.Export().(string)
if v != s.expected {
t.Fatalf("[%s] Expected \n%s, \ngot \n%s", s.js, s.expected, v)
}
}
}
func TestTokensBinds(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
admin, err := app.Dao().FindAdminByEmail("test@example.com")
if err != nil {
t.Fatal(err)
}
record, err := app.Dao().FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
vm := goja.New()
vm.Set("$app", app)
vm.Set("admin", admin)
vm.Set("record", record)
baseBinds(vm)
tokensBinds(vm)
testBindsCount(vm, "$tokens", 8, t)
sceneraios := []struct {
js string
key string
}{
{
`$tokens.adminAuthToken($app, admin)`,
admin.TokenKey + app.Settings().AdminAuthToken.Secret,
},
{
`$tokens.adminResetPasswordToken($app, admin)`,
admin.TokenKey + app.Settings().AdminPasswordResetToken.Secret,
},
{
`$tokens.adminFileToken($app, admin)`,
admin.TokenKey + app.Settings().AdminFileToken.Secret,
},
{
`$tokens.recordAuthToken($app, record)`,
record.TokenKey() + app.Settings().RecordAuthToken.Secret,
},
{
`$tokens.recordVerifyToken($app, record)`,
record.TokenKey() + app.Settings().RecordVerificationToken.Secret,
},
{
`$tokens.recordResetPasswordToken($app, record)`,
record.TokenKey() + app.Settings().RecordPasswordResetToken.Secret,
},
{
`$tokens.recordChangeEmailToken($app, record)`,
record.TokenKey() + app.Settings().RecordEmailChangeToken.Secret,
},
{
`$tokens.recordFileToken($app, record)`,
record.TokenKey() + app.Settings().RecordFileToken.Secret,
},
}
for _, s := range sceneraios {
result, err := vm.RunString(s.js)
if err != nil {
t.Fatalf("[%s] Failed to execute js script, got %v", s.js, err)
}
v, _ := result.Export().(string)
if _, err := security.ParseJWT(v, s.key); err != nil {
t.Fatalf("[%s] Failed to parse JWT %v, got %v", s.js, v, err)
}
}
}
func TestSecurityRandomStringBinds(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
vm := goja.New()
baseBinds(vm)
securityBinds(vm)
testBindsCount(vm, "$security", 7, t)
sceneraios := []struct {
js string
length int
}{
{`$security.randomString(6)`, 6},
{`$security.randomStringWithAlphabet(7, "abc")`, 7},
{`$security.pseudorandomString(8)`, 8},
{`$security.pseudorandomStringWithAlphabet(9, "abc")`, 9},
}
for _, s := range sceneraios {
result, err := vm.RunString(s.js)
if err != nil {
t.Fatalf("[%s] Failed to execute js script, got %v", s.js, err)
}
v, _ := result.Export().(string)
if len(v) != s.length {
t.Fatalf("[%s] Expected %d length string, \ngot \n%v", s.js, s.length, v)
}
}
}
func TestSecurityTokenBinds(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
vm := goja.New()
baseBinds(vm)
securityBinds(vm)
testBindsCount(vm, "$security", 7, t)
sceneraios := []struct {
js string
expected string
}{
{
`$security.parseUnverifiedToken("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIn0.aXzC7q7z1lX_hxk5P0R368xEU7H1xRwnBQQcLAmG0EY")`,
`{"name":"John Doe","sub":"1234567890"}`,
},
{
`$security.parseToken("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIn0.aXzC7q7z1lX_hxk5P0R368xEU7H1xRwnBQQcLAmG0EY", "test")`,
`{"name":"John Doe","sub":"1234567890"}`,
},
{
`$security.createToken({"exp": 123}, "test", 0)`, // overwrite the exp claim for static token
`"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjEyM30.7gbv7w672gApdBRASI6OniCtKwkKjhieSxsr6vxSrtw"`,
},
}
for _, s := range sceneraios {
result, err := vm.RunString(s.js)
if err != nil {
t.Fatalf("[%s] Failed to execute js script, got %v", s.js, err)
}
raw, _ := json.Marshal(result.Export())
if string(raw) != s.expected {
t.Fatalf("[%s] Expected \n%s, \ngot \n%s", s.js, s.expected, raw)
}
}
}
func TestFilesystemBinds(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
vm := goja.New()
vm.Set("mh", &multipart.FileHeader{Filename: "test"})
vm.Set("testFile", filepath.Join(app.DataDir(), "data.db"))
baseBinds(vm)
filesystemBinds(vm)
testBindsCount(vm, "$filesystem", 3, t)
// fileFromPath
{
v, err := vm.RunString(`$filesystem.fileFromPath(testFile)`)
if err != nil {
t.Fatal(err)
}
file, _ := v.Export().(*filesystem.File)
if file == nil || file.OriginalName != "data.db" {
t.Fatalf("[fileFromPath] Expected file with name %q, got %v", file.OriginalName, file)
}
}
// fileFromBytes
{
v, err := vm.RunString(`$filesystem.fileFromBytes([1, 2, 3], "test")`)
if err != nil {
t.Fatal(err)
}
file, _ := v.Export().(*filesystem.File)
if file == nil || file.OriginalName != "test" {
t.Fatalf("[fileFromBytes] Expected file with name %q, got %v", file.OriginalName, file)
}
}
// fileFromMultipart
{
v, err := vm.RunString(`$filesystem.fileFromMultipart(mh)`)
if err != nil {
t.Fatal(err)
}
file, _ := v.Export().(*filesystem.File)
if file == nil || file.OriginalName != "test" {
t.Fatalf("[fileFromMultipart] Expected file with name %q, got %v", file.OriginalName, file)
}
}
}
func TestFormsBinds(t *testing.T) {
vm := goja.New()
formsBinds(vm)
testBindsCount(vm, "this", 20, t)
}
func testBindsCount(vm *goja.Runtime, namespace string, count int, t *testing.T) {
v, err := vm.RunString(`Object.keys(` + namespace + `).length`)
if err != nil {
t.Fatal(err)
}
total, _ := v.Export().(int64)
if int(total) != count {
t.Fatalf("Expected %d %s binds, got %d", count, namespace, total)
}
}
+5 -5
View File
@@ -40,7 +40,7 @@ func (p *plugin) afterCollectionChange() func(*core.ModelEvent) error {
var template string
var templateErr error
if p.options.TemplateLang == TemplateLangJS {
if p.config.TemplateLang == TemplateLangJS {
template, templateErr = p.jsDiffTemplate(new, old)
} else {
template, templateErr = p.goDiffTemplate(new, old)
@@ -63,8 +63,8 @@ func (p *plugin) afterCollectionChange() func(*core.ModelEvent) error {
}
appliedTime := time.Now().Unix()
name := fmt.Sprintf("%d_%s.%s", appliedTime, action, p.options.TemplateLang)
filePath := filepath.Join(p.options.Dir, name)
name := fmt.Sprintf("%d_%s.%s", appliedTime, action, p.config.TemplateLang)
filePath := filepath.Join(p.config.Dir, name)
return p.app.Dao().RunInTransaction(func(txDao *daos.Dao) error {
// insert the migration entry
@@ -77,7 +77,7 @@ func (p *plugin) afterCollectionChange() func(*core.ModelEvent) error {
}
// ensure that the local migrations dir exist
if err := os.MkdirAll(p.options.Dir, os.ModePerm); err != nil {
if err := os.MkdirAll(p.config.Dir, os.ModePerm); err != nil {
return fmt.Errorf("failed to create migration dir: %w", err)
}
@@ -138,7 +138,7 @@ func (p *plugin) getCachedCollections() (map[string]*models.Collection, error) {
}
func (p *plugin) hasCustomMigrations() bool {
files, err := os.ReadDir(p.options.Dir)
files, err := os.ReadDir(p.config.Dir)
if err != nil {
return false
}
+31 -30
View File
@@ -5,10 +5,10 @@
//
// Example usage:
//
// migratecmd.MustRegister(app, app.RootCmd, &migratecmd.Options{
// migratecmd.MustRegister(app, app.RootCmd, migratecmd.Config{
// TemplateLang: migratecmd.TemplateLangJS, // default to migratecmd.TemplateLangGo
// Automigrate: true,
// Dir: "migrations_dir_path", // optional template migrations path; default to "pb_migrations" (for JS) and "migrations" (for Go)
// Dir: "/custom/migrations/dir", // optional template migrations path; default to "pb_migrations" (for JS) and "migrations" (for Go)
// })
//
// Note: To allow running JS migrations you'll need to enable first
@@ -32,8 +32,8 @@ import (
"github.com/spf13/cobra"
)
// Options defines optional struct to customize the default plugin behavior.
type Options struct {
// Config defines the config options of the migratecmd plugin.
type Config struct {
// Dir specifies the directory with the user defined migrations.
//
// If not set it fallbacks to a relative "pb_data/../pb_migrations" (for js)
@@ -48,35 +48,31 @@ type Options struct {
TemplateLang string
}
type plugin struct {
app core.App
options *Options
}
func MustRegister(app core.App, rootCmd *cobra.Command, options *Options) {
if err := Register(app, rootCmd, options); err != nil {
// MustRegister registers the migratecmd plugin to the provided app instance
// and panic if it fails.
//
// Example usage:
//
// migratecmd.MustRegister(app, app.RootCmd, migratecmd.Config{})
func MustRegister(app core.App, rootCmd *cobra.Command, config Config) {
if err := Register(app, rootCmd, config); err != nil {
panic(err)
}
}
func Register(app core.App, rootCmd *cobra.Command, options *Options) error {
p := &plugin{app: app}
// Register registers the migratecmd plugin to the provided app instance.
func Register(app core.App, rootCmd *cobra.Command, config Config) error {
p := &plugin{app: app, config: config}
if options != nil {
p.options = options
} else {
p.options = &Options{}
if p.config.TemplateLang == "" {
p.config.TemplateLang = TemplateLangGo
}
if p.options.TemplateLang == "" {
p.options.TemplateLang = TemplateLangGo
}
if p.options.Dir == "" {
if p.options.TemplateLang == TemplateLangJS {
p.options.Dir = filepath.Join(p.app.DataDir(), "../pb_migrations")
if p.config.Dir == "" {
if p.config.TemplateLang == TemplateLangJS {
p.config.Dir = filepath.Join(p.app.DataDir(), "../pb_migrations")
} else {
p.options.Dir = filepath.Join(p.app.DataDir(), "../migrations")
p.config.Dir = filepath.Join(p.app.DataDir(), "../migrations")
}
}
@@ -86,7 +82,7 @@ func Register(app core.App, rootCmd *cobra.Command, options *Options) error {
}
// watch for collection changes
if p.options.Automigrate {
if p.config.Automigrate {
// refresh the cache right after app bootstap
p.app.OnAfterBootstrap().Add(func(e *core.BootstrapEvent) error {
p.refreshCachedCollections()
@@ -129,6 +125,11 @@ func Register(app core.App, rootCmd *cobra.Command, options *Options) error {
return nil
}
type plugin struct {
app core.App
config Config
}
func (p *plugin) createCommand() *cobra.Command {
const cmdDesc = `Supported arguments are:
- up - runs all available migrations
@@ -185,9 +186,9 @@ func (p *plugin) migrateCreateHandler(template string, args []string, interactiv
}
name := args[0]
dir := p.options.Dir
dir := p.config.Dir
filename := fmt.Sprintf("%d_%s.%s", time.Now().Unix(), inflector.Snakecase(name), p.options.TemplateLang)
filename := fmt.Sprintf("%d_%s.%s", time.Now().Unix(), inflector.Snakecase(name), p.config.TemplateLang)
resultFilePath := path.Join(dir, filename)
@@ -206,7 +207,7 @@ func (p *plugin) migrateCreateHandler(template string, args []string, interactiv
// get default create template
if template == "" {
var templateErr error
if p.options.TemplateLang == TemplateLangJS {
if p.config.TemplateLang == TemplateLangJS {
template, templateErr = p.jsBlankTemplate()
} else {
template, templateErr = p.goBlankTemplate()
@@ -244,7 +245,7 @@ func (p *plugin) migrateCollectionsHandler(args []string, interactive bool) (str
var template string
var templateErr error
if p.options.TemplateLang == TemplateLangJS {
if p.config.TemplateLang == TemplateLangJS {
template, templateErr = p.jsSnapshotTemplate(collections)
} else {
template, templateErr = p.goSnapshotTemplate(collections)
+5 -5
View File
@@ -343,7 +343,7 @@ func init() {
}
`
return fmt.Sprintf(template, filepath.Base(p.options.Dir)), nil
return fmt.Sprintf(template, filepath.Base(p.config.Dir)), nil
}
func (p *plugin) goSnapshotTemplate(collections []*models.Collection) (string, error) {
@@ -380,7 +380,7 @@ func init() {
`
return fmt.Sprintf(
template,
filepath.Base(p.options.Dir),
filepath.Base(p.config.Dir),
escapeBacktick(string(jsonData)),
), nil
}
@@ -427,7 +427,7 @@ func init() {
return fmt.Sprintf(
template,
filepath.Base(p.options.Dir),
filepath.Base(p.config.Dir),
escapeBacktick(string(jsonData)),
collection.Id,
), nil
@@ -475,7 +475,7 @@ func init() {
return fmt.Sprintf(
template,
filepath.Base(p.options.Dir),
filepath.Base(p.config.Dir),
collection.Id,
escapeBacktick(string(jsonData)),
), nil
@@ -745,7 +745,7 @@ func init() {
return fmt.Sprintf(
template,
filepath.Base(p.options.Dir),
filepath.Base(p.config.Dir),
imports,
old.Id, strings.TrimSpace(up),
new.Id, strings.TrimSpace(down),