initial v0.8 pre-release
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"image"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -116,6 +117,39 @@ func (s *System) Upload(content []byte, fileKey string) error {
|
||||
return w.Close()
|
||||
}
|
||||
|
||||
// UploadMultipart upload the provided multipart file to the fileKey location.
|
||||
func (s *System) UploadMultipart(fh *multipart.FileHeader, fileKey string) error {
|
||||
f, err := fh.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
mt, err := mimetype.DetectReader(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// rewind
|
||||
f.Seek(0, io.SeekStart)
|
||||
|
||||
opts := &blob.WriterOptions{
|
||||
ContentType: mt.String(),
|
||||
}
|
||||
|
||||
w, err := s.bucket.NewWriter(s.ctx, fileKey, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := w.ReadFrom(f); err != nil {
|
||||
w.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return w.Close()
|
||||
}
|
||||
|
||||
// Delete deletes stored file at fileKey location.
|
||||
func (s *System) Delete(fileKey string) error {
|
||||
return s.bucket.Delete(s.ctx, fileKey)
|
||||
@@ -233,7 +267,7 @@ func (s *System) Serve(response http.ResponseWriter, fileKey string, name string
|
||||
response.Header().Set("Content-Length", strconv.FormatInt(r.Size(), 10))
|
||||
response.Header().Set("Content-Security-Policy", "default-src 'none'; media-src 'self'; style-src 'unsafe-inline'; sandbox")
|
||||
|
||||
// All HTTP date/time stamps MUST be represented in Greenwich Mean Time (GMT)
|
||||
// all HTTP date/time stamps MUST be represented in Greenwich Mean Time (GMT)
|
||||
// (see https://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.3.1)
|
||||
//
|
||||
// NB! time.LoadLocation may fail on non-Unix systems (see https://github.com/pocketbase/pocketbase/issues/45)
|
||||
@@ -242,6 +276,13 @@ func (s *System) Serve(response http.ResponseWriter, fileKey string, name string
|
||||
response.Header().Set("Last-Modified", r.ModTime().In(location).Format("Mon, 02 Jan 06 15:04:05 MST"))
|
||||
}
|
||||
|
||||
// set a default cache-control header
|
||||
// (valid for 30 days but the cache is allowed to reuse the file for any requests
|
||||
// that are made in the last day while revalidating the response in the background)
|
||||
if response.Header().Get("Cache-Control") == "" {
|
||||
response.Header().Set("Cache-Control", "max-age=2592000, stale-while-revalidate=86400")
|
||||
}
|
||||
|
||||
// copy from the read range to response.
|
||||
_, err := io.Copy(response, r)
|
||||
|
||||
@@ -282,6 +323,7 @@ func (s *System) CreateThumb(originalKey string, thumbKey, thumbSize string) err
|
||||
defer r.Close()
|
||||
|
||||
// create imaging object from the original reader
|
||||
// (note: only the first frame for animated image formats)
|
||||
img, decodeErr := imaging.Decode(r, imaging.AutoOrientation(true))
|
||||
if decodeErr != nil {
|
||||
return decodeErr
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package filesystem_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/png"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -128,6 +131,46 @@ func TestFileSystemDeletePrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemUploadMultipart(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
// create multipart form file
|
||||
body := new(bytes.Buffer)
|
||||
mp := multipart.NewWriter(body)
|
||||
w, err := mp.CreateFormFile("test", "test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed creating form file: %v", err)
|
||||
}
|
||||
w.Write([]byte("demo"))
|
||||
mp.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Add("Content-Type", mp.FormDataContentType())
|
||||
|
||||
file, fh, err := req.FormFile("test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch form file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
// ---
|
||||
|
||||
fs, err := filesystem.NewLocal(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer fs.Close()
|
||||
|
||||
uploadErr := fs.UploadMultipart(fh, "newdir/newkey.txt")
|
||||
if uploadErr != nil {
|
||||
t.Fatal(uploadErr)
|
||||
}
|
||||
|
||||
if exists, _ := fs.Exists("newdir/newkey.txt"); !exists {
|
||||
t.Fatalf("Expected newdir/newkey.txt to exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemUpload(t *testing.T) {
|
||||
dir := createTestDir(t)
|
||||
defer os.RemoveAll(dir)
|
||||
@@ -232,6 +275,10 @@ func TestFileSystemServe(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
|
||||
if scenario.expectError {
|
||||
continue
|
||||
}
|
||||
|
||||
result := r.Result()
|
||||
|
||||
for hName, hValue := range scenario.expectHeaders {
|
||||
@@ -244,6 +291,10 @@ func TestFileSystemServe(t *testing.T) {
|
||||
if v := result.Header.Get("X-Frame-Options"); v != "" {
|
||||
t.Errorf("(%s) Expected the X-Frame-Options header to be unset, got %v", scenario.path, v)
|
||||
}
|
||||
|
||||
if v := result.Header.Get("Cache-Control"); v == "" {
|
||||
t.Errorf("(%s) Expected Cache-Control header to be set, got empty string", scenario.path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+2
-2
@@ -54,12 +54,12 @@ func (h *Hook[T]) Reset() {
|
||||
// - hook.StopPropagation is returned in one of the handlers
|
||||
// - any non-nil error is returned in one of the handlers
|
||||
func (h *Hook[T]) Trigger(data T, oneOffHandlers ...Handler[T]) error {
|
||||
h.mux.Lock()
|
||||
h.mux.RLock()
|
||||
handlers := make([]Handler[T], 0, len(h.handlers)+len(oneOffHandlers))
|
||||
handlers = append(handlers, h.handlers...)
|
||||
handlers = append(handlers, oneOffHandlers...)
|
||||
// unlock is not deferred to avoid deadlocks when Trigger is called recursive by the handlers
|
||||
h.mux.Unlock()
|
||||
h.mux.RUnlock()
|
||||
|
||||
for _, fn := range handlers {
|
||||
err := fn(data)
|
||||
|
||||
@@ -46,7 +46,10 @@ func (m *SmtpClient) Send(
|
||||
htmlContent string,
|
||||
attachments map[string]io.Reader,
|
||||
) error {
|
||||
smtpAuth := smtp.PlainAuth("", m.username, m.password, m.host)
|
||||
var smtpAuth smtp.Auth
|
||||
if m.username != "" || m.password != "" {
|
||||
smtpAuth = smtp.PlainAuth("", m.username, m.password, m.host)
|
||||
}
|
||||
|
||||
// create mail instance
|
||||
var yak *mailyak.MailYak
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
)
|
||||
|
||||
// ApiError defines the properties for a basic api error response.
|
||||
type ApiError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data map[string]any `json:"data"`
|
||||
|
||||
// stores unformatted error data (could be an internal error, text, etc.)
|
||||
rawData any
|
||||
}
|
||||
|
||||
// Error makes it compatible with the `error` interface.
|
||||
func (e *ApiError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
func (e *ApiError) RawData() any {
|
||||
return e.rawData
|
||||
}
|
||||
|
||||
// NewNotFoundError creates and returns 404 `ApiError`.
|
||||
func NewNotFoundError(message string, data any) *ApiError {
|
||||
if message == "" {
|
||||
message = "The requested resource wasn't found."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusNotFound, message, data)
|
||||
}
|
||||
|
||||
// NewBadRequestError creates and returns 400 `ApiError`.
|
||||
func NewBadRequestError(message string, data any) *ApiError {
|
||||
if message == "" {
|
||||
message = "Something went wrong while processing your request."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusBadRequest, message, data)
|
||||
}
|
||||
|
||||
// NewForbiddenError creates and returns 403 `ApiError`.
|
||||
func NewForbiddenError(message string, data any) *ApiError {
|
||||
if message == "" {
|
||||
message = "You are not allowed to perform this request."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusForbidden, message, data)
|
||||
}
|
||||
|
||||
// NewUnauthorizedError creates and returns 401 `ApiError`.
|
||||
func NewUnauthorizedError(message string, data any) *ApiError {
|
||||
if message == "" {
|
||||
message = "Missing or invalid authentication token."
|
||||
}
|
||||
|
||||
return NewApiError(http.StatusUnauthorized, message, data)
|
||||
}
|
||||
|
||||
// NewApiError creates and returns new normalized `ApiError` instance.
|
||||
func NewApiError(status int, message string, data any) *ApiError {
|
||||
message = inflector.Sentenize(message)
|
||||
|
||||
formattedData := map[string]any{}
|
||||
|
||||
if v, ok := data.(validation.Errors); ok {
|
||||
formattedData = resolveValidationErrors(v)
|
||||
}
|
||||
|
||||
return &ApiError{
|
||||
rawData: data,
|
||||
Data: formattedData,
|
||||
Code: status,
|
||||
Message: strings.TrimSpace(message),
|
||||
}
|
||||
}
|
||||
|
||||
func resolveValidationErrors(validationErrors validation.Errors) map[string]any {
|
||||
result := map[string]any{}
|
||||
|
||||
// extract from each validation error its error code and message.
|
||||
for name, err := range validationErrors {
|
||||
// check for nested errors
|
||||
if nestedErrs, ok := err.(validation.Errors); ok {
|
||||
result[name] = resolveValidationErrors(nestedErrs)
|
||||
continue
|
||||
}
|
||||
|
||||
errCode := "validation_invalid_value" // default
|
||||
if errObj, ok := err.(validation.ErrorObject); ok {
|
||||
errCode = errObj.Code()
|
||||
}
|
||||
|
||||
result[name] = map[string]string{
|
||||
"code": errCode,
|
||||
"message": inflector.Sentenize(err.Error()),
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -1,150 +0,0 @@
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
"github.com/pocketbase/pocketbase/tools/rest"
|
||||
)
|
||||
|
||||
func TestNewApiErrorWithRawData(t *testing.T) {
|
||||
e := rest.NewApiError(
|
||||
300,
|
||||
"message_test",
|
||||
"rawData_test",
|
||||
)
|
||||
|
||||
result, _ := json.Marshal(e)
|
||||
expected := `{"code":300,"message":"Message_test.","data":{}}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected %v, got %v", expected, string(result))
|
||||
}
|
||||
|
||||
if e.Error() != "Message_test." {
|
||||
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
|
||||
}
|
||||
|
||||
if e.RawData() != "rawData_test" {
|
||||
t.Errorf("Expected rawData %v, got %v", "rawData_test", e.RawData())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewApiErrorWithValidationData(t *testing.T) {
|
||||
e := rest.NewApiError(
|
||||
300,
|
||||
"message_test",
|
||||
validation.Errors{
|
||||
"err1": errors.New("test error"),
|
||||
"err2": validation.ErrRequired,
|
||||
"err3": validation.Errors{
|
||||
"sub1": errors.New("test error"),
|
||||
"sub2": validation.ErrRequired,
|
||||
"sub3": validation.Errors{
|
||||
"sub11": validation.ErrRequired,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result, _ := json.Marshal(e)
|
||||
expected := `{"code":300,"message":"Message_test.","data":{"err1":{"code":"validation_invalid_value","message":"Test error."},"err2":{"code":"validation_required","message":"Cannot be blank."},"err3":{"sub1":{"code":"validation_invalid_value","message":"Test error."},"sub2":{"code":"validation_required","message":"Cannot be blank."},"sub3":{"sub11":{"code":"validation_required","message":"Cannot be blank."}}}}}`
|
||||
|
||||
if string(result) != expected {
|
||||
t.Errorf("Expected %v, got %v", expected, string(result))
|
||||
}
|
||||
|
||||
if e.Error() != "Message_test." {
|
||||
t.Errorf("Expected %q, got %q", "Message_test.", e.Error())
|
||||
}
|
||||
|
||||
if e.RawData() == nil {
|
||||
t.Error("Expected non-nil rawData")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewNotFoundError(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"code":404,"message":"The requested resource wasn't found.","data":{}}`},
|
||||
{"demo", "rawData_test", `{"code":404,"message":"Demo.","data":{}}`},
|
||||
{"demo", validation.Errors{"err1": errors.New("test error")}, `{"code":404,"message":"Demo.","data":{"err1":{"code":"validation_invalid_value","message":"Test error."}}}`},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
e := rest.NewNotFoundError(scenario.message, scenario.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if string(result) != scenario.expected {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBadRequestError(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"code":400,"message":"Something went wrong while processing your request.","data":{}}`},
|
||||
{"demo", "rawData_test", `{"code":400,"message":"Demo.","data":{}}`},
|
||||
{"demo", validation.Errors{"err1": errors.New("test error")}, `{"code":400,"message":"Demo.","data":{"err1":{"code":"validation_invalid_value","message":"Test error."}}}`},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
e := rest.NewBadRequestError(scenario.message, scenario.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if string(result) != scenario.expected {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewForbiddenError(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"code":403,"message":"You are not allowed to perform this request.","data":{}}`},
|
||||
{"demo", "rawData_test", `{"code":403,"message":"Demo.","data":{}}`},
|
||||
{"demo", validation.Errors{"err1": errors.New("test error")}, `{"code":403,"message":"Demo.","data":{"err1":{"code":"validation_invalid_value","message":"Test error."}}}`},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
e := rest.NewForbiddenError(scenario.message, scenario.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if string(result) != scenario.expected {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUnauthorizedError(t *testing.T) {
|
||||
scenarios := []struct {
|
||||
message string
|
||||
data any
|
||||
expected string
|
||||
}{
|
||||
{"", nil, `{"code":401,"message":"Missing or invalid authentication token.","data":{}}`},
|
||||
{"demo", "rawData_test", `{"code":401,"message":"Demo.","data":{}}`},
|
||||
{"demo", validation.Errors{"err1": errors.New("test error")}, `{"code":401,"message":"Demo.","data":{"err1":{"code":"validation_invalid_value","message":"Test error."}}}`},
|
||||
}
|
||||
|
||||
for i, scenario := range scenarios {
|
||||
e := rest.NewUnauthorizedError(scenario.message, scenario.data)
|
||||
result, _ := json.Marshal(e)
|
||||
|
||||
if string(result) != scenario.expected {
|
||||
t.Errorf("(%d) Expected %v, got %v", i, scenario.expected, string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -23,7 +23,7 @@ func BindBody(c echo.Context, i interface{}) error {
|
||||
ctype := req.Header.Get(echo.HeaderContentType)
|
||||
switch {
|
||||
case strings.HasPrefix(ctype, echo.MIMEApplicationJSON):
|
||||
err := ReadJsonBodyCopy(c.Request(), i)
|
||||
err := CopyJsonBody(c.Request(), i)
|
||||
if err != nil {
|
||||
return echo.NewHTTPErrorWithInternal(http.StatusBadRequest, err, err.Error())
|
||||
}
|
||||
@@ -34,9 +34,9 @@ func BindBody(c echo.Context, i interface{}) error {
|
||||
}
|
||||
}
|
||||
|
||||
// ReadJsonBodyCopy reads the request body into i by
|
||||
// CopyJsonBody reads the request body into i by
|
||||
// creating a copy of `r.Body` to allow multiple reads.
|
||||
func ReadJsonBodyCopy(r *http.Request, i interface{}) error {
|
||||
func CopyJsonBody(r *http.Request, i interface{}) error {
|
||||
body := r.Body
|
||||
|
||||
// this usually shouldn't be needed because the Server calls close for us
|
||||
|
||||
@@ -75,14 +75,14 @@ func TestBindBody(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadJsonBodyCopy(t *testing.T) {
|
||||
func TestCopyJsonBody(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"test":"test123"}`))
|
||||
|
||||
// simulate multiple reads from the same request
|
||||
result1 := map[string]string{}
|
||||
rest.ReadJsonBodyCopy(req, &result1)
|
||||
rest.CopyJsonBody(req, &result1)
|
||||
result2 := map[string]string{}
|
||||
rest.ReadJsonBodyCopy(req, &result2)
|
||||
rest.CopyJsonBody(req, &result2)
|
||||
|
||||
if len(result1) == 0 {
|
||||
t.Error("Expected result1 to be filled")
|
||||
|
||||
+17
-19
@@ -1,15 +1,14 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"github.com/pocketbase/pocketbase/tools/inflector"
|
||||
"github.com/pocketbase/pocketbase/tools/security"
|
||||
)
|
||||
@@ -24,7 +23,6 @@ var extensionInvalidCharsRegex = regexp.MustCompile(`[^\w\.\*\-\+\=\#]+`)
|
||||
type UploadedFile struct {
|
||||
name string
|
||||
header *multipart.FileHeader
|
||||
bytes []byte
|
||||
}
|
||||
|
||||
// Name returns an assigned unique name to the uploaded file.
|
||||
@@ -37,11 +35,6 @@ func (f *UploadedFile) Header() *multipart.FileHeader {
|
||||
return f.header
|
||||
}
|
||||
|
||||
// Bytes returns a slice with the file content.
|
||||
func (f *UploadedFile) Bytes() []byte {
|
||||
return f.bytes
|
||||
}
|
||||
|
||||
// FindUploadedFiles extracts all form files of `key` from a http request
|
||||
// and returns a slice with `UploadedFile` instances (if any).
|
||||
func FindUploadedFiles(r *http.Request, key string) ([]*UploadedFile, error) {
|
||||
@@ -56,26 +49,32 @@ func FindUploadedFiles(r *http.Request, key string) ([]*UploadedFile, error) {
|
||||
return nil, http.ErrMissingFile
|
||||
}
|
||||
|
||||
result := make([]*UploadedFile, len(r.MultipartForm.File[key]))
|
||||
result := make([]*UploadedFile, 0, len(r.MultipartForm.File[key]))
|
||||
|
||||
for i, fh := range r.MultipartForm.File[key] {
|
||||
for _, fh := range r.MultipartForm.File[key] {
|
||||
file, err := fh.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
buf := bytes.NewBuffer(nil)
|
||||
if _, err := io.Copy(buf, file); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// extension
|
||||
// ---
|
||||
originalExt := filepath.Ext(fh.Filename)
|
||||
sanitizedExt := extensionInvalidCharsRegex.ReplaceAllString(originalExt, "")
|
||||
if sanitizedExt == "" {
|
||||
// try to detect the extension from the mime type
|
||||
mt, err := mimetype.DetectReader(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sanitizedExt = mt.Extension()
|
||||
}
|
||||
|
||||
// name
|
||||
// ---
|
||||
originalName := strings.TrimSuffix(fh.Filename, originalExt)
|
||||
sanitizedName := inflector.Snakecase(originalName)
|
||||
|
||||
if length := len(sanitizedName); length < 3 {
|
||||
// the name is too short so we concatenate an additional random part
|
||||
sanitizedName += ("_" + security.RandomString(10))
|
||||
@@ -91,11 +90,10 @@ func FindUploadedFiles(r *http.Request, key string) ([]*UploadedFile, error) {
|
||||
sanitizedExt,
|
||||
)
|
||||
|
||||
result[i] = &UploadedFile{
|
||||
result = append(result, &UploadedFile{
|
||||
name: uploadedFilename,
|
||||
header: fh,
|
||||
bytes: buf.Bytes(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
||||
@@ -2,11 +2,10 @@ package rest_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -14,58 +13,51 @@ import (
|
||||
)
|
||||
|
||||
func TestFindUploadedFiles(t *testing.T) {
|
||||
// create a test temporary file (with very large prefix to test if it will be truncated)
|
||||
tmpFile, err := os.CreateTemp(os.TempDir(), strings.Repeat("a", 150)+"tmpfile-*.txt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := tmpFile.Write([]byte("test")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tmpFile.Seek(0, 0)
|
||||
defer tmpFile.Close()
|
||||
defer os.Remove(tmpFile.Name())
|
||||
// ---
|
||||
|
||||
// stub multipart form file body
|
||||
body := new(bytes.Buffer)
|
||||
mp := multipart.NewWriter(body)
|
||||
w, err := mp.CreateFormFile("test", tmpFile.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := io.Copy(w, tmpFile); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mp.Close()
|
||||
// ---
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Add("Content-Type", mp.FormDataContentType())
|
||||
|
||||
result, err := rest.FindUploadedFiles(req, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
scenarios := []struct {
|
||||
filename string
|
||||
expectedPattern string
|
||||
}{
|
||||
{"ab.png", `^ab_\w{10}_\w{10}\.png$`},
|
||||
{"test", `^test_\w{10}\.txt$`},
|
||||
{"a b c d!@$.j!@$pg", `^a_b_c_d_\w{10}\.jpg$`},
|
||||
{strings.Repeat("a", 150), `^a{100}_\w{10}\.txt$`},
|
||||
}
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("Expected 1 file, got %d", len(result))
|
||||
}
|
||||
for i, s := range scenarios {
|
||||
// create multipart form file body
|
||||
body := new(bytes.Buffer)
|
||||
mp := multipart.NewWriter(body)
|
||||
w, err := mp.CreateFormFile("test", s.filename)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
w.Write([]byte("test"))
|
||||
mp.Close()
|
||||
// ---
|
||||
|
||||
if result[0].Header().Size != 4 {
|
||||
t.Fatalf("Expected the file size to be 4 bytes, got %d", result[0].Header().Size)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
req.Header.Add("Content-Type", mp.FormDataContentType())
|
||||
|
||||
if !strings.HasSuffix(result[0].Name(), ".txt") {
|
||||
t.Fatalf("Expected the file name to have suffix .txt, got %v", result[0].Name())
|
||||
}
|
||||
result, err := rest.FindUploadedFiles(req, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if length := len(result[0].Name()); length != 115 { // truncated + random part + ext
|
||||
t.Fatalf("Expected the file name to have length of 115, got %d\n%q", length, result[0].Name())
|
||||
}
|
||||
if len(result) != 1 {
|
||||
t.Errorf("[%d] Expected 1 file, got %d", i, len(result))
|
||||
}
|
||||
|
||||
if string(result[0].Bytes()) != "test" {
|
||||
t.Fatalf("Expected the file content to be %q, got %q", "test", string(result[0].Bytes()))
|
||||
if result[0].Header().Size != 4 {
|
||||
t.Errorf("[%d] Expected the file size to be 4 bytes, got %d", i, result[0].Header().Size)
|
||||
}
|
||||
|
||||
pattern, err := regexp.Compile(s.expectedPattern)
|
||||
if err != nil {
|
||||
t.Errorf("[%d] Invalid filename pattern %q: %v", i, s.expectedPattern, err)
|
||||
}
|
||||
if !pattern.MatchString(result[0].Name()) {
|
||||
t.Fatalf("Expected filename to match %s, got filename %s", s.expectedPattern, result[0].Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+38
-32
@@ -89,52 +89,39 @@ func (f FilterData) resolveTokenizedExpr(expr fexpr.Expr, fieldResolver FieldRes
|
||||
return nil, fmt.Errorf("Invalid right operand %q - %v.", expr.Right.Literal, rErr)
|
||||
}
|
||||
|
||||
// merge both operands parameters (if any)
|
||||
params := dbx.Params{}
|
||||
for k, v := range lParams {
|
||||
params[k] = v
|
||||
}
|
||||
for k, v := range rParams {
|
||||
params[k] = v
|
||||
}
|
||||
|
||||
switch expr.Op {
|
||||
case fexpr.SignEq:
|
||||
return dbx.NewExp(fmt.Sprintf("COALESCE(%s, '') = COALESCE(%s, '')", lName, rName), params), nil
|
||||
return dbx.NewExp(fmt.Sprintf("COALESCE(%s, '') = COALESCE(%s, '')", lName, rName), mergeParams(lParams, rParams)), nil
|
||||
case fexpr.SignNeq:
|
||||
return dbx.NewExp(fmt.Sprintf("COALESCE(%s, '') != COALESCE(%s, '')", lName, rName), params), nil
|
||||
return dbx.NewExp(fmt.Sprintf("COALESCE(%s, '') != COALESCE(%s, '')", lName, rName), mergeParams(lParams, rParams)), nil
|
||||
case fexpr.SignLike:
|
||||
// both sides are columns and therefore wrap the right side with "%" for contains like behavior
|
||||
if len(params) == 0 {
|
||||
return dbx.NewExp(fmt.Sprintf("%s LIKE ('%%' || %s || '%%')", lName, rName), params), nil
|
||||
// the right side is a column and therefor wrap it with "%" for contains like behavior
|
||||
if len(rParams) == 0 {
|
||||
return dbx.NewExp(fmt.Sprintf("%s LIKE ('%%' || %s || '%%')", lName, rName), lParams), nil
|
||||
}
|
||||
|
||||
// normalize operands and switch sides if the left operand is a number or text
|
||||
if len(lParams) > 0 {
|
||||
return dbx.NewExp(fmt.Sprintf("%s LIKE %s", rName, lName), f.normalizeLikeParams(params)), nil
|
||||
}
|
||||
|
||||
return dbx.NewExp(fmt.Sprintf("%s LIKE %s", lName, rName), f.normalizeLikeParams(params)), nil
|
||||
return dbx.NewExp(fmt.Sprintf("%s LIKE %s", lName, rName), mergeParams(lParams, wrapLikeParams(rParams))), nil
|
||||
case fexpr.SignNlike:
|
||||
// both sides are columns and therefore wrap the right side with "%" for not-contains like behavior
|
||||
if len(params) == 0 {
|
||||
return dbx.NewExp(fmt.Sprintf("%s NOT LIKE ('%%' || %s || '%%')", lName, rName), params), nil
|
||||
// the right side is a column and therefor wrap it with "%" for not-contains like behavior
|
||||
if len(rParams) == 0 {
|
||||
return dbx.NewExp(fmt.Sprintf("%s NOT LIKE ('%%' || %s || '%%')", lName, rName), lParams), nil
|
||||
}
|
||||
|
||||
// normalize operands and switch sides if the left operand is a number or text
|
||||
if len(lParams) > 0 {
|
||||
return dbx.NewExp(fmt.Sprintf("%s NOT LIKE %s", rName, lName), f.normalizeLikeParams(params)), nil
|
||||
// normalize operands and switch sides if the left operand is a number/text, but the right one is a column
|
||||
// (usually this shouldn't be needed, but it's kept for backward compatibility)
|
||||
if len(lParams) > 0 && len(rParams) == 0 {
|
||||
return dbx.NewExp(fmt.Sprintf("%s NOT LIKE %s", rName, lName), wrapLikeParams(lParams)), nil
|
||||
}
|
||||
|
||||
return dbx.NewExp(fmt.Sprintf("%s NOT LIKE %s", lName, rName), f.normalizeLikeParams(params)), nil
|
||||
return dbx.NewExp(fmt.Sprintf("%s NOT LIKE %s", lName, rName), mergeParams(lParams, wrapLikeParams(rParams))), nil
|
||||
case fexpr.SignLt:
|
||||
return dbx.NewExp(fmt.Sprintf("%s < %s", lName, rName), params), nil
|
||||
return dbx.NewExp(fmt.Sprintf("%s < %s", lName, rName), mergeParams(lParams, rParams)), nil
|
||||
case fexpr.SignLte:
|
||||
return dbx.NewExp(fmt.Sprintf("%s <= %s", lName, rName), params), nil
|
||||
return dbx.NewExp(fmt.Sprintf("%s <= %s", lName, rName), mergeParams(lParams, rParams)), nil
|
||||
case fexpr.SignGt:
|
||||
return dbx.NewExp(fmt.Sprintf("%s > %s", lName, rName), params), nil
|
||||
return dbx.NewExp(fmt.Sprintf("%s > %s", lName, rName), mergeParams(lParams, rParams)), nil
|
||||
case fexpr.SignGte:
|
||||
return dbx.NewExp(fmt.Sprintf("%s >= %s", lName, rName), params), nil
|
||||
return dbx.NewExp(fmt.Sprintf("%s >= %s", lName, rName), mergeParams(lParams, rParams)), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown expression operator %q", expr.Op)
|
||||
@@ -190,12 +177,31 @@ func (f FilterData) resolveToken(token fexpr.Token, fieldResolver FieldResolver)
|
||||
return "", nil, errors.New("Unresolvable token type.")
|
||||
}
|
||||
|
||||
func (f FilterData) normalizeLikeParams(params dbx.Params) dbx.Params {
|
||||
// mergeParams returns new dbx.Params where each provided params item
|
||||
// is merged in the order they are specified.
|
||||
func mergeParams(params ...dbx.Params) dbx.Params {
|
||||
result := dbx.Params{}
|
||||
|
||||
for _, p := range params {
|
||||
for k, v := range p {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// wrapLikeParams wraps each provided param value string with `%`
|
||||
// if the string doesn't contains the `%` char (including its escape sequence).
|
||||
func wrapLikeParams(params dbx.Params) dbx.Params {
|
||||
result := dbx.Params{}
|
||||
|
||||
for k, v := range params {
|
||||
vStr := cast.ToString(v)
|
||||
if !strings.Contains(vStr, "%") {
|
||||
for i := 0; i < len(dbx.DefaultLikeEscape); i += 2 {
|
||||
vStr = strings.ReplaceAll(vStr, dbx.DefaultLikeEscape[i], dbx.DefaultLikeEscape[i+1])
|
||||
}
|
||||
vStr = "%" + vStr + "%"
|
||||
}
|
||||
result[k] = vStr
|
||||
|
||||
@@ -38,8 +38,16 @@ func TestFilterDataBuildExpr(t *testing.T) {
|
||||
regexp.QuoteMeta("[[test1]] LIKE ('%' || [[test2]] || '%')") +
|
||||
"$",
|
||||
},
|
||||
// reversed like with text
|
||||
// like with right column operand
|
||||
{"'lorem' ~ test1", false,
|
||||
"^" +
|
||||
regexp.QuoteMeta("{:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("} LIKE ('%' || [[test1]] || '%')") +
|
||||
"$",
|
||||
},
|
||||
// like with left column operand and text as right operand
|
||||
{"test1 ~ 'lorem'", false,
|
||||
"^" +
|
||||
regexp.QuoteMeta("[[test1]] LIKE {:") +
|
||||
".+" +
|
||||
@@ -52,8 +60,16 @@ func TestFilterDataBuildExpr(t *testing.T) {
|
||||
regexp.QuoteMeta("[[test1]] NOT LIKE ('%' || [[test2]] || '%')") +
|
||||
"$",
|
||||
},
|
||||
// reversed not like with text
|
||||
// not like with right column operand
|
||||
{"'lorem' !~ test1", false,
|
||||
"^" +
|
||||
regexp.QuoteMeta("{:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("} NOT LIKE ('%' || [[test1]] || '%')") +
|
||||
"$",
|
||||
},
|
||||
// like with left column operand and text as right operand
|
||||
{"test1 !~ 'lorem'", false,
|
||||
"^" +
|
||||
regexp.QuoteMeta("[[test1]] NOT LIKE {:") +
|
||||
".+" +
|
||||
@@ -97,11 +113,11 @@ func TestFilterDataBuildExpr(t *testing.T) {
|
||||
".+" +
|
||||
regexp.QuoteMeta("}) OR ([[test2]] NOT LIKE {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("}))) AND ([[test1]] LIKE {:") +
|
||||
regexp.QuoteMeta("}))) AND ({:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})) AND ([[test2]] NOT LIKE {:") +
|
||||
regexp.QuoteMeta("} LIKE ('%' || [[test1]] || '%'))) AND ({:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})) AND ([[test3]] > {:") +
|
||||
regexp.QuoteMeta("} NOT LIKE ('%' || [[test2]] || '%'))) AND ([[test3]] > {:") +
|
||||
".+" +
|
||||
regexp.QuoteMeta("})) AND ([[test3]] >= {:") +
|
||||
".+" +
|
||||
|
||||
+14
-16
@@ -5,6 +5,7 @@ import (
|
||||
"math"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
@@ -13,7 +14,7 @@ import (
|
||||
const DefaultPerPage int = 30
|
||||
|
||||
// MaxPerPage specifies the maximum allowed search result items returned in a single page.
|
||||
const MaxPerPage int = 400
|
||||
const MaxPerPage int = 500
|
||||
|
||||
// url search query params
|
||||
const (
|
||||
@@ -38,7 +39,6 @@ type Provider struct {
|
||||
query *dbx.SelectQuery
|
||||
page int
|
||||
perPage int
|
||||
countColumn string
|
||||
sort []SortField
|
||||
filter []FilterData
|
||||
}
|
||||
@@ -53,7 +53,7 @@ type Provider struct {
|
||||
//
|
||||
// result, err := search.NewProvider(fieldResolver).
|
||||
// Query(baseQuery).
|
||||
// ParseAndExec("page=2&filter=id>0&sort=-name", &models)
|
||||
// ParseAndExec("page=2&filter=id>0&sort=-email", &models)
|
||||
func NewProvider(fieldResolver FieldResolver) *Provider {
|
||||
return &Provider{
|
||||
fieldResolver: fieldResolver,
|
||||
@@ -70,13 +70,6 @@ func (s *Provider) Query(query *dbx.SelectQuery) *Provider {
|
||||
return s
|
||||
}
|
||||
|
||||
// CountColumn specifies an optional distinct column to use in the
|
||||
// SELECT COUNT query.
|
||||
func (s *Provider) CountColumn(countColumn string) *Provider {
|
||||
s.countColumn = countColumn
|
||||
return s
|
||||
}
|
||||
|
||||
// Page sets the `page` field of the current search provider.
|
||||
//
|
||||
// Normalization on the `page` value is done during `Exec()`.
|
||||
@@ -170,7 +163,7 @@ func (s *Provider) Exec(items any) (*Result, error) {
|
||||
// clone provider's query
|
||||
modelsQuery := *s.query
|
||||
|
||||
// apply filters
|
||||
// build filters
|
||||
for _, f := range s.filter {
|
||||
expr, err := f.BuildExpr(s.fieldResolver)
|
||||
if err != nil {
|
||||
@@ -197,14 +190,19 @@ func (s *Provider) Exec(items any) (*Result, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
queryInfo := modelsQuery.Info()
|
||||
|
||||
// count
|
||||
var totalCount int64
|
||||
countQuery := modelsQuery
|
||||
countQuery.Distinct(false).Select("COUNT(*)").OrderBy() // unset ORDER BY statements
|
||||
if s.countColumn != "" {
|
||||
countQuery.Select("COUNT(DISTINCT(" + s.countColumn + "))")
|
||||
var baseTable string
|
||||
if len(queryInfo.From) > 0 {
|
||||
baseTable = queryInfo.From[0]
|
||||
}
|
||||
if err := countQuery.Row(&totalCount); err != nil {
|
||||
countQuery := modelsQuery
|
||||
rawCountQuery := countQuery.Select(strings.Join([]string{baseTable, "id"}, ".")).OrderBy().Build().SQL()
|
||||
wrappedCountQuery := queryInfo.Builder.NewQuery("SELECT COUNT(*) FROM (" + rawCountQuery + ")")
|
||||
wrappedCountQuery.Bind(countQuery.Build().Params())
|
||||
if err := wrappedCountQuery.Row(&totalCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -60,15 +60,6 @@ func TestProviderPerPage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderCountColumn(t *testing.T) {
|
||||
r := &testFieldResolver{}
|
||||
p := NewProvider(r).CountColumn("test")
|
||||
|
||||
if p.countColumn != "test" {
|
||||
t.Fatalf("Expected distinct count column %v, got %v", "test", p.countColumn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderSort(t *testing.T) {
|
||||
initialSort := []SortField{{"test1", SortAsc}, {"test2", SortAsc}}
|
||||
r := &testFieldResolver{}
|
||||
@@ -223,7 +214,6 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
perPage int
|
||||
sort []SortField
|
||||
filter []FilterData
|
||||
countColumn string
|
||||
expectError bool
|
||||
expectResult string
|
||||
expectQueries []string
|
||||
@@ -234,11 +224,10 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
10,
|
||||
[]SortField{},
|
||||
[]FilterData{},
|
||||
"",
|
||||
false,
|
||||
`{"page":1,"perPage":10,"totalItems":2,"totalPages":1,"items":[{"test1":1,"test2":"test2.1","test3":""},{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
[]string{
|
||||
"SELECT COUNT(*) FROM `test` WHERE NOT (`test1` IS NULL)",
|
||||
"SELECT COUNT(*) FROM (SELECT `test`.`id` FROM `test` WHERE NOT (`test1` IS NULL))",
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 10",
|
||||
},
|
||||
},
|
||||
@@ -248,11 +237,10 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
0, // fallback to default
|
||||
[]SortField{},
|
||||
[]FilterData{},
|
||||
"",
|
||||
false,
|
||||
`{"page":1,"perPage":30,"totalItems":2,"totalPages":1,"items":[{"test1":1,"test2":"test2.1","test3":""},{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
[]string{
|
||||
"SELECT COUNT(*) FROM `test` WHERE NOT (`test1` IS NULL)",
|
||||
"SELECT COUNT(*) FROM (SELECT `test`.`id` FROM `test` WHERE NOT (`test1` IS NULL))",
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 30",
|
||||
},
|
||||
},
|
||||
@@ -262,7 +250,6 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
10,
|
||||
[]SortField{{"unknown", SortAsc}},
|
||||
[]FilterData{},
|
||||
"",
|
||||
true,
|
||||
"",
|
||||
nil,
|
||||
@@ -273,7 +260,6 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
10,
|
||||
[]SortField{},
|
||||
[]FilterData{"test2 = 'test2.1'", "invalid"},
|
||||
"",
|
||||
true,
|
||||
"",
|
||||
nil,
|
||||
@@ -284,12 +270,11 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
5555, // will be limited by MaxPerPage
|
||||
[]SortField{{"test2", SortDesc}},
|
||||
[]FilterData{"test2 != null", "test1 >= 2"},
|
||||
"",
|
||||
false,
|
||||
`{"page":1,"perPage":` + fmt.Sprint(MaxPerPage) + `,"totalItems":1,"totalPages":1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
[]string{
|
||||
"SELECT COUNT(*) FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (COALESCE(test2, '') != COALESCE(null, ''))) AND (test1 >= 2)",
|
||||
"SELECT * FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (COALESCE(test2, '') != COALESCE(null, ''))) AND (test1 >= 2) ORDER BY `test1` ASC, `test2` DESC LIMIT 400",
|
||||
"SELECT COUNT(*) FROM (SELECT `test`.`id` FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (COALESCE(test2, '') != COALESCE(null, ''))) AND (test1 >= 2))",
|
||||
"SELECT * FROM `test` WHERE ((NOT (`test1` IS NULL)) AND (COALESCE(test2, '') != COALESCE(null, ''))) AND (test1 >= 2) ORDER BY `test1` ASC, `test2` DESC LIMIT 500",
|
||||
},
|
||||
},
|
||||
// valid sort and filter fields (zero results)
|
||||
@@ -298,11 +283,10 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
10,
|
||||
[]SortField{{"test3", SortAsc}},
|
||||
[]FilterData{"test3 != ''"},
|
||||
"",
|
||||
false,
|
||||
`{"page":1,"perPage":10,"totalItems":0,"totalPages":0,"items":[]}`,
|
||||
[]string{
|
||||
"SELECT COUNT(*) FROM `test` WHERE (NOT (`test1` IS NULL)) AND (COALESCE(test3, '') != COALESCE('', ''))",
|
||||
"SELECT COUNT(*) FROM (SELECT `test`.`id` FROM `test` WHERE (NOT (`test1` IS NULL)) AND (COALESCE(test3, '') != COALESCE('', '')))",
|
||||
"SELECT * FROM `test` WHERE (NOT (`test1` IS NULL)) AND (COALESCE(test3, '') != COALESCE('', '')) ORDER BY `test1` ASC, `test3` ASC LIMIT 10",
|
||||
},
|
||||
},
|
||||
@@ -312,25 +296,10 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
1,
|
||||
[]SortField{},
|
||||
[]FilterData{},
|
||||
"",
|
||||
false,
|
||||
`{"page":2,"perPage":1,"totalItems":2,"totalPages":2,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
[]string{
|
||||
"SELECT COUNT(*) FROM `test` WHERE NOT (`test1` IS NULL)",
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 1 OFFSET 1",
|
||||
},
|
||||
},
|
||||
// distinct count column
|
||||
{
|
||||
3,
|
||||
1,
|
||||
[]SortField{},
|
||||
[]FilterData{},
|
||||
"test.test1",
|
||||
false,
|
||||
`{"page":2,"perPage":1,"totalItems":2,"totalPages":2,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
[]string{
|
||||
"SELECT COUNT(DISTINCT(test.test1)) FROM `test` WHERE NOT (`test1` IS NULL)",
|
||||
"SELECT COUNT(*) FROM (SELECT `test`.`id` FROM `test` WHERE NOT (`test1` IS NULL))",
|
||||
"SELECT * FROM `test` WHERE NOT (`test1` IS NULL) ORDER BY `test1` ASC LIMIT 1 OFFSET 1",
|
||||
},
|
||||
},
|
||||
@@ -345,8 +314,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
Page(s.page).
|
||||
PerPage(s.perPage).
|
||||
Sort(s.sort).
|
||||
Filter(s.filter).
|
||||
CountColumn(s.countColumn)
|
||||
Filter(s.filter)
|
||||
|
||||
result, err := p.Exec(&[]testTableStruct{})
|
||||
|
||||
@@ -376,7 +344,7 @@ func TestProviderExecNonEmptyQuery(t *testing.T) {
|
||||
|
||||
for _, q := range testDB.CalledQueries {
|
||||
if !list.ExistInSliceWithRegex(q, s.expectQueries) {
|
||||
t.Errorf("(%d) Didn't expect query \n%v", i, q)
|
||||
t.Errorf("(%d) Didn't expect query \n%v in \n%v", i, q, testDB.CalledQueries)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -439,7 +407,7 @@ func TestProviderParseAndExec(t *testing.T) {
|
||||
{
|
||||
"page=3&perPage=9999&filter=test1>1&sort=-test2,test3",
|
||||
false,
|
||||
`{"page":1,"perPage":400,"totalItems":1,"totalPages":1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
`{"page":1,"perPage":500,"totalItems":1,"totalPages":1,"items":[{"test1":2,"test2":"test2.2","test3":""}]}`,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -504,9 +472,9 @@ func createTestDB() (*testDB, error) {
|
||||
}
|
||||
|
||||
db := testDB{DB: dbx.NewFromDB(sqlDB, "sqlite")}
|
||||
db.CreateTable("test", map[string]string{"test1": "int default 0", "test2": "text default ''", "test3": "text default ''"}).Execute()
|
||||
db.Insert("test", dbx.Params{"test1": 1, "test2": "test2.1"}).Execute()
|
||||
db.Insert("test", dbx.Params{"test1": 2, "test2": "test2.2"}).Execute()
|
||||
db.CreateTable("test", map[string]string{"id": "int default 0", "test1": "int default 0", "test2": "text default ''", "test3": "text default ''"}).Execute()
|
||||
db.Insert("test", dbx.Params{"id": 1, "test1": 1, "test2": "test2.1"}).Execute()
|
||||
db.Insert("test", dbx.Params{"id": 2, "test1": 2, "test2": "test2.2"}).Execute()
|
||||
db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
|
||||
db.CalledQueries = append(db.CalledQueries, sql)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
|
||||
// ParseUnverifiedJWT parses JWT token and returns its claims
|
||||
// but DOES NOT verify the signature.
|
||||
//
|
||||
// It verifies only the exp, iat and nbf claims.
|
||||
func ParseUnverifiedJWT(token string) (jwt.MapClaims, error) {
|
||||
claims := jwt.MapClaims{}
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ func TestRandomStringWithAlphabet(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
generated := make([]string, 100)
|
||||
generated := make([]string, 0, 100)
|
||||
length := 10
|
||||
|
||||
for j := 0; j < 100; j++ {
|
||||
|
||||
@@ -33,8 +33,8 @@ func (s *Store[T]) Remove(key string) {
|
||||
|
||||
// Has checks if element with the specified key exist or not.
|
||||
func (s *Store[T]) Has(key string) bool {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
s.mux.RLock()
|
||||
defer s.mux.RUnlock()
|
||||
|
||||
_, ok := s.data[key]
|
||||
|
||||
@@ -45,8 +45,8 @@ func (s *Store[T]) Has(key string) bool {
|
||||
//
|
||||
// If key is not set, the zero T value is returned.
|
||||
func (s *Store[T]) Get(key string) T {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
s.mux.RLock()
|
||||
defer s.mux.RUnlock()
|
||||
|
||||
return s.data[key]
|
||||
}
|
||||
|
||||
@@ -20,6 +20,9 @@ func NewBroker() *Broker {
|
||||
|
||||
// Clients returns all registered clients.
|
||||
func (b *Broker) Clients() map[string]Client {
|
||||
b.mux.RLock()
|
||||
defer b.mux.RUnlock()
|
||||
|
||||
return b.clients
|
||||
}
|
||||
|
||||
|
||||
@@ -61,25 +61,25 @@ func NewDefaultClient() *DefaultClient {
|
||||
}
|
||||
}
|
||||
|
||||
// Id implements the Client.Id interface method.
|
||||
// Id implements the [Client.Id] interface method.
|
||||
func (c *DefaultClient) Id() string {
|
||||
return c.id
|
||||
}
|
||||
|
||||
// Channel implements the Client.Channel interface method.
|
||||
// Channel implements the [Client.Channel] interface method.
|
||||
func (c *DefaultClient) Channel() chan Message {
|
||||
return c.channel
|
||||
}
|
||||
|
||||
// Subscriptions implements the Client.Subscriptions interface method.
|
||||
// Subscriptions implements the [Client.Subscriptions] interface method.
|
||||
func (c *DefaultClient) Subscriptions() map[string]struct{} {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
c.mux.RLock()
|
||||
defer c.mux.RUnlock()
|
||||
|
||||
return c.subscriptions
|
||||
}
|
||||
|
||||
// Subscribe implements the Client.Subscribe interface method.
|
||||
// Subscribe implements the [Client.Subscribe] interface method.
|
||||
//
|
||||
// Empty subscriptions (aka. "") are ignored.
|
||||
func (c *DefaultClient) Subscribe(subs ...string) {
|
||||
@@ -95,7 +95,7 @@ func (c *DefaultClient) Subscribe(subs ...string) {
|
||||
}
|
||||
}
|
||||
|
||||
// Unsubscribe implements the Client.Unsubscribe interface method.
|
||||
// Unsubscribe implements the [Client.Unsubscribe] interface method.
|
||||
//
|
||||
// If subs is not set, this method removes all registered client's subscriptions.
|
||||
func (c *DefaultClient) Unsubscribe(subs ...string) {
|
||||
@@ -114,25 +114,25 @@ func (c *DefaultClient) Unsubscribe(subs ...string) {
|
||||
}
|
||||
}
|
||||
|
||||
// HasSubscription implements the Client.HasSubscription interface method.
|
||||
// HasSubscription implements the [Client.HasSubscription] interface method.
|
||||
func (c *DefaultClient) HasSubscription(sub string) bool {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
c.mux.RLock()
|
||||
defer c.mux.RUnlock()
|
||||
|
||||
_, ok := c.subscriptions[sub]
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// Get implements the Client.Get interface method.
|
||||
// Get implements the [Client.Get] interface method.
|
||||
func (c *DefaultClient) Get(key string) any {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
c.mux.RLock()
|
||||
defer c.mux.RUnlock()
|
||||
|
||||
return c.store[key]
|
||||
}
|
||||
|
||||
// Set implements the Client.Set interface method.
|
||||
// Set implements the [Client.Set] interface method.
|
||||
func (c *DefaultClient) Set(key string, value any) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
// DefaultDateLayout specifies the default app date strings layout.
|
||||
const DefaultDateLayout = "2006-01-02 15:04:05.000"
|
||||
const DefaultDateLayout = "2006-01-02 15:04:05.000Z"
|
||||
|
||||
// NowDateTime returns new DateTime instance with the current local time.
|
||||
func NowDateTime() DateTime {
|
||||
|
||||
@@ -31,8 +31,8 @@ func TestParseDateTime(t *testing.T) {
|
||||
{"invalid", ""},
|
||||
{nowDateTime, nowStr},
|
||||
{nowTime, nowStr},
|
||||
{1641024040, "2022-01-01 08:00:40.000"},
|
||||
{"2022-01-01 11:23:45.678", "2022-01-01 11:23:45.678"},
|
||||
{1641024040, "2022-01-01 08:00:40.000Z"},
|
||||
{"2022-01-01 11:23:45.678", "2022-01-01 11:23:45.678Z"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
@@ -49,7 +49,7 @@ func TestParseDateTime(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDateTimeTime(t *testing.T) {
|
||||
str := "2022-01-01 11:23:45.678"
|
||||
str := "2022-01-01 11:23:45.678Z"
|
||||
|
||||
expected, err := time.Parse(types.DefaultDateLayout, str)
|
||||
if err != nil {
|
||||
@@ -86,7 +86,7 @@ func TestDateTimeString(t *testing.T) {
|
||||
t.Fatalf("Expected empty string for zer datetime, got %q", dt0.String())
|
||||
}
|
||||
|
||||
expected := "2022-01-01 11:23:45.678"
|
||||
expected := "2022-01-01 11:23:45.678Z"
|
||||
dt1, _ := types.ParseDateTime(expected)
|
||||
if dt1.String() != expected {
|
||||
t.Fatalf("Expected %q, got %v", expected, dt1)
|
||||
@@ -99,7 +99,7 @@ func TestDateTimeMarshalJSON(t *testing.T) {
|
||||
expected string
|
||||
}{
|
||||
{"", `""`},
|
||||
{"2022-01-01 11:23:45.678", `"2022-01-01 11:23:45.678"`},
|
||||
{"2022-01-01 11:23:45.678", `"2022-01-01 11:23:45.678Z"`},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
@@ -128,7 +128,7 @@ func TestDateTimeUnmarshalJSON(t *testing.T) {
|
||||
{"invalid_json", ""},
|
||||
{"'123'", ""},
|
||||
{"2022-01-01 11:23:45.678", ""},
|
||||
{`"2022-01-01 11:23:45.678"`, "2022-01-01 11:23:45.678"},
|
||||
{`"2022-01-01 11:23:45.678"`, "2022-01-01 11:23:45.678Z"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
@@ -148,8 +148,8 @@ func TestDateTimeValue(t *testing.T) {
|
||||
}{
|
||||
{"", ""},
|
||||
{"invalid", ""},
|
||||
{1641024040, "2022-01-01 08:00:40.000"},
|
||||
{"2022-01-01 11:23:45.678", "2022-01-01 11:23:45.678"},
|
||||
{1641024040, "2022-01-01 08:00:40.000Z"},
|
||||
{"2022-01-01 11:23:45.678", "2022-01-01 11:23:45.678Z"},
|
||||
{types.NowDateTime(), types.NowDateTime().String()},
|
||||
}
|
||||
|
||||
@@ -179,8 +179,8 @@ func TestDateTimeScan(t *testing.T) {
|
||||
{"invalid", ""},
|
||||
{types.NowDateTime(), now},
|
||||
{time.Now(), now},
|
||||
{1641024040, "2022-01-01 08:00:40.000"},
|
||||
{"2022-01-01 11:23:45.678", "2022-01-01 11:23:45.678"},
|
||||
{1641024040, "2022-01-01 08:00:40.000Z"},
|
||||
{"2022-01-01 11:23:45.678", "2022-01-01 11:23:45.678Z"},
|
||||
}
|
||||
|
||||
for i, s := range scenarios {
|
||||
|
||||
@@ -23,10 +23,6 @@ func (m JsonArray) MarshalJSON() ([]byte, error) {
|
||||
|
||||
// Value implements the [driver.Valuer] interface.
|
||||
func (m JsonArray) Value() (driver.Value, error) {
|
||||
if m == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
|
||||
return string(data), err
|
||||
|
||||
@@ -36,7 +36,7 @@ func TestJsonArrayValue(t *testing.T) {
|
||||
json types.JsonArray
|
||||
expected driver.Value
|
||||
}{
|
||||
{nil, nil},
|
||||
{nil, `[]`},
|
||||
{types.JsonArray{}, `[]`},
|
||||
{types.JsonArray{1, 2, 3}, `[1,2,3]`},
|
||||
{types.JsonArray{"test1", "test2", "test3"}, `["test1","test2","test3"]`},
|
||||
|
||||
@@ -23,10 +23,6 @@ func (m JsonMap) MarshalJSON() ([]byte, error) {
|
||||
|
||||
// Value implements the [driver.Valuer] interface.
|
||||
func (m JsonMap) Value() (driver.Value, error) {
|
||||
if m == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
|
||||
return string(data), err
|
||||
|
||||
@@ -35,7 +35,7 @@ func TestJsonMapValue(t *testing.T) {
|
||||
json types.JsonMap
|
||||
expected driver.Value
|
||||
}{
|
||||
{nil, nil},
|
||||
{nil, `{}`},
|
||||
{types.JsonMap{}, `{}`},
|
||||
{types.JsonMap{"test1": 123, "test2": "lorem"}, `{"test1":123,"test2":"lorem"}`},
|
||||
{types.JsonMap{"test": []int{1, 2, 3}}, `{"test":[1,2,3]}`},
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
// Package types implements some commonly used db serializable types
|
||||
// like datetime, json, etc.
|
||||
package types
|
||||
|
||||
// Pointer is a generic helper that returns val as *T.
|
||||
func Pointer[T any](val T) *T {
|
||||
return &val
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pocketbase/pocketbase/tools/types"
|
||||
)
|
||||
|
||||
func TestPointer(t *testing.T) {
|
||||
s1 := types.Pointer("")
|
||||
if s1 == nil || *s1 != "" {
|
||||
t.Fatalf("Expected empty string pointer, got %#v", s1)
|
||||
}
|
||||
|
||||
s2 := types.Pointer("test")
|
||||
if s2 == nil || *s2 != "test" {
|
||||
t.Fatalf("Expected 'test' string pointer, got %#v", s2)
|
||||
}
|
||||
|
||||
s3 := types.Pointer(123)
|
||||
if s3 == nil || *s3 != 123 {
|
||||
t.Fatalf("Expected 123 string pointer, got %#v", s3)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
const { exec } = require('node:child_process');
|
||||
|
||||
// you can use any other library for copying directories recursively
|
||||
const fse = require('fs-extra');
|
||||
|
||||
let controller; // this will be used to terminate the PocketBase process
|
||||
|
||||
const srcTestDirPath = "./test_pb_data";
|
||||
const tempTestDirPath = "./temp_test_pb_data";
|
||||
|
||||
beforeEach(() => {
|
||||
// copy test_pb_date to a temp location
|
||||
fse.copySync(srcTestDirPath, tempTestDirPath);
|
||||
|
||||
controller = new AbortController();
|
||||
|
||||
// start PocketBase with the test_pb_data
|
||||
exec('./pocketbase serve --dir=' + tempTestDirPath, { signal: controller.signal});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// stop the PocketBase process
|
||||
controller.abort();
|
||||
|
||||
// clean up the temp test directory
|
||||
fse.removeSync(tempTestDirPath);
|
||||
});
|
||||
Reference in New Issue
Block a user