added view collection type

This commit is contained in:
Gani Georgiev
2023-02-18 19:33:42 +02:00
parent 0052e2ab2a
commit a07f67002f
98 changed files with 3259 additions and 829 deletions
+82 -14
View File
@@ -129,13 +129,23 @@ func (dao *Dao) DeleteCollection(collection *models.Collection) error {
return err
}
if total := len(result); total > 0 {
return fmt.Errorf("The collection %q has external relation field references (%d).", collection.Name, total)
names := make([]string, 0, len(result))
for ref := range result {
names = append(names, ref.Name)
}
return fmt.Errorf("The collection %q has external relation field references (%s).", collection.Name, strings.Join(names, ", "))
}
return dao.RunInTransaction(func(txDao *Dao) error {
// delete the related records table
if err := txDao.DeleteTable(collection.Name); err != nil {
return err
// delete the related view or records table
if collection.IsView() {
if err := txDao.DeleteView(collection.Name); err != nil {
return err
}
} else {
if err := txDao.DeleteTable(collection.Name); err != nil {
return err
}
}
return txDao.Delete(collection)
@@ -163,13 +173,18 @@ func (dao *Dao) SaveCollection(collection *models.Collection) error {
collection.Type = models.CollectionTypeBase
}
// persist the collection model
if err := txDao.Save(collection); err != nil {
return err
}
switch collection.Type {
case models.CollectionTypeView:
return txDao.saveViewCollection(collection, oldCollection)
default:
// persist the collection model
if err := txDao.Save(collection); err != nil {
return err
}
// sync the changes with the related records table
return txDao.SyncRecordTableSchema(collection, oldCollection)
// sync the changes with the related records table
return txDao.SyncRecordTableSchema(collection, oldCollection)
}
})
}
@@ -274,8 +289,14 @@ func (dao *Dao) ImportCollections(
continue // exist
}
if err := txDao.DeleteTable(existing.Name); err != nil {
return err
if existing.IsView() {
if err := txDao.DeleteView(existing.Name); err != nil {
return err
}
} else {
if err := txDao.DeleteTable(existing.Name); err != nil {
return err
}
}
}
}
@@ -283,11 +304,58 @@ func (dao *Dao) ImportCollections(
// sync the upserted collections with the related records table
for _, imported := range importedCollections {
existing := mappedExisting[imported.GetId()]
if err := txDao.SyncRecordTableSchema(imported, existing); err != nil {
return err
if imported.IsView() {
if err := txDao.saveViewCollection(imported, existing); err != nil {
return err
}
} else {
if err := txDao.SyncRecordTableSchema(imported, existing); err != nil {
return err
}
}
}
return nil
})
}
// saveViewCollection persists the provided View collection changes:
// - deletes the old related SQL view (if any)
// - creates a new SQL view with the latest newCollection.Options.Query
// - generates a new schema based on newCollection.Options.Query
// - updates newCollection.Schema based on the generated view table info and query
// - saves the newCollection
//
// This method returns an error if newCollection is not a "view".
func (dao *Dao) saveViewCollection(newCollection *models.Collection, oldCollection *models.Collection) error {
if newCollection.IsAuth() {
return errors.New("not a view collection")
}
return dao.RunInTransaction(func(txDao *Dao) error {
query := newCollection.ViewOptions().Query
// generate collection schema from the query
schema, err := txDao.CreateViewSchema(query)
if err != nil {
return err
}
// delete old renamed view
if oldCollection != nil && newCollection.Name != oldCollection.Name {
if err := txDao.DeleteView(oldCollection.Name); err != nil {
return err
}
}
// (re)create the view
if err := txDao.SaveView(newCollection.Name, query); err != nil {
return err
}
newCollection.Schema = schema
return txDao.Save(newCollection)
})
}
+60 -32
View File
@@ -45,16 +45,16 @@ func TestFindCollectionsByType(t *testing.T) {
hasErr := err != nil
if hasErr != scenario.expectError {
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
t.Errorf("[%d] Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
}
if len(collections) != scenario.expectTotal {
t.Errorf("(%d) Expected %d collections, got %d", i, scenario.expectTotal, len(collections))
t.Errorf("[%d] Expected %d collections, got %d", i, scenario.expectTotal, len(collections))
}
for _, c := range collections {
if c.Type != scenario.collectionType {
t.Errorf("(%d) Expected collection with type %s, got %s: \n%v", i, scenario.collectionType, c.Type, c)
t.Errorf("[%d] Expected collection with type %s, got %s: \n%v", i, scenario.collectionType, c.Type, c)
}
}
}
@@ -80,11 +80,11 @@ func TestFindCollectionByNameOrId(t *testing.T) {
hasErr := err != nil
if hasErr != scenario.expectError {
t.Errorf("(%d) Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
t.Errorf("[%d] Expected hasErr to be %v, got %v (%v)", i, scenario.expectError, hasErr, err)
}
if model != nil && model.Id != scenario.nameOrId && !strings.EqualFold(model.Name, scenario.nameOrId) {
t.Errorf("(%d) Expected model with identifier %s, got %v", i, scenario.nameOrId, model)
t.Errorf("[%d] Expected model with identifier %s, got %v", i, scenario.nameOrId, model)
}
}
}
@@ -108,7 +108,7 @@ func TestIsCollectionNameUnique(t *testing.T) {
for i, scenario := range scenarios {
result := app.Dao().IsCollectionNameUnique(scenario.name, scenario.excludeId)
if result != scenario.expected {
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, result)
t.Errorf("[%d] Expected %v, got %v", i, scenario.expected, result)
}
}
}
@@ -155,7 +155,7 @@ func TestFindCollectionReferences(t *testing.T) {
}
for i, f := range fields {
if !list.ExistInSlice(f.Name, expectedFields) {
t.Fatalf("(%d) Didn't expect field %v", i, f)
t.Fatalf("[%d] Didn't expect field %v", i, f)
}
}
}
@@ -165,21 +165,29 @@ func TestDeleteCollection(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
c0 := &models.Collection{}
c1, err := app.Dao().FindCollectionByNameOrId("clients")
colEmpty := &models.Collection{}
colAuth, err := app.Dao().FindCollectionByNameOrId("clients")
if err != nil {
t.Fatal(err)
}
c2, err := app.Dao().FindCollectionByNameOrId("demo2")
colReferenced, err := app.Dao().FindCollectionByNameOrId("demo2")
if err != nil {
t.Fatal(err)
}
c3, err := app.Dao().FindCollectionByNameOrId("demo1")
colSystem, err := app.Dao().FindCollectionByNameOrId("demo3")
if err != nil {
t.Fatal(err)
}
c3.System = true
if err := app.Dao().Save(c3); err != nil {
colSystem.System = true
if err := app.Dao().Save(colSystem); err != nil {
t.Fatal(err)
}
colView, err := app.Dao().FindCollectionByNameOrId("view1")
if err != nil {
t.Fatal(err)
}
@@ -187,18 +195,28 @@ func TestDeleteCollection(t *testing.T) {
model *models.Collection
expectError bool
}{
{c0, true},
{c1, false},
{c2, true}, // is part of a reference
{c3, true}, // system
{colEmpty, true},
{colAuth, false},
{colReferenced, true},
{colSystem, true},
{colView, false},
}
for i, scenario := range scenarios {
err := app.Dao().DeleteCollection(scenario.model)
hasErr := err != nil
for i, s := range scenarios {
err := app.Dao().DeleteCollection(s.model)
if hasErr != scenario.expectError {
t.Errorf("(%d) Expected hasErr %v, got %v", i, scenario.expectError, hasErr)
hasErr := err != nil
if hasErr != s.expectError {
t.Errorf("[%d] Expected hasErr %v, got %v (%v)", i, s.expectError, hasErr, err)
continue
}
if hasErr {
continue
}
if app.Dao().HasTable(s.model.Name) {
t.Errorf("[%d] Expected table/view %s to be deleted", i, s.model.Name)
}
}
}
@@ -244,7 +262,7 @@ func TestSaveCollectionCreate(t *testing.T) {
}
for i, c := range columns {
if !list.ExistInSlice(c, expectedColumns) {
t.Fatalf("(%d) Didn't expect record column %s", i, c)
t.Fatalf("[%d] Didn't expect record column %s", i, c)
}
}
}
@@ -282,12 +300,14 @@ func TestSaveCollectionUpdate(t *testing.T) {
}
for i, c := range columns {
if !list.ExistInSlice(c, expectedColumns) {
t.Fatalf("(%d) Didn't expect record column %s", i, c)
t.Fatalf("[%d] Didn't expect record column %s", i, c)
}
}
}
func TestImportCollections(t *testing.T) {
totalCollections := 10
scenarios := []struct {
name string
jsonData string
@@ -302,7 +322,7 @@ func TestImportCollections(t *testing.T) {
name: "empty collections",
jsonData: `[]`,
expectError: true,
expectCollectionsCount: 8,
expectCollectionsCount: totalCollections,
},
{
name: "minimal collection import",
@@ -312,7 +332,7 @@ func TestImportCollections(t *testing.T) {
]`,
deleteMissing: false,
expectError: false,
expectCollectionsCount: 10,
expectCollectionsCount: totalCollections + 2,
},
{
name: "minimal collection import + failed beforeRecordsSync",
@@ -324,7 +344,7 @@ func TestImportCollections(t *testing.T) {
},
deleteMissing: false,
expectError: true,
expectCollectionsCount: 8,
expectCollectionsCount: totalCollections,
},
{
name: "minimal collection import + successful beforeRecordsSync",
@@ -336,7 +356,7 @@ func TestImportCollections(t *testing.T) {
},
deleteMissing: false,
expectError: false,
expectCollectionsCount: 9,
expectCollectionsCount: totalCollections + 1,
},
{
name: "new + update + delete system collection",
@@ -372,7 +392,7 @@ func TestImportCollections(t *testing.T) {
]`,
deleteMissing: true,
expectError: true,
expectCollectionsCount: 8,
expectCollectionsCount: totalCollections,
},
{
name: "new + update + delete non-system collection",
@@ -442,11 +462,19 @@ func TestImportCollections(t *testing.T) {
"type":"bool"
}
]
},
{
"id": "test_new_view",
"name": "new_view",
"type": "view",
"options": {
"query": "select id from demo2"
}
}
]`,
deleteMissing: true,
expectError: false,
expectCollectionsCount: 3,
expectCollectionsCount: 4,
},
{
name: "test with deleteMissing: false",
@@ -501,7 +529,7 @@ func TestImportCollections(t *testing.T) {
]`,
deleteMissing: false,
expectError: false,
expectCollectionsCount: 9,
expectCollectionsCount: totalCollections + 1,
afterTestFunc: func(testApp *tests.TestApp, resultCollections []*models.Collection) {
expectedCollectionFields := map[string]int{
"nologin": 1,
@@ -509,7 +537,7 @@ func TestImportCollections(t *testing.T) {
"demo2": 2,
"demo3": 2,
"demo4": 11,
"demo5": 5,
"demo5": 6,
"new_import": 1,
}
for name, expectedCount := range expectedCollectionFields {
+10 -2
View File
@@ -128,7 +128,11 @@ func (dao *Dao) FindRecordsByExpr(collectionNameOrId string, exprs ...dbx.Expres
// FindFirstRecordByData returns the first found record matching
// the provided key-value pair.
func (dao *Dao) FindFirstRecordByData(collectionNameOrId string, key string, value any) (*models.Record, error) {
func (dao *Dao) FindFirstRecordByData(
collectionNameOrId string,
key string,
value any,
) (*models.Record, error) {
collection, err := dao.FindCollectionByNameOrId(collectionNameOrId)
if err != nil {
return nil, err
@@ -396,11 +400,15 @@ func (dao *Dao) cascadeRecordDelete(mainRecord *models.Record, refs map[*models.
uniqueJsonEachAlias := "__je__" + security.PseudorandomString(4)
for refCollection, fields := range refs {
if refCollection.IsView() {
continue // skip view collections
}
for _, field := range fields {
recordTableName := inflector.Columnify(refCollection.Name)
prefixedFieldName := recordTableName + "." + inflector.Columnify(field.Name)
// @todo optimize single relation lookup in v0.12+
// @todo optimize single relation lookup
query := dao.RecordQuery(refCollection).
Distinct(true).
InnerJoin(fmt.Sprintf(
+23 -1
View File
@@ -1,7 +1,10 @@
package daos
import (
"fmt"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/models"
)
// HasTable checks if a table with the provided name exists (case insensitive).
@@ -29,9 +32,28 @@ func (dao *Dao) GetTableColumns(tableName string) ([]string, error) {
return columns, err
}
// GetTableInfo returns the `table_info` pragma result for the specified table.
func (dao *Dao) GetTableInfo(tableName string) ([]*models.TableInfoRow, error) {
info := []*models.TableInfoRow{}
err := dao.DB().NewQuery("SELECT * FROM PRAGMA_TABLE_INFO({:tableName})").
Bind(dbx.Params{"tableName": tableName}).
All(&info)
return info, err
}
// DeleteTable drops the specified table.
//
// This method is a no-op if a table with the provided name doesn't exist.
//
// Be aware that this method is vulnerable to SQL injection and the
// "tableName" argument must come only from trusted input!
func (dao *Dao) DeleteTable(tableName string) error {
_, err := dao.DB().DropTable(tableName).Execute()
_, err := dao.DB().NewQuery(fmt.Sprintf(
"DROP TABLE IF EXISTS {{%s}}",
tableName,
)).Execute()
return err
}
+43 -12
View File
@@ -3,6 +3,7 @@ package daos_test
import (
"context"
"database/sql"
"encoding/json"
"testing"
"time"
@@ -28,7 +29,7 @@ func TestHasTable(t *testing.T) {
for i, scenario := range scenarios {
result := app.Dao().HasTable(scenario.tableName)
if result != scenario.expected {
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, result)
t.Errorf("[%d] Expected %v, got %v", i, scenario.expected, result)
}
}
}
@@ -45,21 +46,50 @@ func TestGetTableColumns(t *testing.T) {
{"_params", []string{"id", "key", "value", "created", "updated"}},
}
for i, scenario := range scenarios {
columns, _ := app.Dao().GetTableColumns(scenario.tableName)
for i, s := range scenarios {
columns, _ := app.Dao().GetTableColumns(s.tableName)
if len(columns) != len(scenario.expected) {
t.Errorf("(%d) Expected columns %v, got %v", i, scenario.expected, columns)
if len(columns) != len(s.expected) {
t.Errorf("[%d] Expected columns %v, got %v", i, s.expected, columns)
continue
}
for _, c := range columns {
if !list.ExistInSlice(c, scenario.expected) {
t.Errorf("(%d) Didn't expect column %s", i, c)
if !list.ExistInSlice(c, s.expected) {
t.Errorf("[%d] Didn't expect column %s", i, c)
}
}
}
}
func TestGetTableInfo(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
tableName string
expected string
}{
{"", "[]"},
{"missing", "[]"},
{
"_admins",
`[{"PK":1,"Index":0,"Name":"id","Type":"TEXT","NotNull":false,"DefaultValue":null},{"PK":0,"Index":1,"Name":"avatar","Type":"INTEGER","NotNull":true,"DefaultValue":0},{"PK":0,"Index":2,"Name":"email","Type":"TEXT","NotNull":true,"DefaultValue":null},{"PK":0,"Index":3,"Name":"tokenKey","Type":"TEXT","NotNull":true,"DefaultValue":null},{"PK":0,"Index":4,"Name":"passwordHash","Type":"TEXT","NotNull":true,"DefaultValue":null},{"PK":0,"Index":5,"Name":"lastResetSentAt","Type":"TEXT","NotNull":true,"DefaultValue":""},{"PK":0,"Index":6,"Name":"created","Type":"TEXT","NotNull":true,"DefaultValue":""},{"PK":0,"Index":7,"Name":"updated","Type":"TEXT","NotNull":true,"DefaultValue":""}]`,
},
}
for i, s := range scenarios {
rows, _ := app.Dao().GetTableInfo(s.tableName)
raw, _ := json.Marshal(rows)
rawStr := string(raw)
if rawStr != s.expected {
t.Errorf("[%d] Expected \n%v, \ngot \n%v", i, s.expected, rawStr)
}
}
}
func TestDeleteTable(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
@@ -69,16 +99,17 @@ func TestDeleteTable(t *testing.T) {
expectError bool
}{
{"", true},
{"test", true},
{"test", false}, // missing tables are ignored
{"_admins", false},
{"demo3", false},
}
for i, scenario := range scenarios {
err := app.Dao().DeleteTable(scenario.tableName)
for i, s := range scenarios {
err := app.Dao().DeleteTable(s.tableName)
hasErr := err != nil
if hasErr != scenario.expectError {
t.Errorf("(%d) Expected hasErr %v, got %v", i, scenario.expectError, hasErr)
if hasErr != s.expectError {
t.Errorf("[%d] Expected hasErr %v, got %v", i, s.expectError, hasErr)
}
}
}
+583
View File
@@ -0,0 +1,583 @@
package daos
import (
"errors"
"fmt"
"io"
"regexp"
"strings"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/models/schema"
"github.com/pocketbase/pocketbase/tools/inflector"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/security"
"github.com/pocketbase/pocketbase/tools/tokenizer"
"github.com/pocketbase/pocketbase/tools/types"
)
// DeleteView drops the specified view name.
//
// This method is a no-op if a view with the provided name doesn't exist.
//
// Be aware that this method is vulnerable to SQL injection and the
// "name" argument must come only from trusted input!
func (dao *Dao) DeleteView(name string) error {
_, err := dao.DB().NewQuery(fmt.Sprintf(
"DROP VIEW IF EXISTS {{%s}}",
name,
)).Execute()
return err
}
// SaveView creates (or updates already existing) persistent SQL view.
//
// Be aware that this method is vulnerable to SQL injection and the
// "selectQuery" argument must come only from trusted input!
func (dao *Dao) SaveView(name string, selectQuery string) error {
return dao.RunInTransaction(func(txDao *Dao) error {
// delete old view (if exists)
if err := txDao.DeleteView(name); err != nil {
return err
}
trimmed := strings.Trim(selectQuery, ";")
// try to eagerly detect multiple inline statements
tk := tokenizer.NewFromString(trimmed)
tk.Separators(';')
if queryParts, _ := tk.ScanAll(); len(queryParts) > 1 {
return errors.New("multiple statements are not supported")
}
// (re)create the view
//
// note: the query is wrapped in a secondary SELECT as a rudimentary
// measure to discourage multiple inline sql statements execution.
viewQuery := fmt.Sprintf("CREATE VIEW {{%s}} AS SELECT * FROM (%s)", name, trimmed)
if _, err := txDao.DB().NewQuery(viewQuery).Execute(); err != nil {
return err
}
// fetch the view table info to ensure that the view was created
// because missing tables or columns won't return an error
if _, err := txDao.GetTableInfo(name); err != nil {
return err
}
return nil
})
}
// CreateViewSchema creates a new view schema from the provided select query.
//
// There are some caveats:
// - The select query must have an "id" column.
// - Wildcard ("*") columns are not supported to avoid accidentally leaking sensitive data.
func (dao *Dao) CreateViewSchema(selectQuery string) (schema.Schema, error) {
result := schema.NewSchema()
suggestedFields, err := dao.parseQueryToFields(selectQuery)
if err != nil {
return result, err
}
// note wrap in a transaction in case the selectQuery contains
// multiple statements allowing us to rollback on any error
txErr := dao.RunInTransaction(func(txDao *Dao) error {
tempView := "_temp_" + security.PseudorandomString(5)
// create a temp view with the provided query
if err := txDao.SaveView(tempView, selectQuery); err != nil {
return err
}
defer txDao.DeleteView(tempView)
// extract the generated view table info
info, err := txDao.GetTableInfo(tempView)
if err != nil {
return err
}
var hasId bool
for _, row := range info {
if row.Name == schema.FieldNameId {
hasId = true
}
if list.ExistInSlice(row.Name, schema.BaseModelFieldNames()) {
continue // skip base model fields since they are not part of the schema
}
var field *schema.SchemaField
if f, ok := suggestedFields[row.Name]; ok {
field = f.field
} else {
field = defaultViewField(row.Name)
}
result.AddField(field)
}
if !hasId {
return errors.New("missing required id column (you ca use `(ROW_NUMBER() OVER()) as id` if you don't have one)")
}
return nil
})
return result, txErr
}
// FindRecordByViewFile returns the original models.Record of the
// provided view collection file.
func (dao *Dao) FindRecordByViewFile(
viewCollectionNameOrId string,
fileFieldName string,
filename string,
) (*models.Record, error) {
view, err := dao.FindCollectionByNameOrId(viewCollectionNameOrId)
if err != nil {
return nil, err
}
if !view.IsView() {
return nil, errors.New("not a view collection")
}
var findFirstNonViewQueryFileField func(int) (*queryField, error)
findFirstNonViewQueryFileField = func(level int) (*queryField, error) {
// check the level depth to prevent infinite circular recursion
// (the limit is arbitrary and may change in the future)
if level > 5 {
return nil, errors.New("reached the max recursion level of view collection file field queries")
}
queryFields, err := dao.parseQueryToFields(view.ViewOptions().Query)
if err != nil {
return nil, err
}
for _, item := range queryFields {
if item.collection == nil ||
item.original == nil ||
item.field.Name != fileFieldName {
continue
}
if item.collection.IsView() {
view = item.collection
fileFieldName = item.original.Name
return findFirstNonViewQueryFileField(level + 1)
}
return item, nil
}
return nil, errors.New("no query file field found")
}
qf, err := findFirstNonViewQueryFileField(1)
if err != nil {
return nil, err
}
cleanFieldName := inflector.Columnify(qf.original.Name)
row := dbx.NullStringMap{}
err = dao.RecordQuery(qf.collection).
InnerJoin(fmt.Sprintf(
// note: the case is used to normalize the value access
`json_each(CASE WHEN json_valid([[%s]]) THEN [[%s]] ELSE json_array([[%s]]) END) as {{_je_file}}`,
cleanFieldName, cleanFieldName, cleanFieldName,
), dbx.HashExp{"_je_file.value": filename}).
Limit(1).
One(row)
if err != nil {
return nil, err
}
return models.NewRecordFromNullStringMap(qf.collection, row), nil
}
// -------------------------------------------------------------------
// Raw query to schema helpers
// -------------------------------------------------------------------
type queryField struct {
// field is the final resolved field.
field *schema.SchemaField
// collection refers to the original field's collection model.
// It could be nil if the found query field is not from a collection schema.
collection *models.Collection
// original is the original found collection field.
// It could be nil if the found query field is not from a collection schema.
original *schema.SchemaField
}
func defaultViewField(name string) *schema.SchemaField {
return &schema.SchemaField{
Name: name,
Type: schema.FieldTypeJson,
}
}
func (dao *Dao) parseQueryToFields(selectQuery string) (map[string]*queryField, error) {
p := new(identifiersParser)
if err := p.parse(selectQuery); err != nil {
return nil, err
}
collections, err := dao.findCollectionsByIdentifiers(p.tables)
if err != nil {
return nil, err
}
result := make(map[string]*queryField, len(p.columns))
var mainTable identifier
if len(p.tables) > 0 {
mainTable = p.tables[0]
}
for _, col := range p.columns {
colLower := strings.ToLower(col.original)
// numeric expression cast
if strings.Contains(colLower, "(") &&
(strings.HasPrefix(colLower, "count(") ||
strings.HasPrefix(colLower, "total(") ||
strings.Contains(colLower, " as numeric") ||
strings.Contains(colLower, " as real") ||
strings.Contains(colLower, " as int") ||
strings.Contains(colLower, " as integer") ||
strings.Contains(colLower, " as decimal")) {
result[col.alias] = &queryField{
field: &schema.SchemaField{
Name: col.alias,
Type: schema.FieldTypeNumber,
},
}
continue
}
parts := strings.Split(col.original, ".")
var fieldName string
var collection *models.Collection
var isMainTableField bool
if len(parts) == 2 {
fieldName = parts[1]
collection = collections[parts[0]]
isMainTableField = parts[0] == mainTable.alias
} else {
fieldName = parts[0]
collection = collections[mainTable.alias]
isMainTableField = true
}
// fallback to the default field if the found column is not from a collection schema
if collection == nil {
result[col.alias] = &queryField{
field: defaultViewField(col.alias),
}
continue
}
if fieldName == "*" {
return nil, errors.New("dynamic column names are not supported")
}
// find the first field by name (case insensitive)
var field *schema.SchemaField
for _, f := range collection.Schema.Fields() {
if strings.EqualFold(f.Name, fieldName) {
field = f
break
}
}
if field != nil {
clone := *field
clone.Name = col.alias
result[col.alias] = &queryField{
field: &clone,
collection: collection,
original: field,
}
continue
}
if fieldName == schema.FieldNameId && !isMainTableField {
// convert to relation since it is a direct id reference to non-maintable collection
result[col.alias] = &queryField{
field: &schema.SchemaField{
Name: col.alias,
Type: schema.FieldTypeRelation,
Options: &schema.RelationOptions{
MaxSelect: types.Pointer(1),
CollectionId: collection.Id,
},
},
collection: collection,
}
} else if fieldName == schema.FieldNameCreated || fieldName == schema.FieldNameUpdated {
result[col.alias] = &queryField{
field: &schema.SchemaField{
Name: col.alias,
Type: schema.FieldTypeDate,
},
collection: collection,
}
} else if fieldName == schema.FieldNameUsername && collection.IsAuth() {
result[col.alias] = &queryField{
field: &schema.SchemaField{
Name: col.alias,
Type: schema.FieldTypeText,
},
collection: collection,
}
} else if fieldName == schema.FieldNameEmail && collection.IsAuth() {
result[col.alias] = &queryField{
field: &schema.SchemaField{
Name: col.alias,
Type: schema.FieldTypeEmail,
},
collection: collection,
}
} else if (fieldName == schema.FieldNameVerified || fieldName == schema.FieldNameEmailVisibility) && collection.IsAuth() {
result[col.alias] = &queryField{
field: &schema.SchemaField{
Name: col.alias,
Type: schema.FieldTypeBool,
},
collection: collection,
}
} else {
result[col.alias] = &queryField{
field: defaultViewField(col.alias),
collection: collection,
}
}
}
return result, nil
}
func (dao *Dao) findCollectionsByIdentifiers(tables []identifier) (map[string]*models.Collection, error) {
names := make([]any, 0, len(tables))
for _, table := range tables {
if strings.Contains(table.alias, "(") {
continue // skip expressions
}
names = append(names, table.original)
}
if len(names) == 0 {
return nil, nil
}
result := make(map[string]*models.Collection, len(names))
collections := make([]*models.Collection, 0, len(names))
err := dao.CollectionQuery().
AndWhere(dbx.In("name", names...)).
All(&collections)
if err != nil {
return nil, err
}
for _, table := range tables {
for _, collection := range collections {
if collection.Name == table.original {
result[table.alias] = collection
}
}
}
return result, nil
}
// -------------------------------------------------------------------
// Raw query identifiers parser
// -------------------------------------------------------------------
var joinReplaceRegex = regexp.MustCompile(`(?im)\s+(inner join|outer join|left join|right join|join)\s+?`)
var discardReplaceRegex = regexp.MustCompile(`(?im)\s+(where|group by|having|order|limit|with)\s+?`)
var commentsReplaceRegex = regexp.MustCompile(`(?m)(\/\*[\s\S]+\*\/)|(--.+$)`)
type identifier struct {
original string
alias string
}
type identifiersParser struct {
columns []identifier
tables []identifier
}
func (p *identifiersParser) parse(selectQuery string) error {
str := strings.Trim(selectQuery, ";")
str = joinReplaceRegex.ReplaceAllString(str, " _join_ ")
str = discardReplaceRegex.ReplaceAllString(str, " _discard_ ")
str = commentsReplaceRegex.ReplaceAllString(str, "")
tk := tokenizer.NewFromString(str)
tk.Separators(',', ' ', '\n', '\t')
tk.KeepSeparator(true)
var skip bool
var partType string
var activeBuilder *strings.Builder
var selectParts strings.Builder
var fromParts strings.Builder
var joinParts strings.Builder
for {
token, err := tk.Scan()
if err != nil {
if err != io.EOF {
return err
}
break
}
trimmed := strings.ToLower(strings.TrimSpace(token))
switch trimmed {
case "select":
skip = false
partType = "select"
activeBuilder = &selectParts
case "from":
skip = false
partType = "from"
activeBuilder = &fromParts
case "_join_":
skip = false
// the previous part was also a join
if partType == "join" {
joinParts.WriteString(",")
}
partType = "join"
activeBuilder = &joinParts
case "_discard_":
// do nothing...
skip = true
default:
isJoin := partType == "join"
if isJoin && trimmed == "on" {
skip = true
}
if !skip && activeBuilder != nil {
activeBuilder.WriteString(" ")
activeBuilder.WriteString(token)
}
}
}
selects, err := extractIdentifiers(selectParts.String())
if err != nil {
return err
}
froms, err := extractIdentifiers(fromParts.String())
if err != nil {
return err
}
joins, err := extractIdentifiers(joinParts.String())
if err != nil {
return err
}
p.columns = selects
p.tables = froms
p.tables = append(p.tables, joins...)
return nil
}
func extractIdentifiers(rawExpression string) ([]identifier, error) {
rawTk := tokenizer.NewFromString(rawExpression)
rawTk.Separators(',')
rawIdentifiers, err := rawTk.ScanAll()
if err != nil {
return nil, err
}
result := make([]identifier, 0, len(rawIdentifiers))
for _, rawIdentifier := range rawIdentifiers {
tk := tokenizer.NewFromString(rawIdentifier)
tk.Separators(' ', '\n', '\t')
parts, err := tk.ScanAll()
if err != nil {
return nil, err
}
resolved, err := identifierFromParts(parts)
if err != nil {
return nil, err
}
result = append(result, resolved)
}
return result, nil
}
func identifierFromParts(parts []string) (identifier, error) {
var result identifier
switch len(parts) {
case 3:
if !strings.EqualFold(parts[1], "as") {
return result, fmt.Errorf(`invalid identifier part - expected "as", got %v`, parts[1])
}
result.original = parts[0]
result.alias = parts[2]
case 2:
result.original = parts[0]
result.alias = parts[1]
case 1:
subParts := strings.Split(parts[0], ".")
result.original = parts[0]
result.alias = subParts[len(subParts)-1]
default:
return result, fmt.Errorf(`invalid identifier parts %v`, parts)
}
result.original = trimRawIdentifier(result.original)
result.alias = trimRawIdentifier(result.alias)
return result, nil
}
func trimRawIdentifier(rawIdentifier string) string {
const trimChars = "`\"[];"
parts := strings.Split(rawIdentifier, ".")
for i := range parts {
parts[i] = strings.Trim(parts[i], trimChars)
}
return strings.Join(parts, ".")
}
+539
View File
@@ -0,0 +1,539 @@
package daos_test
import (
"encoding/json"
"fmt"
"testing"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/models/schema"
"github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/list"
)
func TestDeleteView(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
viewName string
expectError bool
}{
{"", true},
{"demo1", true}, // not a view table
{"missing", false}, // missing or already deleted
{"view1", false}, // existing
{"VieW1", false}, // view names are case insensitives
}
for i, s := range scenarios {
err := app.Dao().DeleteView(s.viewName)
hasErr := err != nil
if hasErr != s.expectError {
t.Errorf("[%d - %q] Expected hasErr %v, got %v (%v)", i, s.viewName, s.expectError, hasErr, err)
}
}
}
func TestSaveView(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
scenarioName string
viewName string
query string
expectError bool
expectColumns []string
}{
{
"empty name and query",
"",
"",
true,
nil,
},
{
"empty name",
"",
"select * from _admins",
true,
nil,
},
{
"empty query",
"123Test",
"",
true,
nil,
},
{
"invalid query",
"123Test",
"123 456",
true,
nil,
},
{
"missing table",
"123Test",
"select * from missing",
true,
nil,
},
{
"non select query",
"123Test",
"drop table _admins",
true,
nil,
},
{
"multiple select queries",
"123Test",
"select *, count(id) as c from _admins; select * from demo1;",
true,
nil,
},
{
"try to break the parent parenthesis",
"123Test",
"select *, count(id) as c from `_admins`)",
true,
nil,
},
{
"simple select query (+ trimmed semicolon)",
"123Test",
";select *, count(id) as c from _admins;",
false,
[]string{
"id", "created", "updated",
"passwordHash", "tokenKey", "email",
"lastResetSentAt", "avatar", "c",
},
},
{
"update old view with new query",
"123Test",
"select 1 as test from _admins",
false,
[]string{"test"},
},
}
for _, s := range scenarios {
err := app.Dao().SaveView(s.viewName, s.query)
hasErr := err != nil
if hasErr != s.expectError {
t.Errorf("[%s] Expected hasErr %v, got %v (%v)", s.scenarioName, s.expectError, hasErr, err)
continue
}
if hasErr {
continue
}
infoRows, err := app.Dao().GetTableInfo(s.viewName)
if err != nil {
t.Errorf("[%s] Failed to fetch table info for %s: %v", s.scenarioName, s.viewName, err)
continue
}
if len(s.expectColumns) != len(infoRows) {
t.Errorf("[%s] Expected %d columns, got %d", s.scenarioName, len(s.expectColumns), len(infoRows))
continue
}
for _, row := range infoRows {
if !list.ExistInSlice(row.Name, s.expectColumns) {
t.Errorf("[%s] Missing %q column in %v", s.scenarioName, row.Name, s.expectColumns)
}
}
}
}
func TestCreateViewSchema(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
scenarios := []struct {
name string
query string
expectError bool
expectFields map[string]string // name-type pairs
}{
{
"empty query",
"",
true,
nil,
},
{
"invalid query",
"test 123456",
true,
nil,
},
{
"missing table",
"select * from missing",
true,
nil,
},
{
"query with wildcard column",
"select a.id, a.* from demo1 a",
true,
nil,
},
{
"query without id",
"select text, url, created, updated from demo1",
true,
nil,
},
{
"query with comments",
`
select
-- test single line
id,
text,
/* multi
line comment */
url, created, updated from demo1
`,
false,
map[string]string{
"text": schema.FieldTypeText,
"url": schema.FieldTypeUrl,
},
},
{
"query with all fields and quoted identifiers",
`
select
"id",
"created",
"updated",
[text],
` + "`bool`" + `,
"url",
"select_one",
"select_many",
"file_one",
"demo1"."file_many",
` + "`demo1`." + "`number`" + ` number_alias,
"email",
"datetime",
"json",
"rel_one",
"rel_many"
from demo1
`,
false,
map[string]string{
"text": schema.FieldTypeText,
"bool": schema.FieldTypeBool,
"url": schema.FieldTypeUrl,
"select_one": schema.FieldTypeSelect,
"select_many": schema.FieldTypeSelect,
"file_one": schema.FieldTypeFile,
"file_many": schema.FieldTypeFile,
"number_alias": schema.FieldTypeNumber,
"email": schema.FieldTypeEmail,
"datetime": schema.FieldTypeDate,
"json": schema.FieldTypeJson,
"rel_one": schema.FieldTypeRelation,
"rel_many": schema.FieldTypeRelation,
},
},
{
"query with indirect relations fields",
"select a.id, b.id as bid, b.created from demo1 as a left join demo2 b",
false,
map[string]string{
"bid": schema.FieldTypeRelation,
},
},
{
"query with multiple froms, joins and style of aliasses",
`
select
a.id as id,
b.id as bid,
lj.id cid,
ij.id as did,
a.bool,
_admins.id as eid,
_admins.email
from demo1 a, demo2 as b
left join demo3 lj on lj.id = 123
inner join demo4 as ij on ij.id = 123
join _admins
where 1=1
group by a.id
limit 10
`,
false,
map[string]string{
"bid": schema.FieldTypeRelation,
"cid": schema.FieldTypeRelation,
"did": schema.FieldTypeRelation,
"bool": schema.FieldTypeBool,
"eid": schema.FieldTypeJson, // not from collection
"email": schema.FieldTypeJson, // not from collection
},
},
{
"query with numeric casts",
`select
a.id,
count(a.id) count,
cast(a.id as int) cast_int,
cast(a.id as integer) cast_integer,
cast(a.id as real) cast_real,
cast(a.id as decimal) cast_decimal,
cast(a.id as numeric) cast_numeric,
avg(a.id) avg,
sum(a.id) sum,
total(a.id) total,
min(a.id) min,
max(a.id) max
from demo1 a`,
false,
map[string]string{
"count": schema.FieldTypeNumber,
"total": schema.FieldTypeNumber,
"cast_int": schema.FieldTypeNumber,
"cast_integer": schema.FieldTypeNumber,
"cast_real": schema.FieldTypeNumber,
"cast_decimal": schema.FieldTypeNumber,
"cast_numeric": schema.FieldTypeNumber,
// json because they are nullable
"sum": schema.FieldTypeJson,
"avg": schema.FieldTypeJson,
"min": schema.FieldTypeJson,
"max": schema.FieldTypeJson,
},
},
{
"query with reserved auth collection fields",
`
select
a.id,
a.username,
a.email,
a.emailVisibility,
a.verified,
demo1.id relid
from users a
left join demo1
`,
false,
map[string]string{
"username": schema.FieldTypeText,
"email": schema.FieldTypeEmail,
"emailVisibility": schema.FieldTypeBool,
"verified": schema.FieldTypeBool,
"relid": schema.FieldTypeRelation,
},
},
{
"query with unknown fields and aliases",
`select
id,
id as id2,
text as text_alias,
url as url_alias,
"demo1"."bool" as bool_alias,
number as number_alias,
created created_alias,
updated updated_alias,
123 as custom
from demo1
`,
false,
map[string]string{
"id2": schema.FieldTypeJson,
"text_alias": schema.FieldTypeText,
"url_alias": schema.FieldTypeUrl,
"bool_alias": schema.FieldTypeBool,
"number_alias": schema.FieldTypeNumber,
"created_alias": schema.FieldTypeDate,
"updated_alias": schema.FieldTypeDate,
"custom": schema.FieldTypeJson,
},
},
}
for _, s := range scenarios {
result, err := app.Dao().CreateViewSchema(s.query)
hasErr := err != nil
if hasErr != s.expectError {
t.Errorf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err)
continue
}
if hasErr {
continue
}
if len(s.expectFields) != len(result.Fields()) {
serialized, _ := json.Marshal(result)
t.Errorf("[%s] Expected %d fields, got %d: \n%s", s.name, len(s.expectFields), len(result.Fields()), serialized)
continue
}
for name, typ := range s.expectFields {
field := result.GetFieldByName(name)
if field == nil {
t.Errorf("[%s] Expected to find field %s, got nil", s.name, name)
continue
}
if field.Type != typ {
t.Errorf("[%s] Expected field %s to be %q, got %s", s.name, name, typ, field.Type)
continue
}
}
}
}
func TestFindRecordByViewFile(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
prevCollection, err := app.Dao().FindCollectionByNameOrId("demo1")
if err != nil {
t.Fatal(err)
}
totalLevels := 6
// create collection view mocks
fileOneAlias := "file_one one0"
fileManyAlias := "file_many many0"
mockCollections := make([]*models.Collection, 0, totalLevels)
for i := 0; i <= totalLevels; i++ {
view := new(models.Collection)
view.Type = models.CollectionTypeView
view.Name = fmt.Sprintf("_test_view%d", i)
view.SetOptions(&models.CollectionViewOptions{
Query: fmt.Sprintf(
"select id, %s, %s from %s",
fileOneAlias,
fileManyAlias,
prevCollection.Name,
),
})
// save view
if err := app.Dao().SaveCollection(view); err != nil {
t.Fatalf("Failed to save view%d: %v", i, err)
}
mockCollections = append(mockCollections, view)
prevCollection = view
fileOneAlias = fmt.Sprintf("one%d one%d", i, i+1)
fileManyAlias = fmt.Sprintf("many%d many%d", i, i+1)
}
fileOneName := "test_d61b33QdDU.txt"
fileManyName := "test_QZFjKjXchk.txt"
expectedRecordId := "84nmscqy84lsi1t"
scenarios := []struct {
name string
collectionNameOrId string
fileFieldName string
filename string
expectError bool
expectRecordId string
}{
{
"missing collection",
"missing",
"a",
fileOneName,
true,
"",
},
{
"non-view collection",
"demo1",
"file_one",
fileOneName,
true,
"",
},
{
"view collection after the max recursion limit",
mockCollections[totalLevels-1].Name,
fmt.Sprintf("one%d", totalLevels-1),
fileOneName,
true,
"",
},
{
"first view collection (single file)",
mockCollections[0].Name,
"one0",
fileOneName,
false,
expectedRecordId,
},
{
"first view collection (many files)",
mockCollections[0].Name,
"many0",
fileManyName,
false,
expectedRecordId,
},
{
"last view collection before the recursion limit (single file)",
mockCollections[totalLevels-2].Name,
fmt.Sprintf("one%d", totalLevels-2),
fileOneName,
false,
expectedRecordId,
},
{
"last view collection before the recursion limit (many files)",
mockCollections[totalLevels-2].Name,
fmt.Sprintf("many%d", totalLevels-2),
fileManyName,
false,
expectedRecordId,
},
}
for _, s := range scenarios {
record, err := app.Dao().FindRecordByViewFile(
s.collectionNameOrId,
s.fileFieldName,
s.filename,
)
hasErr := err != nil
if hasErr != s.expectError {
t.Errorf("[%s] Expected hasErr %v, got %v (%v)", s.name, s.expectError, hasErr, err)
continue
}
if hasErr {
continue
}
if record.Id != s.expectRecordId {
t.Errorf("[%s] Expected recordId %q, got %q", s.name, s.expectRecordId, record.Id)
}
}
}