diff --git a/pagination/keysetpagination/header.go b/pagination/keysetpagination/header.go index 30480a6c..48214148 100644 --- a/pagination/keysetpagination/header.go +++ b/pagination/keysetpagination/header.go @@ -83,18 +83,26 @@ func header(u *url.URL, rel, token string, size int) string { // It contains links to the first and next page, if one exists. func Header(w http.ResponseWriter, u *url.URL, p *Paginator) { size := p.Size() - w.Header().Set("Link", header(u, "first", p.defaultToken, size)) + w.Header().Set("Link", header(u, "first", p.defaultToken.Encode(), size)) if !p.IsLast() { - w.Header().Add("Link", header(u, "next", p.Token(), size)) + w.Header().Add("Link", header(u, "next", p.Token().Encode(), size)) } } // Parse returns the pagination options from the URL query. -func Parse(q url.Values) ([]Option, error) { +func Parse(q url.Values, p PageTokenConstructor) ([]Option, error) { var opts []Option if q.Has("page_token") { - opts = append(opts, WithToken(q.Get("page_token"))) + pageToken, err := url.QueryUnescape(q.Get("page_token")) + if err != nil { + return nil, errors.WithStack(err) + } + parsed, err := p(pageToken) + if err != nil { + return nil, errors.WithStack(err) + } + opts = append(opts, WithToken(parsed)) } if q.Has("page_size") { size, err := strconv.Atoi(q.Get("page_size")) diff --git a/pagination/keysetpagination/header_test.go b/pagination/keysetpagination/header_test.go new file mode 100644 index 00000000..df4503fb --- /dev/null +++ b/pagination/keysetpagination/header_test.go @@ -0,0 +1,44 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package keysetpagination + +import ( + "net/http/httptest" + "net/url" + "testing" + + "github.com/instana/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHeader(t *testing.T) { + p := &Paginator{ + defaultToken: StringPageToken("default"), + token: StringPageToken("next"), + size: 2, + } + + u, err := url.Parse("http://ory.sh/") + require.NoError(t, err) + + r := httptest.NewRecorder() + + Header(r, u, p) + + links := r.HeaderMap["Link"] + require.Len(t, links, 2) + assert.Contains(t, links[0], "page_token=default") + assert.Contains(t, links[1], "page_token=next") + + t.Run("with isLast", func(t *testing.T) { + p.isLast = true + + Header(r, u, p) + + links := r.HeaderMap["Link"] + require.Len(t, links, 1) + assert.Contains(t, links[0], "page_token=default") + }) + +} diff --git a/pagination/keysetpagination/page_token.go b/pagination/keysetpagination/page_token.go new file mode 100644 index 00000000..1beb3755 --- /dev/null +++ b/pagination/keysetpagination/page_token.go @@ -0,0 +1,74 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package keysetpagination + +import ( + "encoding/base64" + "fmt" + "strings" +) + +type PageToken interface { + Parse(string) map[string]string + Encode() string +} + +var _ PageToken = new(StringPageToken) +var _ PageToken = new(MapPageToken) + +type StringPageToken string + +func (s StringPageToken) Parse(idField string) map[string]string { + return map[string]string{idField: string(s)} +} + +func (s StringPageToken) Encode() string { + return string(s) +} + +func NewStringPageToken(s string) (PageToken, error) { + return StringPageToken(s), nil +} + +type MapPageToken map[string]string + +func (m MapPageToken) Parse(_ string) map[string]string { + return map[string]string(m) +} + +const pageTokenColumnDelim = "/" + +func (m MapPageToken) Encode() string { + elems := make([]string, 0, len(m)) + for k, v := range m { + elems = append(elems, fmt.Sprintf("%s=%s", k, v)) + } + + // For now: use Base64 instead of URL escaping, as the Timestamp format we need to use can contain a `+` sign, + // which represents a space in URLs, so it's not properly encoded by the Go library. + return base64.RawStdEncoding.EncodeToString([]byte(strings.Join(elems, pageTokenColumnDelim))) +} + +func NewMapPageToken(s string) (PageToken, error) { + b, err := base64.RawStdEncoding.DecodeString(s) + if err != nil { + return nil, err + } + tokens := strings.Split(string(b), pageTokenColumnDelim) + + r := map[string]string{} + + for _, p := range tokens { + if columnName, value, found := strings.Cut(p, "="); found { + r[columnName] = value + } + } + + return MapPageToken(r), nil +} + +var _ PageTokenConstructor = NewMapPageToken +var _ PageTokenConstructor = NewStringPageToken + +type PageTokenConstructor = func(string) (PageToken, error) diff --git a/pagination/keysetpagination/paginator.go b/pagination/keysetpagination/paginator.go index d8483871..20675ba7 100644 --- a/pagination/keysetpagination/paginator.go +++ b/pagination/keysetpagination/paginator.go @@ -4,23 +4,51 @@ package keysetpagination import ( + "errors" "fmt" "github.com/gobuffalo/pop/v6" + "github.com/gobuffalo/pop/v6/columns" ) type ( - Item interface{ PageToken() string } + Item interface{ PageToken() PageToken } + + Order string + + columnOrdering struct { + name string + order Order + } Paginator struct { - token, defaultToken string + token, defaultToken PageToken size, defaultSize, maxSize int isLast bool + additionalColumn columnOrdering } Option func(*Paginator) *Paginator ) -func (p *Paginator) Token() string { - if p.token == "" { +var ErrUnknownOrder = errors.New("unknown order") + +const ( + OrderDescending Order = "DESC" + OrderAscending Order = "ASC" +) + +func (o Order) extract() (string, string, error) { + switch o { + case OrderAscending: + return ">", string(o), nil + case OrderDescending: + return "<", string(o), nil + default: + return "", "", ErrUnknownOrder + } +} + +func (p *Paginator) Token() PageToken { + if p.token == nil { return p.defaultToken } return p.token @@ -51,22 +79,58 @@ func (p *Paginator) ToOptions() []Option { WithDefaultToken(p.defaultToken), WithDefaultSize(p.defaultSize), WithMaxSize(p.maxSize), + WithColumn(p.additionalColumn.name, p.additionalColumn.order), withIsLast(p.isLast), } } +func (p *Paginator) multipleOrderFieldsQuery(q *pop.Query, idField string, cols map[string]*columns.Column, quote func(string) string) { + tokenParts := p.Token().Parse(idField) + idValue := tokenParts[idField] + + column, ok := cols[p.additionalColumn.name] + if !ok { + q.Where(fmt.Sprintf(`%s > ?`, quote(idField)), idValue) + return + } + + quoteName := quote(column.Name) + + value, ok := tokenParts[column.Name] + + if !ok { + q.Where(fmt.Sprintf(`%s > ?`, quote(idField)), idValue) + return + } + + sign, keyword, err := p.additionalColumn.order.extract() + if err != nil { + q.Where(fmt.Sprintf(`%s > ?`, quote(idField)), idValue) + return + } + + q. + Where(fmt.Sprintf("%s %s ? OR (%s = ? AND %s > ?)", quoteName, sign, quoteName, quote(idField)), value, value, idValue). + Order(fmt.Sprintf("%s %s", quoteName, keyword)) + +} + // Paginate returns a function that paginates a pop.Query. // Usage: // // q := c.Where("foo = ?", foo).Scope(keysetpagination.Paginate[Item](paginator)) func Paginate[I Item](p *Paginator) pop.ScopeFunc { var item I - id := (&pop.Model{Value: item}).IDField() + model := &pop.Model{Value: item} + id := model.IDField() return func(q *pop.Query) *pop.Query { eid := q.Connection.Dialect.Quote(id) + + p.multipleOrderFieldsQuery(q, id, model.Columns().Cols, q.Connection.Dialect.Quote) + return q. - Limit(p.Size()+1). - Where(fmt.Sprintf(`%s > ?`, eid), p.Token()). + Limit(p.Size() + 1). + // we always need to order by the id field last Order(fmt.Sprintf(`%s ASC`, eid)) } } @@ -92,7 +156,7 @@ func Result[I Item](items []I, p *Paginator) ([]I, *Paginator) { } } -func WithDefaultToken(t string) Option { +func WithDefaultToken(t PageToken) Option { return func(opts *Paginator) *Paginator { opts.defaultToken = t return opts @@ -113,7 +177,7 @@ func WithMaxSize(size int) Option { } } -func WithToken(t string) Option { +func WithToken(t PageToken) Option { return func(opts *Paginator) *Paginator { opts.token = t return opts @@ -127,6 +191,16 @@ func WithSize(size int) Option { } } +func WithColumn(name string, order Order) Option { + return func(opts *Paginator) *Paginator { + opts.additionalColumn = columnOrdering{ + name: name, + order: order, + } + return opts + } +} + func withIsLast(isLast bool) Option { return func(opts *Paginator) *Paginator { opts.isLast = isLast diff --git a/pagination/keysetpagination/paginator_test.go b/pagination/keysetpagination/paginator_test.go index 3b63af43..e57f8740 100644 --- a/pagination/keysetpagination/paginator_test.go +++ b/pagination/keysetpagination/paginator_test.go @@ -14,11 +14,12 @@ import ( ) type testItem struct { - ID string `db:"pk"` + ID string `db:"pk"` + CreatedAt string `db:"created_at"` } -func (t testItem) PageToken() string { - return t.ID +func (t testItem) PageToken() PageToken { + return StringPageToken(t.ID) } func TestPaginator(t *testing.T) { @@ -28,11 +29,11 @@ func TestPaginator(t *testing.T) { }) require.NoError(t, err) q := pop.Q(c) - paginator := GetPaginator(WithSize(10), WithToken("token")) + paginator := GetPaginator(WithSize(10), WithToken(StringPageToken("token"))) q = q.Scope(Paginate[testItem](paginator)) sql, args := q.ToSQL(&pop.Model{Value: new(testItem)}) - assert.Equal(t, "SELECT test_items.pk FROM test_items AS test_items WHERE \"pk\" > $1 ORDER BY \"pk\" ASC LIMIT 11", sql) + assert.Equal(t, "SELECT test_items.created_at, test_items.pk FROM test_items AS test_items WHERE \"pk\" > $1 ORDER BY \"pk\" ASC LIMIT 11", sql) assert.Equal(t, []interface{}{"token"}, args) }) @@ -42,11 +43,11 @@ func TestPaginator(t *testing.T) { }) require.NoError(t, err) q := pop.Q(c) - paginator := GetPaginator(WithSize(10), WithToken("token")) + paginator := GetPaginator(WithSize(10), WithToken(StringPageToken("token"))) q = q.Scope(Paginate[testItem](paginator)) sql, args := q.ToSQL(&pop.Model{Value: new(testItem)}) - assert.Equal(t, "SELECT test_items.pk FROM test_items AS test_items WHERE `pk` > ? ORDER BY `pk` ASC LIMIT 11", sql) + assert.Equal(t, "SELECT test_items.created_at, test_items.pk FROM test_items AS test_items WHERE `pk` > ? ORDER BY `pk` ASC LIMIT 11", sql) assert.Equal(t, []interface{}{"token"}, args) }) @@ -64,10 +65,10 @@ func TestPaginator(t *testing.T) { {ID: "10"}, {ID: "11"}, } - paginator := GetPaginator(WithDefaultSize(10), WithToken("token")) + paginator := GetPaginator(WithDefaultSize(10), WithToken(StringPageToken("token"))) items, nextPage := Result(items, paginator) assert.Len(t, items, 10) - assert.Equal(t, "10", nextPage.Token()) + assert.Equal(t, StringPageToken("10"), nextPage.Token()) assert.Equal(t, 10, nextPage.Size()) }) @@ -76,31 +77,30 @@ func TestPaginator(t *testing.T) { name string opts []Option expectedSize int - expectedToken string + expectedToken PageToken }{ { - name: "default", - opts: nil, - expectedSize: 100, - expectedToken: "", + name: "default", + opts: nil, + expectedSize: 100, }, { name: "with size and token", - opts: []Option{WithSize(10), WithToken("token")}, + opts: []Option{WithSize(10), WithToken(StringPageToken("token"))}, expectedSize: 10, - expectedToken: "token", + expectedToken: StringPageToken("token"), }, { name: "with custom defaults", - opts: []Option{WithDefaultSize(10), WithDefaultToken("token")}, + opts: []Option{WithDefaultSize(10), WithDefaultToken(StringPageToken("token"))}, expectedSize: 10, - expectedToken: "token", + expectedToken: StringPageToken("token"), }, { name: "with custom defaults and size and token", - opts: []Option{WithDefaultSize(10), WithDefaultToken("token"), WithSize(20), WithToken("token2")}, + opts: []Option{WithDefaultSize(10), WithDefaultToken(StringPageToken("token")), WithSize(20), WithToken(StringPageToken("token2"))}, expectedSize: 20, - expectedToken: "token2", + expectedToken: StringPageToken("token2"), }, { name: "with size and custom default and max size", @@ -122,28 +122,39 @@ func TestParse(t *testing.T) { name string q url.Values expectedSize int - expectedToken string + expectedToken PageToken + f PageTokenConstructor }{ { name: "with page token", q: url.Values{"page_token": {"token3"}}, expectedSize: 100, - expectedToken: "token3", + expectedToken: StringPageToken("token3"), + f: NewStringPageToken, }, { name: "with page size", q: url.Values{"page_size": {"123"}}, expectedSize: 123, + f: NewStringPageToken, }, { name: "with page size and page token", q: url.Values{"page_size": {"123"}, "page_token": {"token5"}}, expectedSize: 123, - expectedToken: "token5", + expectedToken: StringPageToken("token5"), + f: NewStringPageToken, + }, + { + name: "with page size and page token", + q: url.Values{"page_size": {"123"}, "page_token": {"cGs9dG9rZW41"}}, + expectedSize: 123, + expectedToken: MapPageToken{"pk": "token5"}, + f: NewMapPageToken, }, } { t.Run(tc.name, func(t *testing.T) { - opts, err := Parse(tc.q) + opts, err := Parse(tc.q, tc.f) require.NoError(t, err) paginator := GetPaginator(opts...) assert.Equal(t, tc.expectedSize, paginator.Size()) @@ -152,7 +163,63 @@ func TestParse(t *testing.T) { } t.Run("invalid page size leads to err", func(t *testing.T) { - _, err := Parse(url.Values{"page_size": {"invalid-int"}}) + _, err := Parse(url.Values{"page_size": {"invalid-int"}}, NewStringPageToken) require.ErrorIs(t, err, strconv.ErrSyntax) }) } + +func TestPaginateWithAdditionalColumn(t *testing.T) { + c, err := pop.NewConnection(&pop.ConnectionDetails{ + URL: "postgres://foo.bar", + }) + require.NoError(t, err) + + for _, tc := range []struct { + d string + opts []Option + e string + args []interface{} + }{ + { + d: "with sort by created_at DESC", + opts: []Option{WithToken(MapPageToken{"pk": "token_value", "created_at": "timestamp"}), WithColumn("created_at", "DESC")}, + e: `WHERE "created_at" < $1 OR ("created_at" = $2 AND "pk" > $3) ORDER BY "created_at" DESC, "pk" ASC`, + args: []interface{}{"timestamp", "timestamp", "token_value"}, + }, + { + d: "with sort by created_at ASC", + opts: []Option{WithToken(MapPageToken{"pk": "token_value", "created_at": "timestamp"}), WithColumn("created_at", "ASC")}, + e: `WHERE "created_at" > $1 OR ("created_at" = $2 AND "pk" > $3) ORDER BY "created_at" ASC, "pk" ASC`, + args: []interface{}{"timestamp", "timestamp", "token_value"}, + }, + { + d: "with unknown column", + opts: []Option{WithToken(MapPageToken{"pk": "token_value", "created_at": "timestamp"}), WithColumn("unknown_column", "ASC")}, + e: `WHERE "pk" > $1 ORDER BY "pk"`, + args: []interface{}{"token_value"}, + }, + { + d: "with no token value", + opts: []Option{WithToken(MapPageToken{"pk": "token_value"}), WithColumn("created_at", "ASC")}, + e: `WHERE "pk" > $1 ORDER BY "pk"`, + args: []interface{}{"token_value"}, + }, + { + d: "with unknown order", + opts: []Option{WithToken(MapPageToken{"pk": "token_value", "created_at": "timestamp"}), WithColumn("created_at", Order("unknown order"))}, + e: `WHERE "pk" > $1 ORDER BY "pk"`, + args: []interface{}{"token_value"}, + }, + } { + t.Run("case="+tc.d, func(t *testing.T) { + opts := append(tc.opts, WithSize(10)) + paginator := GetPaginator(opts...) + sql, args := pop.Q(c). + Scope(Paginate[testItem](paginator)). + ToSQL(&pop.Model{Value: new(testItem)}) + assert.Contains(t, sql, tc.e) + assert.Contains(t, sql, "LIMIT 11") + assert.Equal(t, tc.args, args) + }) + } +}