initial public commit
This commit is contained in:
+124
@@ -0,0 +1,124 @@
|
||||
package daos
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
// AdminQuery returns a new Admin select query.
|
||||
func (dao *Dao) AdminQuery() *dbx.SelectQuery {
|
||||
return dao.ModelQuery(&models.Admin{})
|
||||
}
|
||||
|
||||
// FindAdminById finds the admin with the provided id.
|
||||
func (dao *Dao) FindAdminById(id string) (*models.Admin, error) {
|
||||
model := &models.Admin{}
|
||||
|
||||
err := dao.AdminQuery().
|
||||
AndWhere(dbx.HashExp{"id": id}).
|
||||
Limit(1).
|
||||
One(model)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
// FindAdminByEmail finds the admin with the provided email address.
|
||||
func (dao *Dao) FindAdminByEmail(email string) (*models.Admin, error) {
|
||||
model := &models.Admin{}
|
||||
|
||||
err := dao.AdminQuery().
|
||||
AndWhere(dbx.HashExp{"email": email}).
|
||||
Limit(1).
|
||||
One(model)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
// FindAdminByEmail finds the admin associated with the provided JWT token.
|
||||
//
|
||||
// Returns an error if the JWT token is invalid or expired.
|
||||
func (dao *Dao) FindAdminByToken(token string, baseTokenKey string) (*models.Admin, error) {
|
||||
unverifiedClaims, err := security.ParseUnverifiedJWT(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// check required claims
|
||||
id, _ := unverifiedClaims["id"].(string)
|
||||
if id == "" {
|
||||
return nil, errors.New("Missing or invalid token claims.")
|
||||
}
|
||||
|
||||
admin, err := dao.FindAdminById(id)
|
||||
if err != nil || admin == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
verificationKey := admin.TokenKey + baseTokenKey
|
||||
|
||||
// verify token signature
|
||||
if _, err := security.ParseJWT(token, verificationKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return admin, nil
|
||||
}
|
||||
|
||||
// TotalAdmins returns the number of existing admin records.
|
||||
func (dao *Dao) TotalAdmins() (int, error) {
|
||||
var total int
|
||||
|
||||
err := dao.AdminQuery().Select("count(*)").Row(&total)
|
||||
|
||||
return total, err
|
||||
}
|
||||
|
||||
// IsAdminEmailUnique checks if the provided email address is not
|
||||
// already in use by other admins.
|
||||
func (dao *Dao) IsAdminEmailUnique(email string, excludeId string) bool {
|
||||
if email == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
var exists bool
|
||||
err := dao.AdminQuery().
|
||||
Select("count(*)").
|
||||
AndWhere(dbx.Not(dbx.HashExp{"id": excludeId})).
|
||||
AndWhere(dbx.HashExp{"email": email}).
|
||||
Limit(1).
|
||||
Row(&exists)
|
||||
|
||||
return err == nil && !exists
|
||||
}
|
||||
|
||||
// DeleteAdmin deletes the provided Admin model.
|
||||
//
|
||||
// Returns an error if there is only 1 admin.
|
||||
func (dao *Dao) DeleteAdmin(admin *models.Admin) error {
|
||||
total, err := dao.TotalAdmins()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if total == 1 {
|
||||
return errors.New("You cannot delete the only existing admin.")
|
||||
}
|
||||
|
||||
return dao.Delete(admin)
|
||||
}
|
||||
|
||||
// SaveAdmin upserts the provided Admin model.
|
||||
func (dao *Dao) SaveAdmin(admin *models.Admin) error {
|
||||
return dao.Save(admin)
|
||||
}
|
||||
@@ -0,0 +1,238 @@
|
||||
package daos_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestAdminQuery(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
expected := "SELECT {{_admins}}.* FROM `_admins`"
|
||||
|
||||
sql := app.Dao().AdminQuery().Build().SQL()
|
||||
if sql != expected {
|
||||
t.Errorf("Expected sql %s, got %s", expected, sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAdminById(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
id string
|
||||
expectError bool
|
||||
}{
|
||||
{"00000000-2b4a-a26b-4d01-42d3c3d77bc8", true},
|
||||
{"3f8397cc-2b4a-a26b-4d01-42d3c3d77bc8", false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
admin, err := app.Dao().FindAdminById(scenario.id)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if admin != nil && admin.Id != scenario.id {
|
||||
t.Errorf("(%d) Expected admin with id %s, got %s", i, scenario.id, admin.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAdminByEmail(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
email string
|
||||
expectError bool
|
||||
}{
|
||||
{"invalid", true},
|
||||
{"missing@example.com", true},
|
||||
{"test@example.com", false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
admin, err := app.Dao().FindAdminByEmail(scenario.email)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !scenario.expectError && admin.Email != scenario.email {
|
||||
t.Errorf("(%d) Expected admin with email %s, got %s", i, scenario.email, admin.Email)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAdminByToken(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
token string
|
||||
baseKey string
|
||||
expectedEmail string
|
||||
expectError bool
|
||||
}{
|
||||
// invalid base key (password reset key for auth token)
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjJiNGE5N2NjLTNmODMtNGQwMS1hMjZiLTNkNzdiYzg0MmQzYyIsInR5cGUiOiJhZG1pbiIsImV4cCI6MTg3MzQ2Mjc5Mn0.AtRtXR6FHBrCUGkj5OffhmxLbSZaQ4L_Qgw4gfoHyfo",
|
||||
app.Settings().AdminPasswordResetToken.Secret,
|
||||
"",
|
||||
true,
|
||||
},
|
||||
// expired token
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjJiNGE5N2NjLTNmODMtNGQwMS1hMjZiLTNkNzdiYzg0MmQzYyIsInR5cGUiOiJhZG1pbiIsImV4cCI6MTY0MDk5MTY2MX0.uXZ_ywsZeRFSvDNQ9zBoYUXKXw7VEr48Fzx-E06OkS8",
|
||||
app.Settings().AdminAuthToken.Secret,
|
||||
"",
|
||||
true,
|
||||
},
|
||||
// valid token
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjJiNGE5N2NjLTNmODMtNGQwMS1hMjZiLTNkNzdiYzg0MmQzYyIsInR5cGUiOiJhZG1pbiIsImV4cCI6MTg3MzQ2Mjc5Mn0.AtRtXR6FHBrCUGkj5OffhmxLbSZaQ4L_Qgw4gfoHyfo",
|
||||
app.Settings().AdminAuthToken.Secret,
|
||||
"test@example.com",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
admin, err := app.Dao().FindAdminByToken(scenario.token, scenario.baseKey)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !scenario.expectError && admin.Email != scenario.expectedEmail {
|
||||
t.Errorf("(%d) Expected admin model %s, got %s", i, scenario.expectedEmail, admin.Email)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTotalAdmins(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
result1, err := app.Dao().TotalAdmins()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result1 != 2 {
|
||||
t.Fatalf("Expected 2 admins, got %d", result1)
|
||||
}
|
||||
|
||||
// delete all
|
||||
app.Dao().DB().NewQuery("delete from {{_admins}}").Execute()
|
||||
|
||||
result2, err := app.Dao().TotalAdmins()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result2 != 0 {
|
||||
t.Fatalf("Expected 0 admins, got %d", result2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAdminEmailUnique(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
email string
|
||||
excludeId string
|
||||
expected bool
|
||||
}{
|
||||
{"", "", false},
|
||||
{"test@example.com", "", false},
|
||||
{"new@example.com", "", true},
|
||||
{"test@example.com", "2b4a97cc-3f83-4d01-a26b-3d77bc842d3c", true},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := app.Dao().IsAdminEmailUnique(scenario.email, scenario.excludeId)
|
||||
if result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAdmin(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
// try to delete unsaved admin model
|
||||
deleteErr0 := app.Dao().DeleteAdmin(&models.Admin{})
|
||||
if deleteErr0 == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
admin1, err := app.Dao().FindAdminByEmail("test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
admin2, err := app.Dao().FindAdminByEmail("test2@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
deleteErr1 := app.Dao().DeleteAdmin(admin1)
|
||||
if deleteErr1 != nil {
|
||||
t.Fatal(deleteErr1)
|
||||
}
|
||||
|
||||
// cannot delete the only remaining admin
|
||||
deleteErr2 := app.Dao().DeleteAdmin(admin2)
|
||||
if deleteErr2 == nil {
|
||||
t.Fatal("Expected delete error, got nil")
|
||||
}
|
||||
|
||||
total, _ := app.Dao().TotalAdmins()
|
||||
if total != 1 {
|
||||
t.Fatalf("Expected only 1 admin, got %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAdmin(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
// create
|
||||
newAdmin := &models.Admin{}
|
||||
newAdmin.Email = "new@example.com"
|
||||
newAdmin.SetPassword("123456")
|
||||
saveErr1 := app.Dao().SaveAdmin(newAdmin)
|
||||
if saveErr1 != nil {
|
||||
t.Fatal(saveErr1)
|
||||
}
|
||||
if newAdmin.Id == "" {
|
||||
t.Fatal("Expected admin id to be set")
|
||||
}
|
||||
|
||||
// update
|
||||
existingAdmin, err := app.Dao().FindAdminByEmail("test@example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
updatedEmail := "test_update@example.com"
|
||||
existingAdmin.Email = updatedEmail
|
||||
saveErr2 := app.Dao().SaveAdmin(existingAdmin)
|
||||
if saveErr2 != nil {
|
||||
t.Fatal(saveErr2)
|
||||
}
|
||||
existingAdmin, _ = app.Dao().FindAdminById(existingAdmin.Id)
|
||||
if existingAdmin.Email != updatedEmail {
|
||||
t.Fatalf("Expected admin email to be %s, got %s", updatedEmail, existingAdmin.Email)
|
||||
}
|
||||
}
|
||||
+217
@@ -0,0 +1,217 @@
|
||||
// Package daos handles common PocketBase DB model manipulations.
|
||||
//
|
||||
// Think of daos as DB repository and service layer in one.
|
||||
package daos
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
)
|
||||
|
||||
// New creates a new Dao instance with the provided db builder.
|
||||
func New(db dbx.Builder) *Dao {
|
||||
return &Dao{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// Dao handles various db operations.
|
||||
// Think of Dao as a repository and service layer in one.
|
||||
type Dao struct {
|
||||
db dbx.Builder
|
||||
|
||||
BeforeCreateFunc func(eventDao *Dao, m models.Model) error
|
||||
AfterCreateFunc func(eventDao *Dao, m models.Model)
|
||||
BeforeUpdateFunc func(eventDao *Dao, m models.Model) error
|
||||
AfterUpdateFunc func(eventDao *Dao, m models.Model)
|
||||
BeforeDeleteFunc func(eventDao *Dao, m models.Model) error
|
||||
AfterDeleteFunc func(eventDao *Dao, m models.Model)
|
||||
}
|
||||
|
||||
// DB returns the internal db builder (*dbx.DB or *dbx.TX).
|
||||
func (dao *Dao) DB() dbx.Builder {
|
||||
return dao.db
|
||||
}
|
||||
|
||||
// ModelQuery creates a new query with preset Select and From fields
|
||||
// based on the provided model argument.
|
||||
func (dao *Dao) ModelQuery(m models.Model) *dbx.SelectQuery {
|
||||
tableName := m.TableName()
|
||||
return dao.db.Select(fmt.Sprintf("{{%s}}.*", tableName)).From(tableName)
|
||||
}
|
||||
|
||||
// FindById finds a single db record with the specified id and
|
||||
// scans the result into m.
|
||||
func (dao *Dao) FindById(m models.Model, id string) error {
|
||||
return dao.ModelQuery(m).Where(dbx.HashExp{"id": id}).Limit(1).One(m)
|
||||
}
|
||||
|
||||
// RunInTransaction wraps fn into a transaction.
|
||||
//
|
||||
// It is safe to nest RunInTransaction calls.
|
||||
func (dao *Dao) RunInTransaction(fn func(txDao *Dao) error) error {
|
||||
switch txOrDB := dao.db.(type) {
|
||||
case *dbx.Tx:
|
||||
// nested transactions are not supported by default
|
||||
// so execute the function within the current transaction
|
||||
return fn(dao)
|
||||
case *dbx.DB:
|
||||
return txOrDB.Transactional(func(tx *dbx.Tx) error {
|
||||
txDao := New(tx)
|
||||
|
||||
txDao.BeforeCreateFunc = func(eventDao *Dao, m models.Model) error {
|
||||
if dao.BeforeCreateFunc != nil {
|
||||
return dao.BeforeCreateFunc(eventDao, m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
txDao.AfterCreateFunc = func(eventDao *Dao, m models.Model) {
|
||||
if dao.AfterCreateFunc != nil {
|
||||
dao.AfterCreateFunc(eventDao, m)
|
||||
}
|
||||
}
|
||||
txDao.BeforeUpdateFunc = func(eventDao *Dao, m models.Model) error {
|
||||
if dao.BeforeUpdateFunc != nil {
|
||||
return dao.BeforeUpdateFunc(eventDao, m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
txDao.AfterUpdateFunc = func(eventDao *Dao, m models.Model) {
|
||||
if dao.AfterUpdateFunc != nil {
|
||||
dao.AfterUpdateFunc(eventDao, m)
|
||||
}
|
||||
}
|
||||
txDao.BeforeDeleteFunc = func(eventDao *Dao, m models.Model) error {
|
||||
if dao.BeforeDeleteFunc != nil {
|
||||
return dao.BeforeDeleteFunc(eventDao, m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
txDao.AfterDeleteFunc = func(eventDao *Dao, m models.Model) {
|
||||
if dao.AfterDeleteFunc != nil {
|
||||
dao.AfterDeleteFunc(eventDao, m)
|
||||
}
|
||||
}
|
||||
|
||||
return fn(txDao)
|
||||
})
|
||||
}
|
||||
|
||||
return errors.New("Failed to start transaction (unknown dao.db)")
|
||||
}
|
||||
|
||||
// Delete deletes the provided model.
|
||||
func (dao *Dao) Delete(m models.Model) error {
|
||||
if !m.HasId() {
|
||||
return errors.New("ID is not set")
|
||||
}
|
||||
|
||||
if dao.BeforeDeleteFunc != nil {
|
||||
if err := dao.BeforeDeleteFunc(dao, m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
deleteErr := dao.db.Model(m).Delete()
|
||||
if deleteErr != nil {
|
||||
return deleteErr
|
||||
}
|
||||
|
||||
if dao.AfterDeleteFunc != nil {
|
||||
dao.AfterDeleteFunc(dao, m)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save upserts (update or create if primary key is not set) the provided model.
|
||||
func (dao *Dao) Save(m models.Model) error {
|
||||
if m.HasId() {
|
||||
return dao.update(m)
|
||||
}
|
||||
|
||||
return dao.create(m)
|
||||
}
|
||||
|
||||
func (dao *Dao) update(m models.Model) error {
|
||||
if !m.HasId() {
|
||||
return errors.New("ID is not set")
|
||||
}
|
||||
|
||||
m.RefreshUpdated()
|
||||
|
||||
if dao.BeforeUpdateFunc != nil {
|
||||
if err := dao.BeforeUpdateFunc(dao, m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := any(m).(models.ColumnValueMapper); ok {
|
||||
dataMap := v.ColumnValueMap()
|
||||
|
||||
_, err := dao.db.Update(
|
||||
m.TableName(),
|
||||
dataMap,
|
||||
dbx.HashExp{"id": m.GetId()},
|
||||
).Execute()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err := dao.db.Model(m).Update()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if dao.AfterUpdateFunc != nil {
|
||||
dao.AfterUpdateFunc(dao, m)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dao *Dao) create(m models.Model) error {
|
||||
if !m.HasId() {
|
||||
// auto generate id
|
||||
m.RefreshId()
|
||||
}
|
||||
|
||||
if m.GetCreated().IsZero() {
|
||||
m.RefreshCreated()
|
||||
}
|
||||
|
||||
if m.GetUpdated().IsZero() {
|
||||
m.RefreshUpdated()
|
||||
}
|
||||
|
||||
if dao.BeforeCreateFunc != nil {
|
||||
if err := dao.BeforeCreateFunc(dao, m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := any(m).(models.ColumnValueMapper); ok {
|
||||
dataMap := v.ColumnValueMap()
|
||||
|
||||
_, err := dao.db.Insert(m.TableName(), dataMap).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err := dao.db.Model(m).Insert()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if dao.AfterCreateFunc != nil {
|
||||
dao.AfterCreateFunc(dao, m)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,245 @@
|
||||
package daos_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/daos"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
dao := daos.New(testApp.DB())
|
||||
|
||||
if dao.DB() != testApp.DB() {
|
||||
t.Fatal("The 2 db instances are different")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDaoModelQuery(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
dao := daos.New(testApp.DB())
|
||||
|
||||
scenarios := []struct {
|
||||
model models.Model
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
&models.Collection{},
|
||||
"SELECT {{_collections}}.* FROM `_collections`",
|
||||
},
|
||||
{
|
||||
&models.User{},
|
||||
"SELECT {{_users}}.* FROM `_users`",
|
||||
},
|
||||
{
|
||||
&models.Request{},
|
||||
"SELECT {{_requests}}.* FROM `_requests`",
|
||||
},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
sql := dao.ModelQuery(scenario.model).Build().SQL()
|
||||
if sql != scenario.expected {
|
||||
t.Errorf("(%d) Expected select %s, got %s", i, scenario.expected, sql)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDaoFindById(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
model models.Model
|
||||
id string
|
||||
expectError bool
|
||||
}{
|
||||
// missing id
|
||||
{
|
||||
&models.Collection{},
|
||||
"00000000-075d-49fe-9d09-ea7e951000dc",
|
||||
true,
|
||||
},
|
||||
// existing collection id
|
||||
{
|
||||
&models.Collection{},
|
||||
"3f2888f8-075d-49fe-9d09-ea7e951000dc",
|
||||
false,
|
||||
},
|
||||
// existing user id
|
||||
{
|
||||
&models.User{},
|
||||
"97cc3d3d-6ba2-383f-b42a-7bc84d27410c",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
err := testApp.Dao().FindById(scenario.model, scenario.id)
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expectError, err)
|
||||
}
|
||||
|
||||
if !scenario.expectError && scenario.id != scenario.model.GetId() {
|
||||
t.Errorf("(%d) Expected model with id %v, got %v", i, scenario.id, scenario.model.GetId())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDaoRunInTransaction(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
// failed nested transaction
|
||||
testApp.Dao().RunInTransaction(func(txDao *daos.Dao) error {
|
||||
admin, _ := txDao.FindAdminByEmail("test@example.com")
|
||||
|
||||
return txDao.RunInTransaction(func(tx2Dao *daos.Dao) error {
|
||||
if err := tx2Dao.DeleteAdmin(admin); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return errors.New("test error")
|
||||
})
|
||||
})
|
||||
|
||||
// admin should still exist
|
||||
admin1, _ := testApp.Dao().FindAdminByEmail("test@example.com")
|
||||
if admin1 == nil {
|
||||
t.Fatal("Expected admin test@example.com to not be deleted")
|
||||
}
|
||||
|
||||
// successful nested transaction
|
||||
testApp.Dao().RunInTransaction(func(txDao *daos.Dao) error {
|
||||
admin, _ := txDao.FindAdminByEmail("test@example.com")
|
||||
|
||||
return txDao.RunInTransaction(func(tx2Dao *daos.Dao) error {
|
||||
return tx2Dao.DeleteAdmin(admin)
|
||||
})
|
||||
})
|
||||
|
||||
// admin should have been deleted
|
||||
admin2, _ := testApp.Dao().FindAdminByEmail("test@example.com")
|
||||
if admin2 != nil {
|
||||
t.Fatalf("Expected admin test@example.com to be deleted, found %v", admin2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDaoSaveCreate(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
model := &models.Admin{}
|
||||
model.Email = "test_new@example.com"
|
||||
model.Avatar = 8
|
||||
if err := testApp.Dao().Save(model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// refresh
|
||||
model, _ = testApp.Dao().FindAdminByEmail("test_new@example.com")
|
||||
|
||||
if model.Avatar != 8 {
|
||||
t.Fatalf("Expected model avatar field to be 8, got %v", model.Avatar)
|
||||
}
|
||||
|
||||
expectedHooks := []string{"OnModelBeforeCreate", "OnModelAfterCreate"}
|
||||
for _, h := range expectedHooks {
|
||||
if v, ok := testApp.EventCalls[h]; !ok || v != 1 {
|
||||
t.Fatalf("Expected event %s to be called exactly one time, got %d", h, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDaoSaveUpdate(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
model, _ := testApp.Dao().FindAdminByEmail("test@example.com")
|
||||
|
||||
model.Avatar = 8
|
||||
if err := testApp.Dao().Save(model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// refresh
|
||||
model, _ = testApp.Dao().FindAdminByEmail("test@example.com")
|
||||
|
||||
if model.Avatar != 8 {
|
||||
t.Fatalf("Expected model avatar field to be updated to 8, got %v", model.Avatar)
|
||||
}
|
||||
|
||||
expectedHooks := []string{"OnModelBeforeUpdate", "OnModelAfterUpdate"}
|
||||
for _, h := range expectedHooks {
|
||||
if v, ok := testApp.EventCalls[h]; !ok || v != 1 {
|
||||
t.Fatalf("Expected event %s to be called exactly one time, got %d", h, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDaoDelete(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
model, _ := testApp.Dao().FindAdminByEmail("test@example.com")
|
||||
|
||||
if err := testApp.Dao().Delete(model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
model, _ = testApp.Dao().FindAdminByEmail("test@example.com")
|
||||
if model != nil {
|
||||
t.Fatalf("Expected model to be deleted, found %v", model)
|
||||
}
|
||||
|
||||
expectedHooks := []string{"OnModelBeforeDelete", "OnModelAfterDelete"}
|
||||
for _, h := range expectedHooks {
|
||||
if v, ok := testApp.EventCalls[h]; !ok || v != 1 {
|
||||
t.Fatalf("Expected event %s to be called exactly one time, got %d", h, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDaoBeforeHooksError(t *testing.T) {
|
||||
testApp, _ := tests.NewTestApp()
|
||||
defer testApp.Cleanup()
|
||||
|
||||
testApp.Dao().BeforeCreateFunc = func(eventDao *daos.Dao, m models.Model) error {
|
||||
return errors.New("before_create")
|
||||
}
|
||||
testApp.Dao().BeforeUpdateFunc = func(eventDao *daos.Dao, m models.Model) error {
|
||||
return errors.New("before_update")
|
||||
}
|
||||
testApp.Dao().BeforeDeleteFunc = func(eventDao *daos.Dao, m models.Model) error {
|
||||
return errors.New("before_delete")
|
||||
}
|
||||
|
||||
existingModel, _ := testApp.Dao().FindAdminByEmail("test@example.com")
|
||||
|
||||
// try to create
|
||||
// ---
|
||||
newModel := &models.Admin{}
|
||||
newModel.Email = "test_new@example.com"
|
||||
if err := testApp.Dao().Save(newModel); err.Error() != "before_create" {
|
||||
t.Fatalf("Expected before_create error, got %v", err)
|
||||
}
|
||||
|
||||
// try to update
|
||||
// ---
|
||||
if err := testApp.Dao().Save(existingModel); err.Error() != "before_update" {
|
||||
t.Fatalf("Expected before_update error, got %v", err)
|
||||
}
|
||||
|
||||
// try to delete
|
||||
// ---
|
||||
if err := testApp.Dao().Delete(existingModel); err.Error() != "before_delete" {
|
||||
t.Fatalf("Expected before_delete error, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
package daos
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/models/schema"
|
||||
)
|
||||
|
||||
// CollectionQuery returns a new Collection select query.
|
||||
func (dao *Dao) CollectionQuery() *dbx.SelectQuery {
|
||||
return dao.ModelQuery(&models.Collection{})
|
||||
}
|
||||
|
||||
// FindCollectionByNameOrId finds the first collection by its name or id.
|
||||
func (dao *Dao) FindCollectionByNameOrId(nameOrId string) (*models.Collection, error) {
|
||||
model := &models.Collection{}
|
||||
|
||||
err := dao.CollectionQuery().
|
||||
AndWhere(dbx.Or(
|
||||
dbx.HashExp{"id": nameOrId},
|
||||
dbx.HashExp{"name": nameOrId},
|
||||
)).
|
||||
Limit(1).
|
||||
One(model)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
// IsCollectionNameUnique checks that there is no existing collection
|
||||
// with the provided name (case insensitive!).
|
||||
//
|
||||
// Note: case sensitive check because the name is used also as a table name for the records.
|
||||
func (dao *Dao) IsCollectionNameUnique(name string, excludeId string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
var exists bool
|
||||
err := dao.CollectionQuery().
|
||||
Select("count(*)").
|
||||
AndWhere(dbx.Not(dbx.HashExp{"id": excludeId})).
|
||||
AndWhere(dbx.NewExp("LOWER([[name]])={:name}", dbx.Params{"name": strings.ToLower(name)})).
|
||||
Limit(1).
|
||||
Row(&exists)
|
||||
|
||||
return err == nil && !exists
|
||||
}
|
||||
|
||||
// FindCollectionsWithUserFields finds all collections that has
|
||||
// at least one user schema field.
|
||||
func (dao *Dao) FindCollectionsWithUserFields() ([]*models.Collection, error) {
|
||||
result := []*models.Collection{}
|
||||
|
||||
err := dao.CollectionQuery().
|
||||
InnerJoin(
|
||||
"json_each(schema) as jsonField",
|
||||
dbx.NewExp(
|
||||
"json_extract(jsonField.value, '$.type') = {:type}",
|
||||
dbx.Params{"type": schema.FieldTypeUser},
|
||||
),
|
||||
).
|
||||
All(&result)
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
// FindCollectionReferences returns information for all
|
||||
// relation schema fields referencing the provided collection.
|
||||
//
|
||||
// If the provided collection has reference to itself then it will be
|
||||
// also included in the result. To exlude it, pass the collection id
|
||||
// as the excludeId argument.
|
||||
func (dao *Dao) FindCollectionReferences(collection *models.Collection, excludeId string) (map[*models.Collection][]*schema.SchemaField, error) {
|
||||
collections := []*models.Collection{}
|
||||
|
||||
err := dao.CollectionQuery().
|
||||
AndWhere(dbx.Not(dbx.HashExp{"id": excludeId})).
|
||||
All(&collections)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := map[*models.Collection][]*schema.SchemaField{}
|
||||
for _, c := range collections {
|
||||
for _, f := range c.Schema.Fields() {
|
||||
if f.Type != schema.FieldTypeRelation {
|
||||
continue
|
||||
}
|
||||
f.InitOptions()
|
||||
options, _ := f.Options.(*schema.RelationOptions)
|
||||
if options != nil && options.CollectionId == collection.Id {
|
||||
result[c] = append(result[c], f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteCollection deletes the provided Collection model.
|
||||
// This method automatically deletes the related collection records table.
|
||||
//
|
||||
// NB! The collection cannot be deleted, if:
|
||||
// - is system collection (aka. collection.System is true)
|
||||
// - is referenced as part of a relation field in another collection
|
||||
func (dao *Dao) DeleteCollection(collection *models.Collection) error {
|
||||
if collection.System {
|
||||
return errors.New("System collections cannot be deleted.")
|
||||
}
|
||||
|
||||
// ensure that there aren't any existing references.
|
||||
// note: the select is outside of the transaction to prevent SQLITE_LOCKED error when mixing read&write in a single transaction
|
||||
result, err := dao.FindCollectionReferences(collection, collection.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if total := len(result); total > 0 {
|
||||
return fmt.Errorf("The collection has external relation field references (%d).", total)
|
||||
}
|
||||
|
||||
return dao.RunInTransaction(func(txDao *Dao) error {
|
||||
// delete the related records table
|
||||
if err := txDao.DeleteTable(collection.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return txDao.Delete(collection)
|
||||
})
|
||||
}
|
||||
|
||||
// SaveCollection upserts the provided Collection model and updates
|
||||
// its related records table schema.
|
||||
func (dao *Dao) SaveCollection(collection *models.Collection) error {
|
||||
var oldCollection *models.Collection
|
||||
|
||||
if collection.HasId() {
|
||||
// get the existing collection state to compare with the new one
|
||||
// note: the select is outside of the transaction to prevent SQLITE_LOCKED error when mixing read&write in a single transaction
|
||||
var findErr error
|
||||
oldCollection, findErr = dao.FindCollectionByNameOrId(collection.Id)
|
||||
if findErr != nil {
|
||||
return findErr
|
||||
}
|
||||
}
|
||||
|
||||
return dao.RunInTransaction(func(txDao *Dao) error {
|
||||
// persist the collection model
|
||||
if err := txDao.Save(collection); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// sync the changes with the related records table
|
||||
return txDao.SyncRecordTableSchema(collection, oldCollection)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,253 @@
|
||||
package daos_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/models/schema"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
func TestCollectionQuery(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
expected := "SELECT {{_collections}}.* FROM `_collections`"
|
||||
|
||||
sql := app.Dao().CollectionQuery().Build().SQL()
|
||||
if sql != expected {
|
||||
t.Errorf("Expected sql %s, got %s", expected, sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindCollectionByNameOrId(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
nameOrId string
|
||||
expectError bool
|
||||
}{
|
||||
{"", true},
|
||||
{"missing", true},
|
||||
{"00000000-075d-49fe-9d09-ea7e951000dc", true},
|
||||
{"3f2888f8-075d-49fe-9d09-ea7e951000dc", false},
|
||||
{"demo", false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
model, err := app.Dao().FindCollectionByNameOrId(scenario.nameOrId)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if model != nil && model.Id != scenario.nameOrId && model.Name != scenario.nameOrId {
|
||||
t.Errorf("(%d) Expected model with identifier %s, got %v", i, scenario.nameOrId, model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCollectionNameUnique(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
excludeId string
|
||||
expected bool
|
||||
}{
|
||||
{"", "", false},
|
||||
{"demo", "", false},
|
||||
{"new", "", true},
|
||||
{"demo", "3f2888f8-075d-49fe-9d09-ea7e951000dc", true},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := app.Dao().IsCollectionNameUnique(scenario.name, scenario.excludeId)
|
||||
if result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindCollectionsWithUserFields(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
result, err := app.Dao().FindCollectionsWithUserFields()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedNames := []string{"demo2", models.ProfileCollectionName}
|
||||
|
||||
if len(result) != len(expectedNames) {
|
||||
t.Fatalf("Expected collections %v, got %v", expectedNames, result)
|
||||
}
|
||||
|
||||
for i, col := range result {
|
||||
if !list.ExistInSlice(col.Name, expectedNames) {
|
||||
t.Errorf("(%d) Couldn't find %s in %v", i, col.Name, expectedNames)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindCollectionReferences(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection, err := app.Dao().FindCollectionByNameOrId("demo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result, err := app.Dao().FindCollectionReferences(collection, collection.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("Expected 1 collection, got %d: %v", len(result), result)
|
||||
}
|
||||
|
||||
expectedFields := []string{"onerel", "manyrels", "rel_cascade"}
|
||||
|
||||
for col, fields := range result {
|
||||
if col.Name != "demo2" {
|
||||
t.Fatalf("Expected collection demo2, got %s", col.Name)
|
||||
}
|
||||
if len(fields) != len(expectedFields) {
|
||||
t.Fatalf("Expected fields %v, got %v", expectedFields, fields)
|
||||
}
|
||||
for i, f := range fields {
|
||||
if !list.ExistInSlice(f.Name, expectedFields) {
|
||||
t.Fatalf("(%d) Didn't expect field %v", i, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteCollection(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
c0 := &models.Collection{}
|
||||
c1, err := app.Dao().FindCollectionByNameOrId("demo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c2, err := app.Dao().FindCollectionByNameOrId("demo2")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c3, err := app.Dao().FindCollectionByNameOrId(models.ProfileCollectionName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
scenarios := []struct {
|
||||
model *models.Collection
|
||||
expectError bool
|
||||
}{
|
||||
{c0, true},
|
||||
{c1, true}, // is part of a reference
|
||||
{c2, false},
|
||||
{c3, true}, // system
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
err := app.Dao().DeleteCollection(scenario.model)
|
||||
hasErr := err != nil
|
||||
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v", i, scenario.expectError, hasErr)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestSaveCollectionCreate(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection := &models.Collection{
|
||||
Name: "new_test",
|
||||
Schema: schema.NewSchema(
|
||||
&schema.SchemaField{
|
||||
Type: schema.FieldTypeText,
|
||||
Name: "test",
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
err := app.Dao().SaveCollection(collection)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if collection.Id == "" {
|
||||
t.Fatal("Expected collection id to be set")
|
||||
}
|
||||
|
||||
// check if the records table was created
|
||||
hasTable := app.Dao().HasTable(collection.Name)
|
||||
if !hasTable {
|
||||
t.Fatalf("Expected records table %s to be created", collection.Name)
|
||||
}
|
||||
|
||||
// check if the records table has the schema fields
|
||||
columns, err := app.Dao().GetTableColumns(collection.Name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expectedColumns := []string{"id", "created", "updated", "test"}
|
||||
if len(columns) != len(expectedColumns) {
|
||||
t.Fatalf("Expected columns %v, got %v", expectedColumns, columns)
|
||||
}
|
||||
for i, c := range columns {
|
||||
if !list.ExistInSlice(c, expectedColumns) {
|
||||
t.Fatalf("(%d) Didn't expect record column %s", i, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveCollectionUpdate(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection, err := app.Dao().FindCollectionByNameOrId("demo3")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// rename an existing schema field and add a new one
|
||||
oldField := collection.Schema.GetFieldByName("title")
|
||||
oldField.Name = "title_update"
|
||||
collection.Schema.AddField(&schema.SchemaField{
|
||||
Type: schema.FieldTypeText,
|
||||
Name: "test",
|
||||
})
|
||||
|
||||
saveErr := app.Dao().SaveCollection(collection)
|
||||
if saveErr != nil {
|
||||
t.Fatal(saveErr)
|
||||
}
|
||||
|
||||
// check if the records table has the schema fields
|
||||
expectedColumns := []string{"id", "created", "updated", "title_update", "test"}
|
||||
columns, err := app.Dao().GetTableColumns(collection.Name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(columns) != len(expectedColumns) {
|
||||
t.Fatalf("Expected columns %v, got %v", expectedColumns, columns)
|
||||
}
|
||||
for i, c := range columns {
|
||||
if !list.ExistInSlice(c, expectedColumns) {
|
||||
t.Fatalf("(%d) Didn't expect record column %s", i, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package daos
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
// ParamQuery returns a new Param select query.
|
||||
func (dao *Dao) ParamQuery() *dbx.SelectQuery {
|
||||
return dao.ModelQuery(&models.Param{})
|
||||
}
|
||||
|
||||
// FindParamByKey finds the first Param model with the provided key.
|
||||
func (dao *Dao) FindParamByKey(key string) (*models.Param, error) {
|
||||
param := &models.Param{}
|
||||
|
||||
err := dao.ParamQuery().
|
||||
AndWhere(dbx.HashExp{"key": key}).
|
||||
Limit(1).
|
||||
One(param)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return param, nil
|
||||
}
|
||||
|
||||
// SaveParam creates or updates a Param model by the provided key-value pair.
|
||||
// The value argument will be encoded as json string.
|
||||
//
|
||||
// If `optEncryptionKey` is provided it will encrypt the value before storing it.
|
||||
func (dao *Dao) SaveParam(key string, value any, optEncryptionKey ...string) error {
|
||||
param, _ := dao.FindParamByKey(key)
|
||||
if param == nil {
|
||||
param = &models.Param{Key: key}
|
||||
}
|
||||
|
||||
var normalizedValue any
|
||||
|
||||
// encrypt if optEncryptionKey is set
|
||||
if len(optEncryptionKey) > 0 && optEncryptionKey[0] != "" {
|
||||
encoded, encodingErr := json.Marshal(value)
|
||||
if encodingErr != nil {
|
||||
return encodingErr
|
||||
}
|
||||
|
||||
encryptVal, encryptErr := security.Encrypt(encoded, optEncryptionKey[0])
|
||||
if encryptErr != nil {
|
||||
return encryptErr
|
||||
}
|
||||
|
||||
normalizedValue = encryptVal
|
||||
} else {
|
||||
normalizedValue = value
|
||||
}
|
||||
|
||||
encodedValue := types.JsonRaw{}
|
||||
if err := encodedValue.Scan(normalizedValue); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
param.Value = encodedValue
|
||||
|
||||
return dao.Save(param)
|
||||
}
|
||||
|
||||
// DeleteParam deletes the provided Param model.
|
||||
func (dao *Dao) DeleteParam(param *models.Param) error {
|
||||
return dao.Delete(param)
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
package daos_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestParamQuery(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
expected := "SELECT {{_params}}.* FROM `_params`"
|
||||
|
||||
sql := app.Dao().ParamQuery().Build().SQL()
|
||||
if sql != expected {
|
||||
t.Errorf("Expected sql %s, got %s", expected, sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindParamByKey(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
key string
|
||||
expectError bool
|
||||
}{
|
||||
{"", true},
|
||||
{"missing", true},
|
||||
{models.ParamAppSettings, false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
param, err := app.Dao().FindParamByKey(scenario.key)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if param != nil && param.Key != scenario.key {
|
||||
t.Errorf("(%d) Expected param with identifier %s, got %v", i, scenario.key, param.Key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveParam(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
key string
|
||||
value any
|
||||
}{
|
||||
{"", "demo"},
|
||||
{"test", nil},
|
||||
{"test", ""},
|
||||
{"test", 1},
|
||||
{"test", 123},
|
||||
{models.ParamAppSettings, map[string]any{"test": 123}},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
err := app.Dao().SaveParam(scenario.key, scenario.value)
|
||||
if err != nil {
|
||||
t.Errorf("(%d) %v", i, err)
|
||||
}
|
||||
|
||||
jsonRaw := types.JsonRaw{}
|
||||
jsonRaw.Scan(scenario.value)
|
||||
encodedScenarioValue, err := jsonRaw.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Errorf("(%d) Encoded error %v", i, err)
|
||||
}
|
||||
|
||||
// check if the param was really saved
|
||||
param, _ := app.Dao().FindParamByKey(scenario.key)
|
||||
encodedParamValue, err := param.Value.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Errorf("(%d) Encoded error %v", i, err)
|
||||
}
|
||||
|
||||
if string(encodedParamValue) != string(encodedScenarioValue) {
|
||||
t.Errorf("(%d) Expected the two values to be equal, got %v vs %v", i, string(encodedParamValue), string(encodedScenarioValue))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveParamEncrypted(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
encryptionKey := security.RandomString(32)
|
||||
data := map[string]int{"test": 123}
|
||||
expected := map[string]int{}
|
||||
|
||||
err := app.Dao().SaveParam("test", data, encryptionKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// check if the param was really saved
|
||||
param, _ := app.Dao().FindParamByKey("test")
|
||||
|
||||
// decrypt
|
||||
decrypted, decryptErr := security.Decrypt(string(param.Value), encryptionKey)
|
||||
if decryptErr != nil {
|
||||
t.Fatal(decryptErr)
|
||||
}
|
||||
|
||||
// decode
|
||||
decryptedDecodeErr := json.Unmarshal(decrypted, &expected)
|
||||
if decryptedDecodeErr != nil {
|
||||
t.Fatal(decryptedDecodeErr)
|
||||
}
|
||||
|
||||
// check if the decoded value is correct
|
||||
if len(expected) != len(data) || expected["test"] != data["test"] {
|
||||
t.Fatalf("Expected %v, got %v", expected, data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteParam(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
// unsaved param
|
||||
err1 := app.Dao().DeleteParam(&models.Param{})
|
||||
if err1 == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
// existing param
|
||||
param, _ := app.Dao().FindParamByKey(models.ParamAppSettings)
|
||||
err2 := app.Dao().DeleteParam(param)
|
||||
if err2 != nil {
|
||||
t.Fatalf("Expected nil, got error %v", err2)
|
||||
}
|
||||
|
||||
// check if it was really deleted
|
||||
paramCheck, _ := app.Dao().FindParamByKey(models.ParamAppSettings)
|
||||
if paramCheck != nil {
|
||||
t.Fatalf("Expected param to be deleted, got %v", paramCheck)
|
||||
}
|
||||
}
|
||||
+351
@@ -0,0 +1,351 @@
|
||||
package daos
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/models/schema"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
// RecordQuery returns a new Record select query.
|
||||
func (dao *Dao) RecordQuery(collection *models.Collection) *dbx.SelectQuery {
|
||||
tableName := collection.Name
|
||||
selectCols := fmt.Sprintf("%s.*", dao.DB().QuoteSimpleColumnName(tableName))
|
||||
|
||||
return dao.DB().Select(selectCols).From(tableName)
|
||||
}
|
||||
|
||||
// FindRecordById finds the Record model by its id.
|
||||
func (dao *Dao) FindRecordById(
|
||||
collection *models.Collection,
|
||||
recordId string,
|
||||
filter func(q *dbx.SelectQuery) error,
|
||||
) (*models.Record, error) {
|
||||
tableName := collection.Name
|
||||
|
||||
query := dao.RecordQuery(collection).
|
||||
AndWhere(dbx.HashExp{tableName + ".id": recordId})
|
||||
|
||||
if filter != nil {
|
||||
if err := filter(query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
row := dbx.NullStringMap{}
|
||||
if err := query.Limit(1).One(row); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return models.NewRecordFromNullStringMap(collection, row), nil
|
||||
}
|
||||
|
||||
// FindRecordsByIds finds all Record models by the provided ids.
|
||||
// If no records are found, returns an empty slice.
|
||||
func (dao *Dao) FindRecordsByIds(
|
||||
collection *models.Collection,
|
||||
recordIds []string,
|
||||
filter func(q *dbx.SelectQuery) error,
|
||||
) ([]*models.Record, error) {
|
||||
tableName := collection.Name
|
||||
|
||||
query := dao.RecordQuery(collection).
|
||||
AndWhere(dbx.In(tableName+".id", list.ToInterfaceSlice(recordIds)...))
|
||||
|
||||
if filter != nil {
|
||||
if err := filter(query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
rows := []dbx.NullStringMap{}
|
||||
if err := query.All(&rows); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return models.NewRecordsFromNullStringMaps(collection, rows), nil
|
||||
}
|
||||
|
||||
// FindRecordsByExpr finds all records by the provided db expression.
|
||||
// If no records are found, returns an empty slice.
|
||||
//
|
||||
// Example:
|
||||
// expr := dbx.HashExp{"email": "test@example.com"}
|
||||
// dao.FindRecordsByExpr(collection, expr)
|
||||
func (dao *Dao) FindRecordsByExpr(collection *models.Collection, expr dbx.Expression) ([]*models.Record, error) {
|
||||
if expr == nil {
|
||||
return nil, errors.New("Missing filter expression")
|
||||
}
|
||||
|
||||
rows := []dbx.NullStringMap{}
|
||||
|
||||
err := dao.RecordQuery(collection).
|
||||
AndWhere(expr).
|
||||
All(&rows)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return models.NewRecordsFromNullStringMaps(collection, rows), nil
|
||||
}
|
||||
|
||||
// FindFirstRecordByData returns the first found record matching
|
||||
// the provided key-value pair.
|
||||
func (dao *Dao) FindFirstRecordByData(collection *models.Collection, key string, value any) (*models.Record, error) {
|
||||
row := dbx.NullStringMap{}
|
||||
|
||||
err := dao.RecordQuery(collection).
|
||||
AndWhere(dbx.HashExp{key: value}).
|
||||
Limit(1).
|
||||
One(row)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return models.NewRecordFromNullStringMap(collection, row), nil
|
||||
}
|
||||
|
||||
// IsRecordValueUnique checks if the provided key-value pair is a unique Record value.
|
||||
//
|
||||
// NB! Array values (eg. from multiple select fields) are matched
|
||||
// as a serialized json strings (eg. `["a","b"]`), so the value uniqueness
|
||||
// depends on the elements order. Or in other words the following values
|
||||
// are considered different: `[]string{"a","b"}` and `[]string{"b","a"}`
|
||||
func (dao *Dao) IsRecordValueUnique(
|
||||
collection *models.Collection,
|
||||
key string,
|
||||
value any,
|
||||
excludeId string,
|
||||
) bool {
|
||||
var exists bool
|
||||
|
||||
var normalizedVal any
|
||||
switch val := value.(type) {
|
||||
case []string:
|
||||
normalizedVal = append(types.JsonArray{}, list.ToInterfaceSlice(val)...)
|
||||
case []any:
|
||||
normalizedVal = append(types.JsonArray{}, val...)
|
||||
default:
|
||||
normalizedVal = val
|
||||
}
|
||||
|
||||
err := dao.RecordQuery(collection).
|
||||
Select("count(*)").
|
||||
AndWhere(dbx.Not(dbx.HashExp{"id": excludeId})).
|
||||
AndWhere(dbx.HashExp{key: normalizedVal}).
|
||||
Limit(1).
|
||||
Row(&exists)
|
||||
|
||||
return err == nil && !exists
|
||||
}
|
||||
|
||||
// FindUserRelatedRecords returns all records that has a reference
|
||||
// to the provided User model (via the user shema field).
|
||||
func (dao *Dao) FindUserRelatedRecords(user *models.User) ([]*models.Record, error) {
|
||||
collections, err := dao.FindCollectionsWithUserFields()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := []*models.Record{}
|
||||
for _, collection := range collections {
|
||||
userFields := []*schema.SchemaField{}
|
||||
|
||||
// prepare fields options
|
||||
if err := collection.Schema.InitFieldsOptions(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// extract user fields
|
||||
for _, field := range collection.Schema.Fields() {
|
||||
if field.Type == schema.FieldTypeUser {
|
||||
userFields = append(userFields, field)
|
||||
}
|
||||
}
|
||||
|
||||
// fetch records associated to the user
|
||||
exprs := []dbx.Expression{}
|
||||
for _, field := range userFields {
|
||||
exprs = append(exprs, dbx.HashExp{field.Name: user.Id})
|
||||
}
|
||||
rows := []dbx.NullStringMap{}
|
||||
if err := dao.RecordQuery(collection).AndWhere(dbx.Or(exprs...)).All(&rows); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
records := models.NewRecordsFromNullStringMaps(collection, rows)
|
||||
|
||||
result = append(result, records...)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SaveRecord upserts the provided Record model.
|
||||
func (dao *Dao) SaveRecord(record *models.Record) error {
|
||||
return dao.Save(record)
|
||||
}
|
||||
|
||||
// DeleteRecord deletes the provided Record model.
|
||||
//
|
||||
// This method will also cascade the delete operation to all linked
|
||||
// relational records (delete or set to NULL, depending on the rel settings).
|
||||
//
|
||||
// The delete operation may fail if the record is part of a required
|
||||
// reference in another record (aka. cannot be deleted or set to NULL).
|
||||
func (dao *Dao) DeleteRecord(record *models.Record) error {
|
||||
// check for references
|
||||
// note: the select is outside of the transaction to prevent SQLITE_LOCKED error when mixing read&write in a single transaction
|
||||
refs, err := dao.FindCollectionReferences(record.Collection(), "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check if related records has to be deleted (if `CascadeDelete` is set)
|
||||
// OR
|
||||
// just unset the record id from any relation field values (if they are not required)
|
||||
// -----------------------------------------------------------
|
||||
return dao.RunInTransaction(func(txDao *Dao) error {
|
||||
for refCollection, fields := range refs {
|
||||
for _, field := range fields {
|
||||
options, _ := field.Options.(*schema.RelationOptions)
|
||||
|
||||
rows := []dbx.NullStringMap{}
|
||||
|
||||
// note: the select is not using the transaction dao to prevent SQLITE_LOCKED error when mixing read&write in a single transaction
|
||||
err := dao.RecordQuery(refCollection).
|
||||
AndWhere(dbx.Not(dbx.HashExp{"id": record.Id})).
|
||||
AndWhere(dbx.Like(field.Name, record.Id).Match(true, true)).
|
||||
All(&rows)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
refRecords := models.NewRecordsFromNullStringMaps(refCollection, rows)
|
||||
for _, refRecord := range refRecords {
|
||||
ids := refRecord.GetStringSliceDataValue(field.Name)
|
||||
|
||||
// unset the record id
|
||||
for i := len(ids) - 1; i >= 0; i-- {
|
||||
if ids[i] == record.Id {
|
||||
ids = append(ids[:i], ids[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// cascade delete the reference
|
||||
// (only if there are no other active references in case of multiple select)
|
||||
if options.CascadeDelete && len(ids) == 0 {
|
||||
if err := txDao.DeleteRecord(refRecord); err != nil {
|
||||
return err
|
||||
}
|
||||
// no further action are needed (the reference is deleted)
|
||||
continue
|
||||
}
|
||||
|
||||
if field.Required && len(ids) == 0 {
|
||||
return fmt.Errorf("The record cannot be deleted because it is part of a required reference in record %s (%s collection).", refRecord.Id, refCollection.Name)
|
||||
}
|
||||
|
||||
// save the reference changes
|
||||
refRecord.SetDataValue(field.Name, field.PrepareValue(ids))
|
||||
if err := txDao.SaveRecord(refRecord); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return txDao.Delete(record)
|
||||
})
|
||||
}
|
||||
|
||||
// SyncRecordTableSchema compares the two provided collections
|
||||
// and applies the necessary related record table changes.
|
||||
//
|
||||
// If `oldCollection` is null, then only `newCollection` is used to create the record table.
|
||||
func (dao *Dao) SyncRecordTableSchema(newCollection *models.Collection, oldCollection *models.Collection) error {
|
||||
// create
|
||||
if oldCollection == nil {
|
||||
cols := map[string]string{
|
||||
schema.ReservedFieldNameId: "TEXT PRIMARY KEY",
|
||||
schema.ReservedFieldNameCreated: `TEXT DEFAULT "" NOT NULL`,
|
||||
schema.ReservedFieldNameUpdated: `TEXT DEFAULT "" NOT NULL`,
|
||||
}
|
||||
|
||||
tableName := newCollection.Name
|
||||
|
||||
// add schema field definitions
|
||||
for _, field := range newCollection.Schema.Fields() {
|
||||
cols[field.Name] = field.ColDefinition()
|
||||
}
|
||||
|
||||
// create table
|
||||
_, tableErr := dao.DB().CreateTable(tableName, cols).Execute()
|
||||
if tableErr != nil {
|
||||
return tableErr
|
||||
}
|
||||
|
||||
// add index on the base `created` column
|
||||
_, indexErr := dao.DB().CreateIndex(tableName, tableName+"_created_idx", "created").Execute()
|
||||
if indexErr != nil {
|
||||
return indexErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// update
|
||||
return dao.RunInTransaction(func(txDao *Dao) error {
|
||||
oldTableName := oldCollection.Name
|
||||
newTableName := newCollection.Name
|
||||
oldSchema := oldCollection.Schema
|
||||
newSchema := newCollection.Schema
|
||||
|
||||
// check for renamed table
|
||||
if strings.ToLower(oldTableName) != strings.ToLower(newTableName) {
|
||||
_, err := dao.DB().RenameTable(oldTableName, newTableName).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// check for deleted columns
|
||||
for _, oldField := range oldSchema.Fields() {
|
||||
if f := newSchema.GetFieldById(oldField.Id); f != nil {
|
||||
continue // exist
|
||||
}
|
||||
|
||||
_, err := txDao.DB().DropColumn(newTableName, oldField.Name).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// check for new or renamed columns
|
||||
for _, field := range newSchema.Fields() {
|
||||
oldField := oldSchema.GetFieldById(field.Id)
|
||||
if oldField != nil {
|
||||
// rename
|
||||
_, err := txDao.DB().RenameColumn(newTableName, oldField.Name, field.Name).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// add
|
||||
_, err := txDao.DB().AddColumn(newTableName, field.Name, field.ColDefinition()).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
package daos
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/models/schema"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
// MaxExpandDepth specifies the max allowed nested expand depth path.
|
||||
const MaxExpandDepth = 6
|
||||
|
||||
// ExpandFetchFunc defines the function that is used to fetch the expanded relation records.
|
||||
type ExpandFetchFunc func(relCollection *models.Collection, relIds []string) ([]*models.Record, error)
|
||||
|
||||
// ExpandRecord expands the relations of a single Record model.
|
||||
func (dao *Dao) ExpandRecord(record *models.Record, expands []string, fetchFunc ExpandFetchFunc) error {
|
||||
return dao.ExpandRecords([]*models.Record{record}, expands, fetchFunc)
|
||||
}
|
||||
|
||||
// ExpandRecords expands the relations of the provided Record models list.
|
||||
func (dao *Dao) ExpandRecords(records []*models.Record, expands []string, fetchFunc ExpandFetchFunc) error {
|
||||
normalized := normalizeExpands(expands)
|
||||
|
||||
for _, expand := range normalized {
|
||||
if err := dao.expandRecords(records, expand, fetchFunc, 1); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// notes:
|
||||
// - fetchFunc must be non-nil func
|
||||
// - all records are expected to be from the same collection
|
||||
// - if MaxExpandDepth is reached, the function returns nil ignoring the remaining expand path
|
||||
func (dao *Dao) expandRecords(records []*models.Record, expandPath string, fetchFunc ExpandFetchFunc, recursionLevel int) error {
|
||||
if fetchFunc == nil {
|
||||
return errors.New("Relation records fetchFunc is not set.")
|
||||
}
|
||||
|
||||
if expandPath == "" || recursionLevel > MaxExpandDepth || len(records) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.SplitN(expandPath, ".", 2)
|
||||
|
||||
// extract the relation field (if exist)
|
||||
mainCollection := records[0].Collection()
|
||||
relField := mainCollection.Schema.GetFieldByName(parts[0])
|
||||
if relField == nil {
|
||||
return fmt.Errorf("Couldn't find field %q in collection %q.", parts[0], mainCollection.Name)
|
||||
}
|
||||
relField.InitOptions()
|
||||
relFieldOptions, _ := relField.Options.(*schema.RelationOptions)
|
||||
|
||||
relCollection, err := dao.FindCollectionByNameOrId(relFieldOptions.CollectionId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Couldn't find collection %q.", relFieldOptions.CollectionId)
|
||||
}
|
||||
|
||||
// extract the id of the relations to expand
|
||||
relIds := []string{}
|
||||
for _, record := range records {
|
||||
relIds = append(relIds, record.GetStringSliceDataValue(relField.Name)...)
|
||||
}
|
||||
|
||||
// fetch rels
|
||||
rels, relsErr := fetchFunc(relCollection, relIds)
|
||||
if relsErr != nil {
|
||||
return relsErr
|
||||
}
|
||||
|
||||
// expand nested fields
|
||||
if len(parts) > 1 {
|
||||
err := dao.expandRecords(rels, parts[1], fetchFunc, recursionLevel+1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// reindex with the rel id
|
||||
indexedRels := map[string]*models.Record{}
|
||||
for _, rel := range rels {
|
||||
indexedRels[rel.GetId()] = rel
|
||||
}
|
||||
|
||||
for _, model := range records {
|
||||
relIds := model.GetStringSliceDataValue(relField.Name)
|
||||
|
||||
validRels := []*models.Record{}
|
||||
for _, id := range relIds {
|
||||
if rel, ok := indexedRels[id]; ok {
|
||||
validRels = append(validRels, rel)
|
||||
}
|
||||
}
|
||||
|
||||
if len(validRels) == 0 {
|
||||
continue // no valid relations
|
||||
}
|
||||
|
||||
expandData := model.GetExpand()
|
||||
|
||||
// normalize and set the expanded relations
|
||||
if relFieldOptions.MaxSelect == 1 {
|
||||
expandData[relField.Name] = validRels[0]
|
||||
} else {
|
||||
expandData[relField.Name] = validRels
|
||||
}
|
||||
model.SetExpand(expandData)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalizeExpands normalizes expand strings and merges self containing paths
|
||||
// (eg. ["a.b.c", "a.b", " test ", " ", "test"] -> ["a.b.c", "test"]).
|
||||
func normalizeExpands(paths []string) []string {
|
||||
result := []string{}
|
||||
|
||||
// normalize paths
|
||||
normalized := []string{}
|
||||
for _, p := range paths {
|
||||
p := strings.ReplaceAll(p, " ", "") // replace spaces
|
||||
p = strings.Trim(p, ".") // trim incomplete paths
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
normalized = append(normalized, p)
|
||||
}
|
||||
|
||||
// merge containing paths
|
||||
for i, p1 := range normalized {
|
||||
var skip bool
|
||||
for j, p2 := range normalized {
|
||||
if i == j {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(p2, p1+".") {
|
||||
// skip because there is more detailed expand path
|
||||
skip = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !skip {
|
||||
result = append(result, p1)
|
||||
}
|
||||
}
|
||||
|
||||
return list.ToUniqueStringSlice(result)
|
||||
}
|
||||
@@ -0,0 +1,258 @@
|
||||
package daos_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/daos"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
func TestExpandRecords(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
col, _ := app.Dao().FindCollectionByNameOrId("demo4")
|
||||
|
||||
scenarios := []struct {
|
||||
recordIds []string
|
||||
expands []string
|
||||
fetchFunc daos.ExpandFetchFunc
|
||||
expectExpandProps int
|
||||
expectError bool
|
||||
}{
|
||||
// empty records
|
||||
{
|
||||
[]string{},
|
||||
[]string{"onerel", "manyrels.onerel.manyrels"},
|
||||
func(c *models.Collection, ids []string) ([]*models.Record, error) {
|
||||
return app.Dao().FindRecordsByIds(c, ids, nil)
|
||||
},
|
||||
0,
|
||||
false,
|
||||
},
|
||||
// empty expand
|
||||
{
|
||||
[]string{"b8ba58f9-e2d7-42a0-b0e7-a11efd98236b", "df55c8ff-45ef-4c82-8aed-6e2183fe1125"},
|
||||
[]string{},
|
||||
func(c *models.Collection, ids []string) ([]*models.Record, error) {
|
||||
return app.Dao().FindRecordsByIds(c, ids, nil)
|
||||
},
|
||||
0,
|
||||
false,
|
||||
},
|
||||
// empty fetchFunc
|
||||
{
|
||||
[]string{"b8ba58f9-e2d7-42a0-b0e7-a11efd98236b", "df55c8ff-45ef-4c82-8aed-6e2183fe1125"},
|
||||
[]string{"onerel", "manyrels.onerel.manyrels"},
|
||||
nil,
|
||||
0,
|
||||
true,
|
||||
},
|
||||
// fetchFunc with error
|
||||
{
|
||||
[]string{"b8ba58f9-e2d7-42a0-b0e7-a11efd98236b", "df55c8ff-45ef-4c82-8aed-6e2183fe1125"},
|
||||
[]string{"onerel", "manyrels.onerel.manyrels"},
|
||||
func(c *models.Collection, ids []string) ([]*models.Record, error) {
|
||||
return nil, errors.New("test error")
|
||||
},
|
||||
0,
|
||||
true,
|
||||
},
|
||||
// invalid missing first level expand
|
||||
{
|
||||
[]string{"b8ba58f9-e2d7-42a0-b0e7-a11efd98236b", "df55c8ff-45ef-4c82-8aed-6e2183fe1125"},
|
||||
[]string{"invalid"},
|
||||
func(c *models.Collection, ids []string) ([]*models.Record, error) {
|
||||
return app.Dao().FindRecordsByIds(c, ids, nil)
|
||||
},
|
||||
0,
|
||||
true,
|
||||
},
|
||||
// invalid missing second level expand
|
||||
{
|
||||
[]string{"b8ba58f9-e2d7-42a0-b0e7-a11efd98236b", "df55c8ff-45ef-4c82-8aed-6e2183fe1125"},
|
||||
[]string{"manyrels.invalid"},
|
||||
func(c *models.Collection, ids []string) ([]*models.Record, error) {
|
||||
return app.Dao().FindRecordsByIds(c, ids, nil)
|
||||
},
|
||||
0,
|
||||
true,
|
||||
},
|
||||
// expand normalizations
|
||||
{
|
||||
[]string{
|
||||
"b8ba58f9-e2d7-42a0-b0e7-a11efd98236b",
|
||||
"df55c8ff-45ef-4c82-8aed-6e2183fe1125",
|
||||
"b84cd893-7119-43c9-8505-3c4e22da28a9",
|
||||
"054f9f24-0a0a-4e09-87b1-bc7ff2b336a2",
|
||||
},
|
||||
[]string{"manyrels.onerel.manyrels.onerel", "manyrels.onerel", "onerel", "onerel.", " onerel ", ""},
|
||||
func(c *models.Collection, ids []string) ([]*models.Record, error) {
|
||||
return app.Dao().FindRecordsByIds(c, ids, nil)
|
||||
},
|
||||
9,
|
||||
false,
|
||||
},
|
||||
// single expand
|
||||
{
|
||||
[]string{
|
||||
"b8ba58f9-e2d7-42a0-b0e7-a11efd98236b",
|
||||
"df55c8ff-45ef-4c82-8aed-6e2183fe1125",
|
||||
"b84cd893-7119-43c9-8505-3c4e22da28a9", // no manyrels
|
||||
"054f9f24-0a0a-4e09-87b1-bc7ff2b336a2", // no manyrels
|
||||
},
|
||||
[]string{"manyrels"},
|
||||
func(c *models.Collection, ids []string) ([]*models.Record, error) {
|
||||
return app.Dao().FindRecordsByIds(c, ids, nil)
|
||||
},
|
||||
2,
|
||||
false,
|
||||
},
|
||||
// maxExpandDepth reached
|
||||
{
|
||||
[]string{"b8ba58f9-e2d7-42a0-b0e7-a11efd98236b"},
|
||||
[]string{"manyrels.onerel.manyrels.onerel.manyrels.onerel.manyrels.onerel.manyrels"},
|
||||
func(c *models.Collection, ids []string) ([]*models.Record, error) {
|
||||
return app.Dao().FindRecordsByIds(c, ids, nil)
|
||||
},
|
||||
6,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
ids := list.ToUniqueStringSlice(s.recordIds)
|
||||
records, _ := app.Dao().FindRecordsByIds(col, ids, nil)
|
||||
err := app.Dao().ExpandRecords(records, s.expands, s.fetchFunc)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
encoded, _ := json.Marshal(records)
|
||||
encodedStr := string(encoded)
|
||||
totalExpandProps := strings.Count(encodedStr, "@expand")
|
||||
|
||||
if s.expectExpandProps != totalExpandProps {
|
||||
t.Errorf("(%d) Expected %d @expand props in %v, got %d", i, s.expectExpandProps, encodedStr, totalExpandProps)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandRecord(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
col, _ := app.Dao().FindCollectionByNameOrId("demo4")
|
||||
|
||||
scenarios := []struct {
|
||||
recordId string
|
||||
expands []string
|
||||
fetchFunc daos.ExpandFetchFunc
|
||||
expectExpandProps int
|
||||
expectError bool
|
||||
}{
|
||||
// empty expand
|
||||
{
|
||||
"b8ba58f9-e2d7-42a0-b0e7-a11efd98236b",
|
||||
[]string{},
|
||||
func(c *models.Collection, ids []string) ([]*models.Record, error) {
|
||||
return app.Dao().FindRecordsByIds(c, ids, nil)
|
||||
},
|
||||
0,
|
||||
false,
|
||||
},
|
||||
// empty fetchFunc
|
||||
{
|
||||
"b8ba58f9-e2d7-42a0-b0e7-a11efd98236b",
|
||||
[]string{"onerel", "manyrels.onerel.manyrels"},
|
||||
nil,
|
||||
0,
|
||||
true,
|
||||
},
|
||||
// fetchFunc with error
|
||||
{
|
||||
"b8ba58f9-e2d7-42a0-b0e7-a11efd98236b",
|
||||
[]string{"onerel", "manyrels.onerel.manyrels"},
|
||||
func(c *models.Collection, ids []string) ([]*models.Record, error) {
|
||||
return nil, errors.New("test error")
|
||||
},
|
||||
0,
|
||||
true,
|
||||
},
|
||||
// invalid missing first level expand
|
||||
{
|
||||
"b8ba58f9-e2d7-42a0-b0e7-a11efd98236b",
|
||||
[]string{"invalid"},
|
||||
func(c *models.Collection, ids []string) ([]*models.Record, error) {
|
||||
return app.Dao().FindRecordsByIds(c, ids, nil)
|
||||
},
|
||||
0,
|
||||
true,
|
||||
},
|
||||
// invalid missing second level expand
|
||||
{
|
||||
"b8ba58f9-e2d7-42a0-b0e7-a11efd98236b",
|
||||
[]string{"manyrels.invalid"},
|
||||
func(c *models.Collection, ids []string) ([]*models.Record, error) {
|
||||
return app.Dao().FindRecordsByIds(c, ids, nil)
|
||||
},
|
||||
0,
|
||||
true,
|
||||
},
|
||||
// expand normalizations
|
||||
{
|
||||
"b8ba58f9-e2d7-42a0-b0e7-a11efd98236b",
|
||||
[]string{"manyrels.onerel.manyrels", "manyrels.onerel", "onerel", " onerel "},
|
||||
func(c *models.Collection, ids []string) ([]*models.Record, error) {
|
||||
return app.Dao().FindRecordsByIds(c, ids, nil)
|
||||
},
|
||||
3,
|
||||
false,
|
||||
},
|
||||
// single expand
|
||||
{
|
||||
"b8ba58f9-e2d7-42a0-b0e7-a11efd98236b",
|
||||
[]string{"manyrels"},
|
||||
func(c *models.Collection, ids []string) ([]*models.Record, error) {
|
||||
return app.Dao().FindRecordsByIds(c, ids, nil)
|
||||
},
|
||||
1,
|
||||
false,
|
||||
},
|
||||
// maxExpandDepth reached
|
||||
{
|
||||
"b8ba58f9-e2d7-42a0-b0e7-a11efd98236b",
|
||||
[]string{"manyrels.onerel.manyrels.onerel.manyrels.onerel.manyrels.onerel.manyrels"},
|
||||
func(c *models.Collection, ids []string) ([]*models.Record, error) {
|
||||
return app.Dao().FindRecordsByIds(c, ids, nil)
|
||||
},
|
||||
6,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
record, _ := app.Dao().FindFirstRecordByData(col, "id", s.recordId)
|
||||
err := app.Dao().ExpandRecord(record, s.expands, s.fetchFunc)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != s.expectError {
|
||||
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, s.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
encoded, _ := json.Marshal(record)
|
||||
encodedStr := string(encoded)
|
||||
totalExpandProps := strings.Count(encodedStr, "@expand")
|
||||
|
||||
if s.expectExpandProps != totalExpandProps {
|
||||
t.Errorf("(%d) Expected %d @expand props in %v, got %d", i, s.expectExpandProps, encodedStr, totalExpandProps)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,473 @@
|
||||
package daos_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/models/schema"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
func TestRecordQuery(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection, _ := app.Dao().FindCollectionByNameOrId("demo")
|
||||
|
||||
expected := fmt.Sprintf("SELECT `%s`.* FROM `%s`", collection.Name, collection.Name)
|
||||
|
||||
sql := app.Dao().RecordQuery(collection).Build().SQL()
|
||||
if sql != expected {
|
||||
t.Errorf("Expected sql %s, got %s", expected, sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindRecordById(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection, _ := app.Dao().FindCollectionByNameOrId("demo")
|
||||
|
||||
scenarios := []struct {
|
||||
id string
|
||||
filter func(q *dbx.SelectQuery) error
|
||||
expectError bool
|
||||
}{
|
||||
{"00000000-bafd-48f7-b8b7-090638afe209", nil, true},
|
||||
{"b5c2ffc2-bafd-48f7-b8b7-090638afe209", nil, false},
|
||||
{"b5c2ffc2-bafd-48f7-b8b7-090638afe209", func(q *dbx.SelectQuery) error {
|
||||
q.AndWhere(dbx.HashExp{"title": "missing"})
|
||||
return nil
|
||||
}, true},
|
||||
{"b5c2ffc2-bafd-48f7-b8b7-090638afe209", func(q *dbx.SelectQuery) error {
|
||||
return errors.New("test error")
|
||||
}, true},
|
||||
{"b5c2ffc2-bafd-48f7-b8b7-090638afe209", func(q *dbx.SelectQuery) error {
|
||||
q.AndWhere(dbx.HashExp{"title": "lorem"})
|
||||
return nil
|
||||
}, false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
record, err := app.Dao().FindRecordById(collection, scenario.id, scenario.filter)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if record != nil && record.Id != scenario.id {
|
||||
t.Errorf("(%d) Expected record with id %s, got %s", i, scenario.id, record.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindRecordsByIds(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection, _ := app.Dao().FindCollectionByNameOrId("demo")
|
||||
|
||||
scenarios := []struct {
|
||||
ids []string
|
||||
filter func(q *dbx.SelectQuery) error
|
||||
expectTotal int
|
||||
expectError bool
|
||||
}{
|
||||
{[]string{}, nil, 0, false},
|
||||
{[]string{"00000000-bafd-48f7-b8b7-090638afe209"}, nil, 0, false},
|
||||
{[]string{"b5c2ffc2-bafd-48f7-b8b7-090638afe209"}, nil, 1, false},
|
||||
{
|
||||
[]string{"b5c2ffc2-bafd-48f7-b8b7-090638afe209", "848a1dea-5ddd-42d6-a00d-030547bffcfe"},
|
||||
nil,
|
||||
2,
|
||||
false,
|
||||
},
|
||||
{
|
||||
[]string{"b5c2ffc2-bafd-48f7-b8b7-090638afe209", "848a1dea-5ddd-42d6-a00d-030547bffcfe"},
|
||||
func(q *dbx.SelectQuery) error {
|
||||
return errors.New("test error")
|
||||
},
|
||||
0,
|
||||
true,
|
||||
},
|
||||
{
|
||||
[]string{"b5c2ffc2-bafd-48f7-b8b7-090638afe209", "848a1dea-5ddd-42d6-a00d-030547bffcfe"},
|
||||
func(q *dbx.SelectQuery) error {
|
||||
q.AndWhere(dbx.Like("title", "test").Match(true, true))
|
||||
return nil
|
||||
},
|
||||
1,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
records, err := app.Dao().FindRecordsByIds(collection, scenario.ids, scenario.filter)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if len(records) != scenario.expectTotal {
|
||||
t.Errorf("(%d) Expected %d records, got %d", i, scenario.expectTotal, len(records))
|
||||
continue
|
||||
}
|
||||
|
||||
for _, r := range records {
|
||||
if !list.ExistInSlice(r.Id, scenario.ids) {
|
||||
t.Errorf("(%d) Couldn't find id %s in %v", i, r.Id, scenario.ids)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindRecordsByExpr(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection, _ := app.Dao().FindCollectionByNameOrId("demo")
|
||||
|
||||
scenarios := []struct {
|
||||
expression dbx.Expression
|
||||
expectIds []string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
nil,
|
||||
[]string{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
dbx.HashExp{"id": 123},
|
||||
[]string{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
dbx.Like("title", "test").Match(true, true),
|
||||
[]string{
|
||||
"848a1dea-5ddd-42d6-a00d-030547bffcfe",
|
||||
"577bd676-aacb-4072-b7da-99d00ee210a4",
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
records, err := app.Dao().FindRecordsByExpr(collection, scenario.expression)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if len(records) != len(scenario.expectIds) {
|
||||
t.Errorf("(%d) Expected %d records, got %d", i, len(scenario.expectIds), len(records))
|
||||
continue
|
||||
}
|
||||
|
||||
for _, r := range records {
|
||||
if !list.ExistInSlice(r.Id, scenario.expectIds) {
|
||||
t.Errorf("(%d) Couldn't find id %s in %v", i, r.Id, scenario.expectIds)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindFirstRecordByData(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection, _ := app.Dao().FindCollectionByNameOrId("demo")
|
||||
|
||||
scenarios := []struct {
|
||||
key string
|
||||
value any
|
||||
expectId string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
"",
|
||||
"848a1dea-5ddd-42d6-a00d-030547bffcfe",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"id",
|
||||
"invalid",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"id",
|
||||
"848a1dea-5ddd-42d6-a00d-030547bffcfe",
|
||||
"848a1dea-5ddd-42d6-a00d-030547bffcfe",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"title",
|
||||
"lorem",
|
||||
"b5c2ffc2-bafd-48f7-b8b7-090638afe209",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
record, err := app.Dao().FindFirstRecordByData(collection, scenario.key, scenario.value)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !scenario.expectError && record.Id != scenario.expectId {
|
||||
t.Errorf("(%d) Expected record with id %s, got %v", i, scenario.expectId, record.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRecordValueUnique(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection, _ := app.Dao().FindCollectionByNameOrId("demo4")
|
||||
|
||||
testManyRelsId1 := "df55c8ff-45ef-4c82-8aed-6e2183fe1125"
|
||||
testManyRelsId2 := "b84cd893-7119-43c9-8505-3c4e22da28a9"
|
||||
|
||||
scenarios := []struct {
|
||||
key string
|
||||
value any
|
||||
excludeId string
|
||||
expected bool
|
||||
}{
|
||||
{"", "", "", false},
|
||||
{"missing", "unique", "", false},
|
||||
{"title", "unique", "", true},
|
||||
{"title", "demo1", "", false},
|
||||
{"title", "demo1", "054f9f24-0a0a-4e09-87b1-bc7ff2b336a2", true},
|
||||
{"manyrels", []string{testManyRelsId2}, "", false},
|
||||
{"manyrels", []any{testManyRelsId2}, "", false},
|
||||
// with exclude
|
||||
{"manyrels", []string{testManyRelsId1, testManyRelsId2}, "b8ba58f9-e2d7-42a0-b0e7-a11efd98236b", true},
|
||||
// reverse order
|
||||
{"manyrels", []string{testManyRelsId2, testManyRelsId1}, "", true},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := app.Dao().IsRecordValueUnique(collection, scenario.key, scenario.value, scenario.excludeId)
|
||||
|
||||
if result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindUserRelatedRecords(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
u0 := &models.User{}
|
||||
u1, _ := app.Dao().FindUserByEmail("test3@example.com")
|
||||
u2, _ := app.Dao().FindUserByEmail("test2@example.com")
|
||||
|
||||
scenarios := []struct {
|
||||
user *models.User
|
||||
expectedIds []string
|
||||
}{
|
||||
{u0, []string{}},
|
||||
{u1, []string{
|
||||
"94568ca2-0bee-49d7-b749-06cb97956fd9", // demo2
|
||||
"fc69274d-ca5c-416a-b9ef-561b101cfbb1", // profile
|
||||
}},
|
||||
{u2, []string{
|
||||
"b2d5e39d-f569-4cc1-b593-3f074ad026bf", // profile
|
||||
}},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
records, err := app.Dao().FindUserRelatedRecords(scenario.user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(records) != len(scenario.expectedIds) {
|
||||
t.Errorf("(%d) Expected %d records, got %d (%v)", i, len(scenario.expectedIds), len(records), records)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, r := range records {
|
||||
if !list.ExistInSlice(r.Id, scenario.expectedIds) {
|
||||
t.Errorf("(%d) Couldn't find %s in %v", i, r.Id, scenario.expectedIds)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveRecord(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
collection, _ := app.Dao().FindCollectionByNameOrId("demo")
|
||||
|
||||
// create
|
||||
// ---
|
||||
r1 := models.NewRecord(collection)
|
||||
r1.SetDataValue("title", "test_new")
|
||||
err1 := app.Dao().SaveRecord(r1)
|
||||
if err1 != nil {
|
||||
t.Fatal(err1)
|
||||
}
|
||||
newR1, _ := app.Dao().FindFirstRecordByData(collection, "title", "test_new")
|
||||
if newR1 == nil || newR1.Id != r1.Id || newR1.GetStringDataValue("title") != r1.GetStringDataValue("title") {
|
||||
t.Errorf("Expected to find record %v, got %v", r1, newR1)
|
||||
}
|
||||
|
||||
// update
|
||||
// ---
|
||||
r2, _ := app.Dao().FindFirstRecordByData(collection, "id", "b5c2ffc2-bafd-48f7-b8b7-090638afe209")
|
||||
r2.SetDataValue("title", "test_update")
|
||||
err2 := app.Dao().SaveRecord(r2)
|
||||
if err2 != nil {
|
||||
t.Fatal(err2)
|
||||
}
|
||||
newR2, _ := app.Dao().FindFirstRecordByData(collection, "title", "test_update")
|
||||
if newR2 == nil || newR2.Id != r2.Id || newR2.GetStringDataValue("title") != r2.GetStringDataValue("title") {
|
||||
t.Errorf("Expected to find record %v, got %v", r2, newR2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteRecord(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
demo, _ := app.Dao().FindCollectionByNameOrId("demo")
|
||||
demo2, _ := app.Dao().FindCollectionByNameOrId("demo2")
|
||||
|
||||
// delete unsaved record
|
||||
// ---
|
||||
rec1 := models.NewRecord(demo)
|
||||
err1 := app.Dao().DeleteRecord(rec1)
|
||||
if err1 == nil {
|
||||
t.Fatal("(rec1) Didn't expect to succeed deleting new record")
|
||||
}
|
||||
|
||||
// delete existing record while being part of a non-cascade required relation
|
||||
// ---
|
||||
rec2, _ := app.Dao().FindFirstRecordByData(demo, "id", "848a1dea-5ddd-42d6-a00d-030547bffcfe")
|
||||
err2 := app.Dao().DeleteRecord(rec2)
|
||||
if err2 == nil {
|
||||
t.Fatalf("(rec2) Expected error, got nil")
|
||||
}
|
||||
|
||||
// delete existing record
|
||||
// ---
|
||||
rec3, _ := app.Dao().FindFirstRecordByData(demo, "id", "577bd676-aacb-4072-b7da-99d00ee210a4")
|
||||
err3 := app.Dao().DeleteRecord(rec3)
|
||||
if err3 != nil {
|
||||
t.Fatalf("(rec3) Expected nil, got error %v", err3)
|
||||
}
|
||||
|
||||
// check if it was really deleted
|
||||
rec3, _ = app.Dao().FindRecordById(demo, rec3.Id, nil)
|
||||
if rec3 != nil {
|
||||
t.Fatalf("(rec3) Expected record to be deleted, got %v", rec3)
|
||||
}
|
||||
|
||||
// check if the operation cascaded
|
||||
rel, _ := app.Dao().FindFirstRecordByData(demo2, "id", "63c2ab80-84ab-4057-a592-4604a731f78f")
|
||||
if rel != nil {
|
||||
t.Fatalf("(rec3) Expected the delete to cascade, found relation %v", rel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncRecordTableSchema(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
oldCollection, err := app.Dao().FindCollectionByNameOrId("demo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
updatedCollection, err := app.Dao().FindCollectionByNameOrId("demo")
|
||||
updatedCollection.Name = "demo_renamed"
|
||||
updatedCollection.Schema.RemoveField(updatedCollection.Schema.GetFieldByName("file").Id)
|
||||
updatedCollection.Schema.AddField(
|
||||
&schema.SchemaField{
|
||||
Name: "new_field",
|
||||
Type: schema.FieldTypeEmail,
|
||||
},
|
||||
)
|
||||
updatedCollection.Schema.AddField(
|
||||
&schema.SchemaField{
|
||||
Id: updatedCollection.Schema.GetFieldByName("title").Id,
|
||||
Name: "title_renamed",
|
||||
Type: schema.FieldTypeEmail,
|
||||
},
|
||||
)
|
||||
|
||||
scenarios := []struct {
|
||||
newCollection *models.Collection
|
||||
oldCollection *models.Collection
|
||||
expectedTableName string
|
||||
expectedColumns []string
|
||||
}{
|
||||
{
|
||||
&models.Collection{
|
||||
Name: "new_table",
|
||||
Schema: schema.NewSchema(
|
||||
&schema.SchemaField{
|
||||
Name: "test",
|
||||
Type: schema.FieldTypeText,
|
||||
},
|
||||
),
|
||||
},
|
||||
nil,
|
||||
"new_table",
|
||||
[]string{"id", "created", "updated", "test"},
|
||||
},
|
||||
// no changes
|
||||
{
|
||||
oldCollection,
|
||||
oldCollection,
|
||||
"demo",
|
||||
[]string{"id", "created", "updated", "title", "file"},
|
||||
},
|
||||
// renamed table, deleted column, renamed columnd and new column
|
||||
{
|
||||
updatedCollection,
|
||||
oldCollection,
|
||||
"demo_renamed",
|
||||
[]string{"id", "created", "updated", "title_renamed", "new_field"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
err := app.Dao().SyncRecordTableSchema(scenario.newCollection, scenario.oldCollection)
|
||||
if err != nil {
|
||||
t.Errorf("(%d) %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !app.Dao().HasTable(scenario.newCollection.Name) {
|
||||
t.Errorf("(%d) Expected table %s to exist", i, scenario.newCollection.Name)
|
||||
}
|
||||
|
||||
cols, _ := app.Dao().GetTableColumns(scenario.newCollection.Name)
|
||||
if len(cols) != len(scenario.expectedColumns) {
|
||||
t.Errorf("(%d) Expected columns %v, got %v", i, scenario.expectedColumns, cols)
|
||||
}
|
||||
|
||||
for _, c := range cols {
|
||||
if !list.ExistInSlice(c, scenario.expectedColumns) {
|
||||
t.Errorf("(%d) Couldn't find column %s in %v", i, c, scenario.expectedColumns)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package daos
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
// RequestQuery returns a new Request logs select query.
|
||||
func (dao *Dao) RequestQuery() *dbx.SelectQuery {
|
||||
return dao.ModelQuery(&models.Request{})
|
||||
}
|
||||
|
||||
// FindRequestById finds a single Request log by its id.
|
||||
func (dao *Dao) FindRequestById(id string) (*models.Request, error) {
|
||||
model := &models.Request{}
|
||||
|
||||
err := dao.RequestQuery().
|
||||
AndWhere(dbx.HashExp{"id": id}).
|
||||
Limit(1).
|
||||
One(model)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
type RequestsStatsItem struct {
|
||||
Total int `db:"total" json:"total"`
|
||||
Date types.DateTime `db:"date" json:"date"`
|
||||
}
|
||||
|
||||
// RequestsStats returns hourly grouped requests logs statistics.
|
||||
func (dao *Dao) RequestsStats(expr dbx.Expression) ([]*RequestsStatsItem, error) {
|
||||
result := []*RequestsStatsItem{}
|
||||
|
||||
query := dao.RequestQuery().
|
||||
Select("count(id) as total", "strftime('%Y-%m-%d %H:00:00', created) as date").
|
||||
GroupBy("date")
|
||||
|
||||
if expr != nil {
|
||||
query.AndWhere(expr)
|
||||
}
|
||||
|
||||
err := query.All(&result)
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
// DeleteOldRequests delete all requests that are created before createdBefore.
|
||||
func (dao *Dao) DeleteOldRequests(createdBefore time.Time) error {
|
||||
m := models.Request{}
|
||||
tableName := m.TableName()
|
||||
|
||||
formattedDate := createdBefore.UTC().Format(types.DefaultDateLayout)
|
||||
expr := dbx.NewExp("[[created]] <= {:date}", dbx.Params{"date": formattedDate})
|
||||
|
||||
_, err := dao.DB().Delete(tableName, expr).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SaveRequest upserts the provided Request model.
|
||||
func (dao *Dao) SaveRequest(request *models.Request) error {
|
||||
return dao.Save(request)
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
package daos_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestRequestQuery(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
expected := "SELECT {{_requests}}.* FROM `_requests`"
|
||||
|
||||
sql := app.Dao().RequestQuery().Build().SQL()
|
||||
if sql != expected {
|
||||
t.Errorf("Expected sql %s, got %s", expected, sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindRequestById(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
tests.MockRequestLogsData(app)
|
||||
|
||||
scenarios := []struct {
|
||||
id string
|
||||
expectError bool
|
||||
}{
|
||||
{"", true},
|
||||
{"invalid", true},
|
||||
{"00000000-9f38-44fb-bf82-c8f53b310d91", true},
|
||||
{"873f2133-9f38-44fb-bf82-c8f53b310d91", false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
admin, err := app.LogsDao().FindRequestById(scenario.id)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if admin != nil && admin.Id != scenario.id {
|
||||
t.Errorf("(%d) Expected admin with id %s, got %s", i, scenario.id, admin.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestsStats(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
tests.MockRequestLogsData(app)
|
||||
|
||||
expected := `[{"total":1,"date":"2022-05-01 10:00:00.000"},{"total":1,"date":"2022-05-02 10:00:00.000"}]`
|
||||
|
||||
now := time.Now().UTC().Format(types.DefaultDateLayout)
|
||||
exp := dbx.NewExp("[[created]] <= {:date}", dbx.Params{"date": now})
|
||||
result, err := app.LogsDao().RequestsStats(exp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
encoded, _ := json.Marshal(result)
|
||||
if string(encoded) != expected {
|
||||
t.Fatalf("Expected %s, got %s", expected, string(encoded))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteOldRequests(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
tests.MockRequestLogsData(app)
|
||||
|
||||
scenarios := []struct {
|
||||
date string
|
||||
expectedTotal int
|
||||
}{
|
||||
{"2022-01-01 10:00:00.000", 2}, // no requests to delete before that time
|
||||
{"2022-05-01 11:00:00.000", 1}, // only 1 request should have left
|
||||
{"2022-05-03 11:00:00.000", 0}, // no more requests should have left
|
||||
{"2022-05-04 11:00:00.000", 0}, // no more requests should have left
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
date, dateErr := time.Parse(types.DefaultDateLayout, scenario.date)
|
||||
if dateErr != nil {
|
||||
t.Errorf("(%d) Date error %v", i, dateErr)
|
||||
}
|
||||
|
||||
deleteErr := app.LogsDao().DeleteOldRequests(date)
|
||||
if deleteErr != nil {
|
||||
t.Errorf("(%d) Delete error %v", i, deleteErr)
|
||||
}
|
||||
|
||||
// check total remaining requests
|
||||
var total int
|
||||
countErr := app.LogsDao().RequestQuery().Select("count(*)").Row(&total)
|
||||
if countErr != nil {
|
||||
t.Errorf("(%d) Count error %v", i, countErr)
|
||||
}
|
||||
|
||||
if total != scenario.expectedTotal {
|
||||
t.Errorf("(%d) Expected %d remaining requests, got %d", i, scenario.expectedTotal, total)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveRequest(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
tests.MockRequestLogsData(app)
|
||||
|
||||
// create new request
|
||||
newRequest := &models.Request{}
|
||||
newRequest.Method = "get"
|
||||
newRequest.Meta = types.JsonMap{}
|
||||
createErr := app.LogsDao().SaveRequest(newRequest)
|
||||
if createErr != nil {
|
||||
t.Fatal(createErr)
|
||||
}
|
||||
|
||||
// check if it was really created
|
||||
existingRequest, fetchErr := app.LogsDao().FindRequestById(newRequest.Id)
|
||||
if fetchErr != nil {
|
||||
t.Fatal(fetchErr)
|
||||
}
|
||||
|
||||
existingRequest.Method = "post"
|
||||
updateErr := app.LogsDao().SaveRequest(existingRequest)
|
||||
if updateErr != nil {
|
||||
t.Fatal(updateErr)
|
||||
}
|
||||
// refresh instance to check if it was really updated
|
||||
existingRequest, _ = app.LogsDao().FindRequestById(existingRequest.Id)
|
||||
if existingRequest.Method != "post" {
|
||||
t.Fatalf("Expected request method to be %s, got %s", "post", existingRequest.Method)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package daos
|
||||
|
||||
import (
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
// HasTable checks if a table with the provided name exists (case insensitive).
|
||||
func (dao *Dao) HasTable(tableName string) bool {
|
||||
var exists bool
|
||||
|
||||
err := dao.DB().Select("count(*)").
|
||||
From("sqlite_schema").
|
||||
AndWhere(dbx.HashExp{"type": "table"}).
|
||||
AndWhere(dbx.NewExp("LOWER([[name]])=LOWER({:tableName})", dbx.Params{"tableName": tableName})).
|
||||
Limit(1).
|
||||
Row(&exists)
|
||||
|
||||
return err == nil && exists
|
||||
}
|
||||
|
||||
// GetTableColumns returns all column names of a single table by its name.
|
||||
func (dao *Dao) GetTableColumns(tableName string) ([]string, error) {
|
||||
columns := []string{}
|
||||
|
||||
err := dao.DB().NewQuery("SELECT name FROM PRAGMA_TABLE_INFO({:tableName})").
|
||||
Bind(dbx.Params{"tableName": tableName}).
|
||||
Column(&columns)
|
||||
|
||||
return columns, err
|
||||
}
|
||||
|
||||
// DeleteTable drops the specified table.
|
||||
func (dao *Dao) DeleteTable(tableName string) error {
|
||||
_, err := dao.DB().DropTable(tableName).Execute()
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package daos_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
)
|
||||
|
||||
func TestHasTable(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
tableName string
|
||||
expected bool
|
||||
}{
|
||||
{"", false},
|
||||
{"test", false},
|
||||
{"_admins", true},
|
||||
{"demo3", true},
|
||||
{"DEMO3", true}, // table names are case insensitives by default
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := app.Dao().HasTable(scenario.tableName)
|
||||
if result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTableColumns(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
tableName string
|
||||
expected []string
|
||||
}{
|
||||
{"", nil},
|
||||
{"_params", []string{"id", "key", "value", "created", "updated"}},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
columns, _ := app.Dao().GetTableColumns(scenario.tableName)
|
||||
|
||||
if len(columns) != len(scenario.expected) {
|
||||
t.Errorf("(%d) Expected columns %v, got %v", i, scenario.expected, columns)
|
||||
}
|
||||
|
||||
for _, c := range columns {
|
||||
if !list.ExistInSlice(c, scenario.expected) {
|
||||
t.Errorf("(%d) Didn't expect column %s", i, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteTable(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
tableName string
|
||||
expectError bool
|
||||
}{
|
||||
{"", true},
|
||||
{"test", true},
|
||||
{"_admins", false},
|
||||
{"demo3", false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
err := app.Dao().DeleteTable(scenario.tableName)
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasErr %v, got %v", i, scenario.expectError, hasErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
+281
@@ -0,0 +1,281 @@
|
||||
package daos
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/models/schema"
|
||||
"github.com/pocketbase/pocketbase/tools/list"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
|
||||
// UserQuery returns a new User model select query.
|
||||
func (dao *Dao) UserQuery() *dbx.SelectQuery {
|
||||
return dao.ModelQuery(&models.User{})
|
||||
}
|
||||
|
||||
// LoadProfile loads the profile record associated to the provided user.
|
||||
func (dao *Dao) LoadProfile(user *models.User) error {
|
||||
collection, err := dao.FindCollectionByNameOrId(models.ProfileCollectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
profile, err := dao.FindFirstRecordByData(collection, models.ProfileCollectionUserFieldName, user.Id)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
|
||||
user.Profile = profile
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadProfiles loads the profile records associated to the provied users list.
|
||||
func (dao *Dao) LoadProfiles(users []*models.User) error {
|
||||
collection, err := dao.FindCollectionByNameOrId(models.ProfileCollectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// extract user ids
|
||||
ids := []string{}
|
||||
usersMap := map[string]*models.User{}
|
||||
for _, user := range users {
|
||||
ids = append(ids, user.Id)
|
||||
usersMap[user.Id] = user
|
||||
}
|
||||
|
||||
profiles, err := dao.FindRecordsByExpr(collection, dbx.HashExp{
|
||||
models.ProfileCollectionUserFieldName: list.ToInterfaceSlice(ids),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// populate each user.Profile member
|
||||
for _, profile := range profiles {
|
||||
userId := profile.GetStringDataValue(models.ProfileCollectionUserFieldName)
|
||||
user, ok := usersMap[userId]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
user.Profile = profile
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindUserById finds a single User model by its id.
|
||||
//
|
||||
// This method also auto loads the related user profile record
|
||||
// into the found model.
|
||||
func (dao *Dao) FindUserById(id string) (*models.User, error) {
|
||||
model := &models.User{}
|
||||
|
||||
err := dao.UserQuery().
|
||||
AndWhere(dbx.HashExp{"id": id}).
|
||||
Limit(1).
|
||||
One(model)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// try to load the user profile (if exist)
|
||||
if err := dao.LoadProfile(model); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
// FindUserByEmail finds a single User model by its email address.
|
||||
//
|
||||
// This method also auto loads the related user profile record
|
||||
// into the found model.
|
||||
func (dao *Dao) FindUserByEmail(email string) (*models.User, error) {
|
||||
model := &models.User{}
|
||||
|
||||
err := dao.UserQuery().
|
||||
AndWhere(dbx.HashExp{"email": email}).
|
||||
Limit(1).
|
||||
One(model)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// try to load the user profile (if exist)
|
||||
if err := dao.LoadProfile(model); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
// FindUserByToken finds the user associated with the provided JWT token.
|
||||
// Returns an error if the JWT token is invalid or expired.
|
||||
//
|
||||
// This method also auto loads the related user profile record
|
||||
// into the found model.
|
||||
func (dao *Dao) FindUserByToken(token string, baseTokenKey string) (*models.User, error) {
|
||||
unverifiedClaims, err := security.ParseUnverifiedJWT(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// check required claims
|
||||
id, _ := unverifiedClaims["id"].(string)
|
||||
if id == "" {
|
||||
return nil, errors.New("Missing or invalid token claims.")
|
||||
}
|
||||
|
||||
user, err := dao.FindUserById(id)
|
||||
if err != nil || user == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
verificationKey := user.TokenKey + baseTokenKey
|
||||
|
||||
// verify token signature
|
||||
if _, err := security.ParseJWT(token, verificationKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// IsUserEmailUnique checks if the provided email address is not
|
||||
// already in use by other users.
|
||||
func (dao *Dao) IsUserEmailUnique(email string, excludeId string) bool {
|
||||
if email == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
var exists bool
|
||||
err := dao.UserQuery().
|
||||
Select("count(*)").
|
||||
AndWhere(dbx.Not(dbx.HashExp{"id": excludeId})).
|
||||
AndWhere(dbx.HashExp{"email": email}).
|
||||
Limit(1).
|
||||
Row(&exists)
|
||||
|
||||
return err == nil && !exists
|
||||
}
|
||||
|
||||
// DeleteUser deletes the provided User model.
|
||||
//
|
||||
// This method will also cascade the delete operation to all
|
||||
// Record models that references the provided User model
|
||||
// (delete or set to NULL, depending on the related user shema field settings).
|
||||
//
|
||||
// The delete operation may fail if the user is part of a required
|
||||
// reference in another Record model (aka. cannot be deleted or set to NULL).
|
||||
func (dao *Dao) DeleteUser(user *models.User) error {
|
||||
// fetch related records
|
||||
// note: the select is outside of the transaction to prevent SQLITE_LOCKED error when mixing read&write in a single transaction
|
||||
relatedRecords, err := dao.FindUserRelatedRecords(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return dao.RunInTransaction(func(txDao *Dao) error {
|
||||
// check if related records has to be deleted (if `CascadeDelete` is set)
|
||||
// OR
|
||||
// just unset the user related fields (if they are not required)
|
||||
// -----------------------------------------------------------
|
||||
recordsLoop:
|
||||
for _, record := range relatedRecords {
|
||||
var needSave bool
|
||||
|
||||
for _, field := range record.Collection().Schema.Fields() {
|
||||
if field.Type != schema.FieldTypeUser {
|
||||
continue // not a user field
|
||||
}
|
||||
|
||||
ids := record.GetStringSliceDataValue(field.Name)
|
||||
|
||||
// unset the user id
|
||||
for i := len(ids) - 1; i >= 0; i-- {
|
||||
if ids[i] == user.Id {
|
||||
ids = append(ids[:i], ids[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
options, _ := field.Options.(*schema.UserOptions)
|
||||
|
||||
// cascade delete
|
||||
// (only if there are no other user references in case of multiple select)
|
||||
if options.CascadeDelete && len(ids) == 0 {
|
||||
if err := txDao.DeleteRecord(record); err != nil {
|
||||
return err
|
||||
}
|
||||
// no need to further iterate the user fields (the record is deleted)
|
||||
continue recordsLoop
|
||||
}
|
||||
|
||||
if field.Required && len(ids) == 0 {
|
||||
return fmt.Errorf("Failed delete the user because a record exist with required user reference to the current model (%q, %q).", record.Id, record.Collection().Name)
|
||||
}
|
||||
|
||||
// apply the reference changes
|
||||
record.SetDataValue(field.Name, field.PrepareValue(ids))
|
||||
needSave = true
|
||||
}
|
||||
|
||||
if needSave {
|
||||
if err := txDao.SaveRecord(record); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
// -----------------------------------------------------------
|
||||
|
||||
return txDao.Delete(user)
|
||||
})
|
||||
}
|
||||
|
||||
// SaveUser upserts the provided User model.
|
||||
//
|
||||
// An empty profile record will be created if the user
|
||||
// doesn't have a profile record set yet.
|
||||
func (dao *Dao) SaveUser(user *models.User) error {
|
||||
profileCollection, err := dao.FindCollectionByNameOrId(models.ProfileCollectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// fetch the related user profile record (if exist)
|
||||
var userProfile *models.Record
|
||||
if user.HasId() {
|
||||
userProfile, _ = dao.FindFirstRecordByData(
|
||||
profileCollection,
|
||||
models.ProfileCollectionUserFieldName,
|
||||
user.Id,
|
||||
)
|
||||
}
|
||||
|
||||
return dao.RunInTransaction(func(txDao *Dao) error {
|
||||
if err := txDao.Save(user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// create default/empty profile record if doesn't exist
|
||||
if userProfile == nil {
|
||||
userProfile = models.NewRecord(profileCollection)
|
||||
userProfile.SetDataValue(models.ProfileCollectionUserFieldName, user.Id)
|
||||
if err := txDao.Save(userProfile); err != nil {
|
||||
return err
|
||||
}
|
||||
user.Profile = userProfile
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,274 @@
|
||||
package daos_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/tests"
|
||||
)
|
||||
|
||||
func TestUserQuery(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
expected := "SELECT {{_users}}.* FROM `_users`"
|
||||
|
||||
sql := app.Dao().UserQuery().Build().SQL()
|
||||
if sql != expected {
|
||||
t.Errorf("Expected sql %s, got %s", expected, sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadProfile(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
// try to load missing profile (shouldn't return an error)
|
||||
// ---
|
||||
newUser := &models.User{}
|
||||
err1 := app.Dao().LoadProfile(newUser)
|
||||
if err1 != nil {
|
||||
t.Fatalf("Expected nil, got error %v", err1)
|
||||
}
|
||||
|
||||
// try to load existing profile
|
||||
// ---
|
||||
existingUser, _ := app.Dao().FindUserByEmail("test@example.com")
|
||||
existingUser.Profile = nil // reset
|
||||
|
||||
err2 := app.Dao().LoadProfile(existingUser)
|
||||
if err2 != nil {
|
||||
t.Fatal(err2)
|
||||
}
|
||||
|
||||
if existingUser.Profile == nil {
|
||||
t.Fatal("Expected user profile to be loaded, got nil")
|
||||
}
|
||||
|
||||
if existingUser.Profile.GetStringDataValue("name") != "test" {
|
||||
t.Fatalf("Expected profile.name to be 'test', got %s", existingUser.Profile.GetStringDataValue("name"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadProfiles(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
u0 := &models.User{}
|
||||
u1, _ := app.Dao().FindUserByEmail("test@example.com")
|
||||
u2, _ := app.Dao().FindUserByEmail("test2@example.com")
|
||||
|
||||
users := []*models.User{u0, u1, u2}
|
||||
|
||||
err := app.Dao().LoadProfiles(users)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if u0.Profile != nil {
|
||||
t.Errorf("Expected profile to be nil for u0, got %v", u0.Profile)
|
||||
}
|
||||
if u1.Profile == nil {
|
||||
t.Errorf("Expected profile to be set for u1, got nil")
|
||||
}
|
||||
if u2.Profile == nil {
|
||||
t.Errorf("Expected profile to be set for u2, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindUserById(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
id string
|
||||
expectError bool
|
||||
}{
|
||||
{"00000000-2b4a-a26b-4d01-42d3c3d77bc8", true},
|
||||
{"97cc3d3d-6ba2-383f-b42a-7bc84d27410c", false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
user, err := app.Dao().FindUserById(scenario.id)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
|
||||
}
|
||||
|
||||
if user != nil && user.Id != scenario.id {
|
||||
t.Errorf("(%d) Expected user with id %s, got %s", i, scenario.id, user.Id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindUserByEmail(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
email string
|
||||
expectError bool
|
||||
}{
|
||||
{"invalid", true},
|
||||
{"missing@example.com", true},
|
||||
{"test@example.com", false},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
user, err := app.Dao().FindUserByEmail(scenario.email)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !scenario.expectError && user.Email != scenario.email {
|
||||
t.Errorf("(%d) Expected user with email %s, got %s", i, scenario.email, user.Email)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindUserByToken(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
token string
|
||||
baseKey string
|
||||
expectedEmail string
|
||||
expectError bool
|
||||
}{
|
||||
// invalid base key (password reset key for auth token)
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRkMDE5N2NjLTJiNGEtM2Y4My1hMjZiLWQ3N2JjODQyM2QzYyIsInR5cGUiOiJ1c2VyIiwiZXhwIjoxODkzNDc0MDAwfQ.Wq5ac1q1f5WntIzEngXk22ydMj-eFgvfSRg7dhmPKic",
|
||||
app.Settings().UserPasswordResetToken.Secret,
|
||||
"",
|
||||
true,
|
||||
},
|
||||
// expired token
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRkMDE5N2NjLTJiNGEtM2Y4My1hMjZiLWQ3N2JjODQyM2QzYyIsInR5cGUiOiJ1c2VyIiwiZXhwIjoxNjQwOTkxNjYxfQ.RrSG5NwysI38DEZrIQiz3lUgI6sEuYGTll_jLRbBSiw",
|
||||
app.Settings().UserAuthToken.Secret,
|
||||
"",
|
||||
true,
|
||||
},
|
||||
// valid token
|
||||
{
|
||||
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjRkMDE5N2NjLTJiNGEtM2Y4My1hMjZiLWQ3N2JjODQyM2QzYyIsInR5cGUiOiJ1c2VyIiwiZXhwIjoxODkzNDc0MDAwfQ.Wq5ac1q1f5WntIzEngXk22ydMj-eFgvfSRg7dhmPKic",
|
||||
app.Settings().UserAuthToken.Secret,
|
||||
"test@example.com",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
user, err := app.Dao().FindUserByToken(scenario.token, scenario.baseKey)
|
||||
|
||||
hasErr := err != nil
|
||||
if hasErr != scenario.expectError {
|
||||
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !scenario.expectError && user.Email != scenario.expectedEmail {
|
||||
t.Errorf("(%d) Expected user model %s, got %s", i, scenario.expectedEmail, user.Email)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsUserEmailUnique(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
scenarios := []struct {
|
||||
email string
|
||||
excludeId string
|
||||
expected bool
|
||||
}{
|
||||
{"", "", false},
|
||||
{"test@example.com", "", false},
|
||||
{"new@example.com", "", true},
|
||||
{"test@example.com", "4d0197cc-2b4a-3f83-a26b-d77bc8423d3c", true},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
result := app.Dao().IsUserEmailUnique(scenario.email, scenario.excludeId)
|
||||
if result != scenario.expected {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUser(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
// try to delete unsaved user
|
||||
// ---
|
||||
err1 := app.Dao().DeleteUser(&models.User{})
|
||||
if err1 == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
|
||||
// try to delete existing user
|
||||
// ---
|
||||
user, _ := app.Dao().FindUserByEmail("test3@example.com")
|
||||
err2 := app.Dao().DeleteUser(user)
|
||||
if err2 != nil {
|
||||
t.Fatalf("Expected nil, got error %v", err2)
|
||||
}
|
||||
|
||||
// check if the delete operation was cascaded to the profiles collection (record delete)
|
||||
profilesCol, _ := app.Dao().FindCollectionByNameOrId(models.ProfileCollectionName)
|
||||
profile, _ := app.Dao().FindRecordById(profilesCol, user.Profile.Id, nil)
|
||||
if profile != nil {
|
||||
t.Fatalf("Expected user profile to be deleted, got %v", profile)
|
||||
}
|
||||
|
||||
// check if delete operation was cascaded to the related demo2 collection (null set)
|
||||
demo2Col, _ := app.Dao().FindCollectionByNameOrId("demo2")
|
||||
record, _ := app.Dao().FindRecordById(demo2Col, "94568ca2-0bee-49d7-b749-06cb97956fd9", nil)
|
||||
if record == nil {
|
||||
t.Fatal("Expected to found related record, got nil")
|
||||
}
|
||||
if record.GetStringDataValue("user") != "" {
|
||||
t.Fatalf("Expected user field to be set to empty string, got %v", record.GetStringDataValue("user"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveUser(t *testing.T) {
|
||||
app, _ := tests.NewTestApp()
|
||||
defer app.Cleanup()
|
||||
|
||||
// create
|
||||
// ---
|
||||
u1 := &models.User{}
|
||||
u1.Email = "new@example.com"
|
||||
u1.SetPassword("123456")
|
||||
err1 := app.Dao().SaveUser(u1)
|
||||
if err1 != nil {
|
||||
t.Fatal(err1)
|
||||
}
|
||||
u1, refreshErr1 := app.Dao().FindUserByEmail("new@example.com")
|
||||
if refreshErr1 != nil {
|
||||
t.Fatalf("Expected user with email new@example.com to have been created, got error %v", refreshErr1)
|
||||
}
|
||||
if u1.Profile == nil {
|
||||
t.Fatalf("Expected creating a user to create also an empty profile record")
|
||||
}
|
||||
|
||||
// update
|
||||
// ---
|
||||
u2, _ := app.Dao().FindUserByEmail("test@example.com")
|
||||
u2.Email = "test_update@example.com"
|
||||
err2 := app.Dao().SaveUser(u2)
|
||||
if err2 != nil {
|
||||
t.Fatal(err2)
|
||||
}
|
||||
u2, refreshErr2 := app.Dao().FindUserByEmail("test_update@example.com")
|
||||
if u2 == nil {
|
||||
t.Fatalf("Couldn't find user with email test_update@example.com (%v)", refreshErr2)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user