diff --git a/adapter/cockroachdb/database.go b/adapter/cockroachdb/database.go index ce182187..b2414f16 100644 --- a/adapter/cockroachdb/database.go +++ b/adapter/cockroachdb/database.go @@ -129,7 +129,7 @@ func (*database) CompileStatement(sess sqladapter.Session, stmt *exql.Statement, } query, args := sqlbuilder.Preprocess(compiled, args) - query = sqladapter.ReplaceWithDollarSign(query) + query = string(sqladapter.ReplaceWithDollarSign([]byte(query))) return query, args, nil } diff --git a/adapter/postgresql/database.go b/adapter/postgresql/database.go index 82e21772..cea7da2e 100644 --- a/adapter/postgresql/database.go +++ b/adapter/postgresql/database.go @@ -99,7 +99,7 @@ func (*database) CompileStatement(sess sqladapter.Session, stmt *exql.Statement, } query, args := sqlbuilder.Preprocess(compiled, args) - query = sqladapter.ReplaceWithDollarSign(query) + query = string(sqladapter.ReplaceWithDollarSign([]byte(query))) return query, args, nil } diff --git a/adapter/ql/database.go b/adapter/ql/database.go index afd5ab67..1ee6728b 100644 --- a/adapter/ql/database.go +++ b/adapter/ql/database.go @@ -81,7 +81,7 @@ func (*database) CompileStatement(sess sqladapter.Session, stmt *exql.Statement, } query, args := sqlbuilder.Preprocess(compiled, args) - query = sqladapter.ReplaceWithDollarSign(query) + query = string(sqladapter.ReplaceWithDollarSign([]byte(query))) return query, args, nil } diff --git a/go.mod b/go.mod index 28f3624a..09a6b7fa 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,8 @@ require ( github.com/jackc/pgx/v4 v4.15.0 github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.9 + github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect + github.com/segmentio/fasthash v1.0.3 github.com/sirupsen/logrus v1.8.1 github.com/stretchr/testify v1.7.0 golang.org/x/crypto v0.0.0-20220307211146-efcb8507fb70 // indirect diff --git a/go.sum b/go.sum index f06fdf0b..45519f00 100644 --- a/go.sum +++ b/go.sum @@ -100,6 +100,8 @@ github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA= github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= +github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -112,6 +114,8 @@ github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/segmentio/fasthash v1.0.3 h1:EI9+KE1EwvMLBWwjpRDc+fEM+prwxDYbslddQGtrmhM= +github.com/segmentio/fasthash v1.0.3/go.mod h1:waKX8l2N8yckOgmSsXJi7x1ZfdKZ4x7KRMzBtS3oedY= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= diff --git a/internal/cache/cache.go b/internal/cache/cache.go index fbf0353a..80dadac9 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -24,25 +24,21 @@ package cache import ( "container/list" "errors" - "fmt" - "strconv" "sync" - - "github.com/upper/db/v4/internal/cache/hashstructure" ) const defaultCapacity = 128 // Cache holds a map of volatile key -> values. type Cache struct { - cache map[string]*list.Element - li *list.List - capacity int + keys *list.List + items map[uint64]*list.Element mu sync.RWMutex + capacity int } -type item struct { - key string +type cacheItem struct { + key uint64 value interface{} } @@ -52,11 +48,11 @@ func NewCacheWithCapacity(capacity int) (*Cache, error) { if capacity < 1 { return nil, errors.New("Capacity must be greater than zero.") } - return &Cache{ - cache: make(map[string]*list.Element), - li: list.New(), + c := &Cache{ capacity: capacity, - }, nil + } + c.init() + return c, nil } // NewCache initializes a new caching space with default settings. @@ -68,6 +64,11 @@ func NewCache() *Cache { return c } +func (c *Cache) init() { + c.items = make(map[uint64]*list.Element) + c.keys = list.New() +} + // Read attempts to retrieve a cached value as a string, if the value does not // exists returns an empty string and false. func (c *Cache) Read(h Hashable) (string, bool) { @@ -84,33 +85,35 @@ func (c *Cache) Read(h Hashable) (string, bool) { func (c *Cache) ReadRaw(h Hashable) (interface{}, bool) { c.mu.RLock() defer c.mu.RUnlock() - data, ok := c.cache[h.Hash()] + + item, ok := c.items[h.Hash()] if ok { - return data.Value.(*item).value, true + return item.Value.(*cacheItem).value, true } + return nil, false } // Write stores a value in memory. If the value already exists its overwritten. func (c *Cache) Write(h Hashable, value interface{}) { - key := h.Hash() - c.mu.Lock() defer c.mu.Unlock() - if el, ok := c.cache[key]; ok { - el.Value.(*item).value = value - c.li.MoveToFront(el) + key := h.Hash() + + if item, ok := c.items[key]; ok { + item.Value.(*cacheItem).value = value + c.keys.MoveToFront(item) return } - c.cache[key] = c.li.PushFront(&item{key, value}) + c.items[key] = c.keys.PushFront(&cacheItem{key, value}) - for c.li.Len() > c.capacity { - el := c.li.Remove(c.li.Back()) - delete(c.cache, el.(*item).key) - if p, ok := el.(*item).value.(HasOnPurge); ok { - p.OnPurge() + for c.keys.Len() > c.capacity { + item := c.keys.Remove(c.keys.Back()).(*cacheItem) + delete(c.items, item.key) + if p, ok := item.value.(HasOnEvict); ok { + p.OnEvict() } } } @@ -120,33 +123,12 @@ func (c *Cache) Write(h Hashable, value interface{}) { func (c *Cache) Clear() { c.mu.Lock() defer c.mu.Unlock() - for _, el := range c.cache { - if p, ok := el.Value.(*item).value.(HasOnPurge); ok { - p.OnPurge() - } - } - c.cache = make(map[string]*list.Element) - c.li.Init() -} -// Hash returns a hash of the given struct. -func Hash(v interface{}) string { - q, err := hashstructure.Hash(v, nil) - if err != nil { - panic(fmt.Sprintf("Could not hash struct: %v", err.Error())) + for _, item := range c.items { + if p, ok := item.Value.(*cacheItem).value.(HasOnEvict); ok { + p.OnEvict() + } } - return strconv.FormatUint(q, 10) -} - -type hash struct { - name string -} - -func (h *hash) Hash() string { - return h.name -} -// String returns a Hashable that produces a hash equal to the given string. -func String(s string) Hashable { - return &hash{s} + c.init() } diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index d0d16410..76634cce 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -23,6 +23,7 @@ package cache import ( "fmt" + "hash/fnv" "testing" ) @@ -32,8 +33,10 @@ type cacheableT struct { Name string } -func (ct *cacheableT) Hash() string { - return Hash(ct) +func (ct *cacheableT) Hash() uint64 { + s := fnv.New64() + s.Sum([]byte(ct.Name)) + return s.Sum64() } var ( @@ -77,6 +80,13 @@ func BenchmarkNewCache(b *testing.B) { } } +func BenchmarkNewCacheAndClear(b *testing.B) { + for i := 0; i < b.N; i++ { + c := NewCache() + c.Clear() + } +} + func BenchmarkReadNonExistentValue(b *testing.B) { z := NewCache() for i := 0; i < b.N; i++ { diff --git a/internal/cache/hash.go b/internal/cache/hash.go new file mode 100644 index 00000000..4b866a9d --- /dev/null +++ b/internal/cache/hash.go @@ -0,0 +1,109 @@ +package cache + +import ( + "fmt" + + "github.com/segmentio/fasthash/fnv1a" +) + +const ( + hashTypeInt uint64 = 1 << iota + hashTypeSignedInt + hashTypeBool + hashTypeString + hashTypeHashable + hashTypeNil +) + +type hasher struct { + t uint64 + v interface{} +} + +func (h *hasher) Hash() uint64 { + return NewHash(h.t, h.v) +} + +func NewHashable(t uint64, v interface{}) Hashable { + return &hasher{t: t, v: v} +} + +func InitHash(t uint64) uint64 { + return fnv1a.AddUint64(fnv1a.Init64, t) +} + +func NewHash(t uint64, in ...interface{}) uint64 { + return AddToHash(InitHash(t), in...) +} + +func AddToHash(h uint64, in ...interface{}) uint64 { + for i := range in { + if in[i] == nil { + continue + } + h = addToHash(h, in[i]) + } + return h +} + +func addToHash(h uint64, in interface{}) uint64 { + switch v := in.(type) { + case uint64: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), v) + case uint32: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + case uint16: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + case uint8: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + case uint: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + case int64: + if v < 0 { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v)) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + } + case int32: + if v < 0 { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v)) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + } + case int16: + if v < 0 { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v)) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + } + case int8: + if v < 0 { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v)) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + } + case int: + if v < 0 { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v)) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + } + case bool: + if v { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeBool), 1) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeBool), 2) + } + case string: + return fnv1a.AddString64(fnv1a.AddUint64(h, hashTypeString), v) + case Hashable: + if in == nil { + panic(fmt.Sprintf("could not hash nil element %T", in)) + } + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeHashable), v.Hash()) + case nil: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeNil), 0) + default: + panic(fmt.Sprintf("unsupported value type %T", in)) + } +} diff --git a/internal/cache/hashstructure/LICENSE b/internal/cache/hashstructure/LICENSE deleted file mode 100644 index a3866a29..00000000 --- a/internal/cache/hashstructure/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2016 Mitchell Hashimoto - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. diff --git a/internal/cache/hashstructure/README.md b/internal/cache/hashstructure/README.md deleted file mode 100644 index 7d0de5bf..00000000 --- a/internal/cache/hashstructure/README.md +++ /dev/null @@ -1,61 +0,0 @@ -# hashstructure - -hashstructure is a Go library for creating a unique hash value -for arbitrary values in Go. - -This can be used to key values in a hash (for use in a map, set, etc.) -that are complex. The most common use case is comparing two values without -sending data across the network, caching values locally (de-dup), and so on. - -## Features - - * Hash any arbitrary Go value, including complex types. - - * Tag a struct field to ignore it and not affect the hash value. - - * Tag a slice type struct field to treat it as a set where ordering - doesn't affect the hash code but the field itself is still taken into - account to create the hash value. - - * Optionally specify a custom hash function to optimize for speed, collision - avoidance for your data set, etc. - -## Installation - -Standard `go get`: - -``` -$ go get github.com/mitchellh/hashstructure -``` - -## Usage & Example - -For usage and examples see the [Godoc](http://godoc.org/github.com/mitchellh/hashstructure). - -A quick code example is shown below: - - - type ComplexStruct struct { - Name string - Age uint - Metadata map[string]interface{} - } - - v := ComplexStruct{ - Name: "mitchellh", - Age: 64, - Metadata: map[string]interface{}{ - "car": true, - "location": "California", - "siblings": []string{"Bob", "John"}, - }, - } - - hash, err := hashstructure.Hash(v, nil) - if err != nil { - panic(err) - } - - fmt.Printf("%d", hash) - // Output: - // 2307517237273902113 diff --git a/internal/cache/hashstructure/hashstructure.go b/internal/cache/hashstructure/hashstructure.go deleted file mode 100644 index 9a2e9535..00000000 --- a/internal/cache/hashstructure/hashstructure.go +++ /dev/null @@ -1,325 +0,0 @@ -package hashstructure - -import ( - "encoding/binary" - "fmt" - "hash" - "hash/fnv" - "reflect" -) - -// HashOptions are options that are available for hashing. -type HashOptions struct { - // Hasher is the hash function to use. If this isn't set, it will - // default to FNV. - Hasher hash.Hash64 - - // TagName is the struct tag to look at when hashing the structure. - // By default this is "hash". - TagName string -} - -// Hash returns the hash value of an arbitrary value. -// -// If opts is nil, then default options will be used. See HashOptions -// for the default values. -// -// Notes on the value: -// -// * Unexported fields on structs are ignored and do not affect the -// hash value. -// -// * Adding an exported field to a struct with the zero value will change -// the hash value. -// -// For structs, the hashing can be controlled using tags. For example: -// -// struct { -// Name string -// UUID string `hash:"ignore"` -// } -// -// The available tag values are: -// -// * "ignore" - The field will be ignored and not affect the hash code. -// -// * "set" - The field will be treated as a set, where ordering doesn't -// affect the hash code. This only works for slices. -// -func Hash(v interface{}, opts *HashOptions) (uint64, error) { - // Create default options - if opts == nil { - opts = &HashOptions{} - } - if opts.Hasher == nil { - opts.Hasher = fnv.New64() - } - if opts.TagName == "" { - opts.TagName = "hash" - } - - // Reset the hash - opts.Hasher.Reset() - - // Create our walker and walk the structure - w := &walker{ - h: opts.Hasher, - tag: opts.TagName, - } - return w.visit(reflect.ValueOf(v), nil) -} - -type walker struct { - h hash.Hash64 - tag string -} - -type visitOpts struct { - // Flags are a bitmask of flags to affect behavior of this visit - Flags visitFlag - - // Information about the struct containing this field - Struct interface{} - StructField string -} - -func (w *walker) visit(v reflect.Value, opts *visitOpts) (uint64, error) { - // Loop since these can be wrapped in multiple layers of pointers - // and interfaces. - for { - // If we have an interface, dereference it. We have to do this up - // here because it might be a nil in there and the check below must - // catch that. - if v.Kind() == reflect.Interface { - v = v.Elem() - continue - } - - if v.Kind() == reflect.Ptr { - v = reflect.Indirect(v) - continue - } - - break - } - - // If it is nil, treat it like a zero. - if !v.IsValid() { - var tmp int8 - v = reflect.ValueOf(tmp) - } - - // Binary writing can use raw ints, we have to convert to - // a sized-int, we'll choose the largest... - switch v.Kind() { - case reflect.Int: - v = reflect.ValueOf(int64(v.Int())) - case reflect.Uint: - v = reflect.ValueOf(uint64(v.Uint())) - case reflect.Bool: - var tmp int8 - if v.Bool() { - tmp = 1 - } - v = reflect.ValueOf(tmp) - } - - k := v.Kind() - - // We can shortcut numeric values by directly binary writing them - if k >= reflect.Int && k <= reflect.Complex64 { - // A direct hash calculation - w.h.Reset() - err := binary.Write(w.h, binary.LittleEndian, v.Interface()) - return w.h.Sum64(), err - } - - switch k { - case reflect.Array: - var h uint64 - l := v.Len() - for i := 0; i < l; i++ { - current, err := w.visit(v.Index(i), nil) - if err != nil { - return 0, err - } - - h = hashUpdateOrdered(w.h, h, current) - } - - return h, nil - - case reflect.Map: - var includeMap IncludableMap - if opts != nil && opts.Struct != nil { - if v, ok := opts.Struct.(IncludableMap); ok { - includeMap = v - } - } - - // Build the hash for the map. We do this by XOR-ing all the key - // and value hashes. This makes it deterministic despite ordering. - var h uint64 - for _, k := range v.MapKeys() { - v := v.MapIndex(k) - if includeMap != nil { - incl, err := includeMap.HashIncludeMap( - opts.StructField, k.Interface(), v.Interface()) - if err != nil { - return 0, err - } - if !incl { - continue - } - } - - kh, err := w.visit(k, nil) - if err != nil { - return 0, err - } - vh, err := w.visit(v, nil) - if err != nil { - return 0, err - } - - fieldHash := hashUpdateOrdered(w.h, kh, vh) - h = hashUpdateUnordered(h, fieldHash) - } - - return h, nil - - case reflect.Struct: - var include Includable - parent := v.Interface() - if impl, ok := parent.(Includable); ok { - include = impl - } - - t := v.Type() - h, err := w.visit(reflect.ValueOf(t.Name()), nil) - if err != nil { - return 0, err - } - - l := v.NumField() - for i := 0; i < l; i++ { - if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" { - var f visitFlag - fieldType := t.Field(i) - if fieldType.PkgPath != "" { - // Unexported - continue - } - - tag := fieldType.Tag.Get(w.tag) - if tag == "ignore" { - // Ignore this field - continue - } - - // Check if we implement includable and check it - if include != nil { - incl, err := include.HashInclude(fieldType.Name, v) - if err != nil { - return 0, err - } - if !incl { - continue - } - } - - switch tag { - case "set": - f |= visitFlagSet - } - - kh, err := w.visit(reflect.ValueOf(fieldType.Name), nil) - if err != nil { - return 0, err - } - - vh, err := w.visit(v, &visitOpts{ - Flags: f, - Struct: parent, - StructField: fieldType.Name, - }) - if err != nil { - return 0, err - } - - fieldHash := hashUpdateOrdered(w.h, kh, vh) - h = hashUpdateUnordered(h, fieldHash) - } - } - - return h, nil - - case reflect.Slice: - // We have two behaviors here. If it isn't a set, then we just - // visit all the elements. If it is a set, then we do a deterministic - // hash code. - var h uint64 - var set bool - if opts != nil { - set = (opts.Flags & visitFlagSet) != 0 - } - l := v.Len() - for i := 0; i < l; i++ { - current, err := w.visit(v.Index(i), nil) - if err != nil { - return 0, err - } - - if set { - h = hashUpdateUnordered(h, current) - } else { - h = hashUpdateOrdered(w.h, h, current) - } - } - - return h, nil - - case reflect.String: - // Directly hash - w.h.Reset() - _, err := w.h.Write([]byte(v.String())) - return w.h.Sum64(), err - - default: - return 0, fmt.Errorf("unknown kind to hash: %s", k) - } -} - -func hashUpdateOrdered(h hash.Hash64, a, b uint64) uint64 { - // For ordered updates, use a real hash function - h.Reset() - - // We just panic if the binary writes fail because we are writing - // an int64 which should never be fail-able. - e1 := binary.Write(h, binary.LittleEndian, a) - e2 := binary.Write(h, binary.LittleEndian, b) - if e1 != nil { - panic(e1) - } - if e2 != nil { - panic(e2) - } - - return h.Sum64() -} - -func hashUpdateUnordered(a, b uint64) uint64 { - return a ^ b -} - -// visitFlag is used as a bitmask for affecting visit behavior -type visitFlag uint - -const ( - visitFlagInvalid visitFlag = iota - visitFlagSet = iota << 1 -) - -var ( - _ = visitFlagInvalid -) diff --git a/internal/cache/hashstructure/hashstructure_test.go b/internal/cache/hashstructure/hashstructure_test.go deleted file mode 100644 index 919f8966..00000000 --- a/internal/cache/hashstructure/hashstructure_test.go +++ /dev/null @@ -1,357 +0,0 @@ -package hashstructure - -import ( - "testing" -) - -func TestHash_identity(t *testing.T) { - cases := []interface{}{ - nil, - "foo", - 42, - true, - false, - []string{"foo", "bar"}, - []interface{}{1, nil, "foo"}, - map[string]string{"foo": "bar"}, - map[interface{}]string{"foo": "bar"}, - map[interface{}]interface{}{"foo": "bar", "bar": 0}, - struct { - Foo string - Bar []interface{} - }{ - Foo: "foo", - Bar: []interface{}{nil, nil, nil}, - }, - &struct { - Foo string - Bar []interface{} - }{ - Foo: "foo", - Bar: []interface{}{nil, nil, nil}, - }, - } - - for _, tc := range cases { - // We run the test 100 times to try to tease out variability - // in the runtime in terms of ordering. - valuelist := make([]uint64, 100) - for i := range valuelist { - v, err := Hash(tc, nil) - if err != nil { - t.Fatalf("Error: %s\n\n%#v", err, tc) - } - - valuelist[i] = v - } - - // Zero is always wrong - if valuelist[0] == 0 { - t.Fatalf("zero hash: %#v", tc) - } - - // Make sure all the values match - t.Logf("%#v: %d", tc, valuelist[0]) - for i := 1; i < len(valuelist); i++ { - if valuelist[i] != valuelist[0] { - t.Fatalf("non-matching: %d, %d\n\n%#v", i, 0, tc) - } - } - } -} - -func TestHash_equal(t *testing.T) { - type testFoo struct{ Name string } - type testBar struct{ Name string } - - cases := []struct { - One, Two interface{} - Match bool - }{ - { - map[string]string{"foo": "bar"}, - map[interface{}]string{"foo": "bar"}, - true, - }, - - { - map[string]interface{}{"1": "1"}, - map[string]interface{}{"1": "1", "2": "2"}, - false, - }, - - { - struct{ Fname, Lname string }{"foo", "bar"}, - struct{ Fname, Lname string }{"bar", "foo"}, - false, - }, - - { - struct{ Lname, Fname string }{"foo", "bar"}, - struct{ Fname, Lname string }{"foo", "bar"}, - false, - }, - - { - struct{ Lname, Fname string }{"foo", "bar"}, - struct{ Fname, Lname string }{"bar", "foo"}, - true, - }, - - { - testFoo{"foo"}, - testBar{"foo"}, - false, - }, - - { - struct { - Foo string - unexported string - }{ - Foo: "bar", - unexported: "baz", - }, - struct { - Foo string - unexported string - }{ - Foo: "bar", - unexported: "bang", - }, - true, - }, - } - - for _, tc := range cases { - t.Logf("Hashing: %#v", tc.One) - one, err := Hash(tc.One, nil) - t.Logf("Result: %d", one) - if err != nil { - t.Fatalf("Failed to hash %#v: %s", tc.One, err) - } - t.Logf("Hashing: %#v", tc.Two) - two, err := Hash(tc.Two, nil) - t.Logf("Result: %d", two) - if err != nil { - t.Fatalf("Failed to hash %#v: %s", tc.Two, err) - } - - // Zero is always wrong - if one == 0 { - t.Fatalf("zero hash: %#v", tc.One) - } - - // Compare - if (one == two) != tc.Match { - t.Fatalf("bad, expected: %#v\n\n%#v\n\n%#v", tc.Match, tc.One, tc.Two) - } - } -} - -func TestHash_equalIgnore(t *testing.T) { - type Test struct { - Name string - UUID string `hash:"ignore"` - } - - cases := []struct { - One, Two interface{} - Match bool - }{ - { - Test{Name: "foo", UUID: "foo"}, - Test{Name: "foo", UUID: "bar"}, - true, - }, - - { - Test{Name: "foo", UUID: "foo"}, - Test{Name: "foo", UUID: "foo"}, - true, - }, - } - - for _, tc := range cases { - one, err := Hash(tc.One, nil) - if err != nil { - t.Fatalf("Failed to hash %#v: %s", tc.One, err) - } - two, err := Hash(tc.Two, nil) - if err != nil { - t.Fatalf("Failed to hash %#v: %s", tc.Two, err) - } - - // Zero is always wrong - if one == 0 { - t.Fatalf("zero hash: %#v", tc.One) - } - - // Compare - if (one == two) != tc.Match { - t.Fatalf("bad, expected: %#v\n\n%#v\n\n%#v", tc.Match, tc.One, tc.Two) - } - } -} - -func TestHash_equalSet(t *testing.T) { - type Test struct { - Name string - Friends []string `hash:"set"` - } - - cases := []struct { - One, Two interface{} - Match bool - }{ - { - Test{Name: "foo", Friends: []string{"foo", "bar"}}, - Test{Name: "foo", Friends: []string{"bar", "foo"}}, - true, - }, - - { - Test{Name: "foo", Friends: []string{"foo", "bar"}}, - Test{Name: "foo", Friends: []string{"foo", "bar"}}, - true, - }, - } - - for _, tc := range cases { - one, err := Hash(tc.One, nil) - if err != nil { - t.Fatalf("Failed to hash %#v: %s", tc.One, err) - } - two, err := Hash(tc.Two, nil) - if err != nil { - t.Fatalf("Failed to hash %#v: %s", tc.Two, err) - } - - // Zero is always wrong - if one == 0 { - t.Fatalf("zero hash: %#v", tc.One) - } - - // Compare - if (one == two) != tc.Match { - t.Fatalf("bad, expected: %#v\n\n%#v\n\n%#v", tc.Match, tc.One, tc.Two) - } - } -} - -func TestHash_includable(t *testing.T) { - cases := []struct { - One, Two interface{} - Match bool - }{ - { - testIncludable{Value: "foo"}, - testIncludable{Value: "foo"}, - true, - }, - - { - testIncludable{Value: "foo", Ignore: "bar"}, - testIncludable{Value: "foo"}, - true, - }, - - { - testIncludable{Value: "foo", Ignore: "bar"}, - testIncludable{Value: "bar"}, - false, - }, - } - - for _, tc := range cases { - one, err := Hash(tc.One, nil) - if err != nil { - t.Fatalf("Failed to hash %#v: %s", tc.One, err) - } - two, err := Hash(tc.Two, nil) - if err != nil { - t.Fatalf("Failed to hash %#v: %s", tc.Two, err) - } - - // Zero is always wrong - if one == 0 { - t.Fatalf("zero hash: %#v", tc.One) - } - - // Compare - if (one == two) != tc.Match { - t.Fatalf("bad, expected: %#v\n\n%#v\n\n%#v", tc.Match, tc.One, tc.Two) - } - } -} - -func TestHash_includableMap(t *testing.T) { - cases := []struct { - One, Two interface{} - Match bool - }{ - { - testIncludableMap{Map: map[string]string{"foo": "bar"}}, - testIncludableMap{Map: map[string]string{"foo": "bar"}}, - true, - }, - - { - testIncludableMap{Map: map[string]string{"foo": "bar", "ignore": "true"}}, - testIncludableMap{Map: map[string]string{"foo": "bar"}}, - true, - }, - - { - testIncludableMap{Map: map[string]string{"foo": "bar", "ignore": "true"}}, - testIncludableMap{Map: map[string]string{"bar": "baz"}}, - false, - }, - } - - for _, tc := range cases { - one, err := Hash(tc.One, nil) - if err != nil { - t.Fatalf("Failed to hash %#v: %s", tc.One, err) - } - two, err := Hash(tc.Two, nil) - if err != nil { - t.Fatalf("Failed to hash %#v: %s", tc.Two, err) - } - - // Zero is always wrong - if one == 0 { - t.Fatalf("zero hash: %#v", tc.One) - } - - // Compare - if (one == two) != tc.Match { - t.Fatalf("bad, expected: %#v\n\n%#v\n\n%#v", tc.Match, tc.One, tc.Two) - } - } -} - -type testIncludable struct { - Value string - Ignore string -} - -func (t testIncludable) HashInclude(field string, v interface{}) (bool, error) { - return field != "Ignore", nil -} - -type testIncludableMap struct { - Map map[string]string -} - -func (t testIncludableMap) HashIncludeMap(field string, k, v interface{}) (bool, error) { - if field != "Map" { - return true, nil - } - - if s, ok := k.(string); ok && s == "ignore" { - return false, nil - } - - return true, nil -} diff --git a/internal/cache/hashstructure/include.go b/internal/cache/hashstructure/include.go deleted file mode 100644 index b6289c0b..00000000 --- a/internal/cache/hashstructure/include.go +++ /dev/null @@ -1,15 +0,0 @@ -package hashstructure - -// Includable is an interface that can optionally be implemented by -// a struct. It will be called for each field in the struct to check whether -// it should be included in the hash. -type Includable interface { - HashInclude(field string, v interface{}) (bool, error) -} - -// IncludableMap is an interface that can optionally be implemented by -// a struct. It will be called when a map-type field is found to ask the -// struct if the map item should be included in the hash. -type IncludableMap interface { - HashIncludeMap(field string, k, v interface{}) (bool, error) -} diff --git a/internal/cache/interface.go b/internal/cache/interface.go index 489d6452..c63246af 100644 --- a/internal/cache/interface.go +++ b/internal/cache/interface.go @@ -24,11 +24,11 @@ package cache // Hashable types must implement a method that returns a key. This key will be // associated with a cached value. type Hashable interface { - Hash() string + Hash() uint64 } -// HasOnPurge type is (optionally) implemented by cache objects to clean after +// HasOnEvict type is (optionally) implemented by cache objects to clean after // themselves. -type HasOnPurge interface { - OnPurge() +type HasOnEvict interface { + OnEvict() } diff --git a/internal/sqladapter/exql/column.go b/internal/sqladapter/exql/column.go index 4140a442..5789317b 100644 --- a/internal/sqladapter/exql/column.go +++ b/internal/sqladapter/exql/column.go @@ -3,18 +3,18 @@ package exql import ( "fmt" "strings" + + "github.com/upper/db/v4/internal/cache" ) -type columnT struct { +type columnWithAlias struct { Name string Alias string } // Column represents a SQL column. type Column struct { - Name interface{} - Alias string - hash hash + Name interface{} } var _ = Fragment(&Column{}) @@ -25,8 +25,11 @@ func ColumnWithName(name string) *Column { } // Hash returns a unique identifier for the struct. -func (c *Column) Hash() string { - return c.hash.Hash(c) +func (c *Column) Hash() uint64 { + if c == nil { + return cache.NewHash(FragmentType_Column, nil) + } + return cache.NewHash(FragmentType_Column, c.Name) } // Compile transforms the ColumnValue into an equivalent SQL representation. @@ -35,20 +38,17 @@ func (c *Column) Compile(layout *Template) (compiled string, err error) { return z, nil } - alias := c.Alias - + var alias string switch value := c.Name.(type) { case string: - input := trimString(value) - - chunks := separateByAS(input) + value = trimString(value) + chunks := separateByAS(value) if len(chunks) == 1 { - chunks = separateBySpace(input) + chunks = separateBySpace(value) } name := chunks[0] - nameChunks := strings.SplitN(name, layout.ColumnSeparator, 2) for i := range nameChunks { @@ -65,17 +65,19 @@ func (c *Column) Compile(layout *Template) (compiled string, err error) { alias = trimString(chunks[1]) alias = layout.MustCompile(layout.IdentifierQuote, Raw{Value: alias}) } - case Raw: - compiled = value.String() + case compilable: + compiled, err = value.Compile(layout) + if err != nil { + return "", err + } default: - compiled = fmt.Sprintf("%v", c.Name) + return "", fmt.Errorf(errExpectingHashableFmt, c.Name) } if alias != "" { - compiled = layout.MustCompile(layout.ColumnAliasLayout, columnT{compiled, alias}) + compiled = layout.MustCompile(layout.ColumnAliasLayout, columnWithAlias{compiled, alias}) } layout.Write(c, compiled) - return } diff --git a/internal/sqladapter/exql/column_test.go b/internal/sqladapter/exql/column_test.go index f999960e..7706538b 100644 --- a/internal/sqladapter/exql/column_test.go +++ b/internal/sqladapter/exql/column_test.go @@ -2,76 +2,36 @@ package exql import ( "testing" -) - -func TestColumnHash(t *testing.T) { - var s, e string - - column := Column{Name: "role.name"} - s = column.Hash() - e = "*exql.Column:5663680925324531495" - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} + "github.com/stretchr/testify/assert" +) func TestColumnString(t *testing.T) { - column := Column{Name: "role.name"} - s, err := column.Compile(defaultTemplate) - if err != nil { - t.Fatal() - } - - e := `"role"."name"` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.NoError(t, err) + assert.Equal(t, `"role"."name"`, s) } func TestColumnAs(t *testing.T) { column := Column{Name: "role.name as foo"} - s, err := column.Compile(defaultTemplate) - if err != nil { - t.Fatal() - } - - e := `"role"."name" AS "foo"` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.NoError(t, err) + assert.Equal(t, `"role"."name" AS "foo"`, s) } func TestColumnImplicitAs(t *testing.T) { column := Column{Name: "role.name foo"} - s, err := column.Compile(defaultTemplate) - if err != nil { - t.Fatal() - } - - e := `"role"."name" AS "foo"` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.NoError(t, err) + assert.Equal(t, `"role"."name" AS "foo"`, s) } func TestColumnRaw(t *testing.T) { - column := Column{Name: Raw{Value: "role.name As foo"}} - + column := Column{Name: &Raw{Value: "role.name As foo"}} s, err := column.Compile(defaultTemplate) - if err != nil { - t.Fatal() - } - - e := `role.name As foo` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.NoError(t, err) + assert.Equal(t, `role.name As foo`, s) } func BenchmarkColumnWithName(b *testing.B) { @@ -82,6 +42,7 @@ func BenchmarkColumnWithName(b *testing.B) { func BenchmarkColumnHash(b *testing.B) { c := Column{Name: "name"} + b.ResetTimer() for i := 0; i < b.N; i++ { c.Hash() } @@ -89,6 +50,7 @@ func BenchmarkColumnHash(b *testing.B) { func BenchmarkColumnCompile(b *testing.B) { c := Column{Name: "name"} + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = c.Compile(defaultTemplate) } @@ -103,6 +65,7 @@ func BenchmarkColumnCompileNoCache(b *testing.B) { func BenchmarkColumnWithDotCompile(b *testing.B) { c := Column{Name: "role.name"} + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = c.Compile(defaultTemplate) } @@ -110,6 +73,7 @@ func BenchmarkColumnWithDotCompile(b *testing.B) { func BenchmarkColumnWithImplicitAsKeywordCompile(b *testing.B) { c := Column{Name: "role.name foo"} + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = c.Compile(defaultTemplate) } @@ -117,6 +81,7 @@ func BenchmarkColumnWithImplicitAsKeywordCompile(b *testing.B) { func BenchmarkColumnWithAsKeywordCompile(b *testing.B) { c := Column{Name: "role.name AS foo"} + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = c.Compile(defaultTemplate) } diff --git a/internal/sqladapter/exql/column_value.go b/internal/sqladapter/exql/column_value.go index 018faa45..49296114 100644 --- a/internal/sqladapter/exql/column_value.go +++ b/internal/sqladapter/exql/column_value.go @@ -1,6 +1,7 @@ package exql import ( + "github.com/upper/db/v4/internal/cache" "strings" ) @@ -9,7 +10,6 @@ type ColumnValue struct { Column Fragment Operator string Value Fragment - hash hash } var _ = Fragment(&ColumnValue{}) @@ -21,8 +21,11 @@ type columnValueT struct { } // Hash returns a unique identifier for the struct. -func (c *ColumnValue) Hash() string { - return c.hash.Hash(c) +func (c *ColumnValue) Hash() uint64 { + if c == nil { + return cache.NewHash(FragmentType_ColumnValue, nil) + } + return cache.NewHash(FragmentType_ColumnValue, c.Column, c.Operator, c.Value) } // Compile transforms the ColumnValue into an equivalent SQL representation. @@ -58,7 +61,6 @@ func (c *ColumnValue) Compile(layout *Template) (compiled string, err error) { // ColumnValues represents an array of ColumnValue type ColumnValues struct { ColumnValues []Fragment - hash hash } var _ = Fragment(&ColumnValues{}) @@ -71,13 +73,16 @@ func JoinColumnValues(values ...Fragment) *ColumnValues { // Insert adds a column to the columns array. func (c *ColumnValues) Insert(values ...Fragment) *ColumnValues { c.ColumnValues = append(c.ColumnValues, values...) - c.hash.Reset() return c } // Hash returns a unique identifier for the struct. -func (c *ColumnValues) Hash() string { - return c.hash.Hash(c) +func (c *ColumnValues) Hash() uint64 { + h := cache.InitHash(FragmentType_ColumnValues) + for i := range c.ColumnValues { + h = cache.AddToHash(h, c.ColumnValues[i]) + } + return h } // Compile transforms the ColumnValues into its SQL representation. diff --git a/internal/sqladapter/exql/column_value_test.go b/internal/sqladapter/exql/column_value_test.go index 9569fcac..33ecc36d 100644 --- a/internal/sqladapter/exql/column_value_test.go +++ b/internal/sqladapter/exql/column_value_test.go @@ -2,61 +2,20 @@ package exql import ( "testing" -) - -func TestColumnValueHash(t *testing.T) { - var s, e string - - c := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} - s = c.Hash() - e = `*exql.ColumnValue:4950005282640920683` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestColumnValuesHash(t *testing.T) { - var s, e string - - c := JoinColumnValues( - &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)}, - &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(2)}, - ) - - s = c.Hash() - e = `*exql.ColumnValues:8728513848368010747` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} + "github.com/stretchr/testify/assert" +) func TestColumnValue(t *testing.T) { cv := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} - s, err := cv.Compile(defaultTemplate) - if err != nil { - t.Fatal() - } - - e := `"id" = '1'` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } - - cv = &ColumnValue{Column: ColumnWithName("date"), Operator: "=", Value: NewValue(RawValue("NOW()"))} + assert.NoError(t, err) + assert.Equal(t, `"id" = '1'`, s) + cv = &ColumnValue{Column: ColumnWithName("date"), Operator: "=", Value: &Raw{Value: "NOW()"}} s, err = cv.Compile(defaultTemplate) - if err != nil { - t.Fatal() - } - - e = `"date" = NOW()` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.NoError(t, err) + assert.Equal(t, `"date" = NOW()`, s) } func TestColumnValues(t *testing.T) { @@ -69,14 +28,8 @@ func TestColumnValues(t *testing.T) { ) s, err := cvs.Compile(defaultTemplate) - if err != nil { - t.Fatal() - } - - e := `"id" > '8', "other"."id" < 100, "name" = 'Haruki Murakami', "created" >= NOW(), "modified" <= NOW()` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.NoError(t, err) + assert.Equal(t, `"id" > '8', "other"."id" < 100, "name" = 'Haruki Murakami', "created" >= NOW(), "modified" <= NOW()`, s) } func BenchmarkNewColumnValue(b *testing.B) { @@ -87,6 +40,7 @@ func BenchmarkNewColumnValue(b *testing.B) { func BenchmarkColumnValueHash(b *testing.B) { cv := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + b.ResetTimer() for i := 0; i < b.N; i++ { cv.Hash() } @@ -94,6 +48,7 @@ func BenchmarkColumnValueHash(b *testing.B) { func BenchmarkColumnValueCompile(b *testing.B) { cv := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = cv.Compile(defaultTemplate) } @@ -121,11 +76,12 @@ func BenchmarkJoinColumnValues(b *testing.B) { func BenchmarkColumnValuesHash(b *testing.B) { cvs := JoinColumnValues( &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, - &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(&Raw{Value: "100"})}, &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, - &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(Raw{Value: "NOW()"})}, - &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(&Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(&Raw{Value: "NOW()"})}, ) + b.ResetTimer() for i := 0; i < b.N; i++ { cvs.Hash() } @@ -134,11 +90,12 @@ func BenchmarkColumnValuesHash(b *testing.B) { func BenchmarkColumnValuesCompile(b *testing.B) { cvs := JoinColumnValues( &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, - &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(&Raw{Value: "100"})}, &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, - &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(Raw{Value: "NOW()"})}, - &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(&Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(&Raw{Value: "NOW()"})}, ) + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = cvs.Compile(defaultTemplate) } @@ -148,10 +105,10 @@ func BenchmarkColumnValuesCompileNoCache(b *testing.B) { for i := 0; i < b.N; i++ { cvs := JoinColumnValues( &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, - &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(&Raw{Value: "100"})}, &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, - &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(Raw{Value: "NOW()"})}, - &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(&Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(&Raw{Value: "NOW()"})}, ) _, _ = cvs.Compile(defaultTemplate) } diff --git a/internal/sqladapter/exql/columns.go b/internal/sqladapter/exql/columns.go index d85e8e4a..c59f73bf 100644 --- a/internal/sqladapter/exql/columns.go +++ b/internal/sqladapter/exql/columns.go @@ -2,19 +2,27 @@ package exql import ( "strings" + + "github.com/upper/db/v4/internal/cache" ) // Columns represents an array of Column. type Columns struct { Columns []Fragment - hash hash } var _ = Fragment(&Columns{}) // Hash returns a unique identifier. -func (c *Columns) Hash() string { - return c.hash.Hash(c) +func (c *Columns) Hash() uint64 { + if c == nil { + return cache.NewHash(FragmentType_Columns, nil) + } + h := cache.InitHash(FragmentType_Columns) + for i := range c.Columns { + h = cache.AddToHash(h, c.Columns[i]) + } + return h } // JoinColumns creates and returns an array of Column. @@ -48,7 +56,6 @@ func (c *Columns) IsEmpty() bool { // Compile transforms the Columns into an equivalent SQL representation. func (c *Columns) Compile(layout *Template) (compiled string, err error) { - if z, ok := layout.Read(c); ok { return z, nil } diff --git a/internal/sqladapter/exql/columns_test.go b/internal/sqladapter/exql/columns_test.go index 4c56f852..39cbd5ce 100644 --- a/internal/sqladapter/exql/columns_test.go +++ b/internal/sqladapter/exql/columns_test.go @@ -2,6 +2,8 @@ package exql import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestColumns(t *testing.T) { @@ -14,14 +16,8 @@ func TestColumns(t *testing.T) { ) s, err := columns.Compile(defaultTemplate) - if err != nil { - t.Fatal() - } - - e := `"id", "customer", "service_id", "role"."name", "role"."id"` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.NoError(t, err) + assert.Equal(t, `"id", "customer", "service_id", "role"."name", "role"."id"`, s) } func BenchmarkJoinColumns(b *testing.B) { @@ -42,6 +38,7 @@ func BenchmarkColumnsHash(b *testing.B) { &Column{Name: "role.name"}, &Column{Name: "role.id"}, ) + b.ResetTimer() for i := 0; i < b.N; i++ { c.Hash() } @@ -55,6 +52,7 @@ func BenchmarkColumnsCompile(b *testing.B) { &Column{Name: "role.name"}, &Column{Name: "role.id"}, ) + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = c.Compile(defaultTemplate) } diff --git a/internal/sqladapter/exql/database.go b/internal/sqladapter/exql/database.go index 1603607e..abdfc1d1 100644 --- a/internal/sqladapter/exql/database.go +++ b/internal/sqladapter/exql/database.go @@ -1,9 +1,12 @@ package exql +import ( + "github.com/upper/db/v4/internal/cache" +) + // Database represents a SQL database. type Database struct { Name string - hash hash } var _ = Fragment(&Database{}) @@ -14,8 +17,11 @@ func DatabaseWithName(name string) *Database { } // Hash returns a unique identifier for the struct. -func (d *Database) Hash() string { - return d.hash.Hash(d) +func (d *Database) Hash() uint64 { + if d == nil { + return cache.NewHash(FragmentType_Database, nil) + } + return cache.NewHash(FragmentType_Database, d.Name) } // Compile transforms the Database into an equivalent SQL representation. diff --git a/internal/sqladapter/exql/database_test.go b/internal/sqladapter/exql/database_test.go index 657c5b40..aba55be2 100644 --- a/internal/sqladapter/exql/database_test.go +++ b/internal/sqladapter/exql/database_test.go @@ -1,39 +1,22 @@ package exql import ( - "fmt" + "strconv" "testing" -) - -func TestDatabaseHash(t *testing.T) { - var s, e string - - column := Database{Name: "users"} - s = column.Hash() - e = `*exql.Database:16777957551305673389` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} + "github.com/stretchr/testify/assert" +) func TestDatabaseCompile(t *testing.T) { column := Database{Name: "name"} - s, err := column.Compile(defaultTemplate) - if err != nil { - t.Fatal() - } - - e := `"name"` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.NoError(t, err) + assert.Equal(t, `"name"`, s) } func BenchmarkDatabaseHash(b *testing.B) { c := Database{Name: "name"} + b.ResetTimer() for i := 0; i < b.N; i++ { c.Hash() } @@ -41,6 +24,7 @@ func BenchmarkDatabaseHash(b *testing.B) { func BenchmarkDatabaseCompile(b *testing.B) { c := Database{Name: "name"} + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = c.Compile(defaultTemplate) } @@ -55,7 +39,7 @@ func BenchmarkDatabaseCompileNoCache(b *testing.B) { func BenchmarkDatabaseCompileNoCache2(b *testing.B) { for i := 0; i < b.N; i++ { - c := Database{Name: fmt.Sprintf("name: %v", i)} + c := Database{Name: strconv.Itoa(i)} _, _ = c.Compile(defaultTemplate) } } diff --git a/internal/sqladapter/exql/errors.go b/internal/sqladapter/exql/errors.go new file mode 100644 index 00000000..b9c8b85e --- /dev/null +++ b/internal/sqladapter/exql/errors.go @@ -0,0 +1,5 @@ +package exql + +const ( + errExpectingHashableFmt = "expecting hashable value, got %T" +) diff --git a/internal/sqladapter/exql/group_by.go b/internal/sqladapter/exql/group_by.go index 4f0132c7..0cb09245 100644 --- a/internal/sqladapter/exql/group_by.go +++ b/internal/sqladapter/exql/group_by.go @@ -1,9 +1,12 @@ package exql +import ( + "github.com/upper/db/v4/internal/cache" +) + // GroupBy represents a SQL's "group by" statement. type GroupBy struct { Columns Fragment - hash hash } var _ = Fragment(&GroupBy{}) @@ -13,8 +16,11 @@ type groupByT struct { } // Hash returns a unique identifier. -func (g *GroupBy) Hash() string { - return g.hash.Hash(g) +func (g *GroupBy) Hash() uint64 { + if g == nil { + return cache.NewHash(FragmentType_GroupBy, nil) + } + return cache.NewHash(FragmentType_GroupBy, g.Columns) } // GroupByColumns creates and returns a GroupBy with the given column. diff --git a/internal/sqladapter/exql/group_by_test.go b/internal/sqladapter/exql/group_by_test.go index f42de617..cdc1e6f6 100644 --- a/internal/sqladapter/exql/group_by_test.go +++ b/internal/sqladapter/exql/group_by_test.go @@ -2,6 +2,8 @@ package exql import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestGroupBy(t *testing.T) { @@ -14,10 +16,7 @@ func TestGroupBy(t *testing.T) { ) s := mustTrim(columns.Compile(defaultTemplate)) - e := `GROUP BY "id", "customer", "service_id", "role"."name", "role"."id"` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.Equal(t, `GROUP BY "id", "customer", "service_id", "role"."name", "role"."id"`, s) } func BenchmarkGroupByColumns(b *testing.B) { @@ -38,6 +37,7 @@ func BenchmarkGroupByHash(b *testing.B) { &Column{Name: "role.name"}, &Column{Name: "role.id"}, ) + b.ResetTimer() for i := 0; i < b.N; i++ { c.Hash() } @@ -51,6 +51,7 @@ func BenchmarkGroupByCompile(b *testing.B) { &Column{Name: "role.name"}, &Column{Name: "role.id"}, ) + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = c.Compile(defaultTemplate) } diff --git a/internal/sqladapter/exql/hash.go b/internal/sqladapter/exql/hash.go deleted file mode 100644 index 7459023c..00000000 --- a/internal/sqladapter/exql/hash.go +++ /dev/null @@ -1,26 +0,0 @@ -package exql - -import ( - "reflect" - "sync/atomic" - - "github.com/upper/db/v4/internal/cache" -) - -type hash struct { - v atomic.Value -} - -func (h *hash) Hash(i interface{}) string { - v := h.v.Load() - if r, ok := v.(string); ok && r != "" { - return r - } - s := reflect.TypeOf(i).String() + ":" + cache.Hash(i) - h.v.Store(s) - return s -} - -func (h *hash) Reset() { - h.v.Store("") -} diff --git a/internal/sqladapter/exql/join.go b/internal/sqladapter/exql/join.go index ba982f7c..c09005a9 100644 --- a/internal/sqladapter/exql/join.go +++ b/internal/sqladapter/exql/join.go @@ -2,6 +2,8 @@ package exql import ( "strings" + + "github.com/upper/db/v4/internal/cache" ) type innerJoinT struct { @@ -14,14 +16,20 @@ type innerJoinT struct { // Joins represents the union of different join conditions. type Joins struct { Conditions []Fragment - hash hash } var _ = Fragment(&Joins{}) // Hash returns a unique identifier for the struct. -func (j *Joins) Hash() string { - return j.hash.Hash(j) +func (j *Joins) Hash() uint64 { + if j == nil { + return cache.NewHash(FragmentType_Joins, nil) + } + h := cache.InitHash(FragmentType_Joins) + for i := range j.Conditions { + h = cache.AddToHash(h, j.Conditions[i]) + } + return h } // Compile transforms the Where into an equivalent SQL representation. @@ -66,14 +74,16 @@ type Join struct { Table Fragment On Fragment Using Fragment - hash hash } var _ = Fragment(&Join{}) // Hash returns a unique identifier for the struct. -func (j *Join) Hash() string { - return j.hash.Hash(j) +func (j *Join) Hash() uint64 { + if j == nil { + return cache.NewHash(FragmentType_Join, nil) + } + return cache.NewHash(FragmentType_Join, j.Type, j.Table, j.On, j.Using) } // Compile transforms the Join into its equivalent SQL representation. @@ -118,9 +128,11 @@ type On Where var _ = Fragment(&On{}) -// Hash returns a unique identifier. -func (o *On) Hash() string { - return o.hash.Hash(o) +func (o *On) Hash() uint64 { + if o == nil { + return cache.NewHash(FragmentType_On, nil) + } + return cache.NewHash(FragmentType_On, (*Where)(o)) } // Compile transforms the On into an equivalent SQL representation. @@ -151,9 +163,11 @@ type usingT struct { Columns string } -// Hash returns a unique identifier. -func (u *Using) Hash() string { - return u.hash.Hash(u) +func (u *Using) Hash() uint64 { + if u == nil { + return cache.NewHash(FragmentType_Using, nil) + } + return cache.NewHash(FragmentType_Using, (*Columns)(u)) } // Compile transforms the Using into an equivalent SQL representation. diff --git a/internal/sqladapter/exql/join_test.go b/internal/sqladapter/exql/join_test.go index 68c603b8..65ce6aa6 100644 --- a/internal/sqladapter/exql/join_test.go +++ b/internal/sqladapter/exql/join_test.go @@ -3,11 +3,11 @@ package exql import ( "fmt" "testing" + + "github.com/stretchr/testify/assert" ) func TestOnAndRawOrAnd(t *testing.T) { - var s, e string - on := OnConditions( JoinWithAnd( &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, @@ -25,33 +25,21 @@ func TestOnAndRawOrAnd(t *testing.T) { ), ) - s = mustTrim(on.Compile(defaultTemplate)) - e = `ON (("id" > 8 AND "id" < 99) AND "name" = 'John' AND city_id = 728 AND ("last_name" = 'Smith' OR "last_name" = 'Reyes') AND ("age" > 18 AND "age" < 41))` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(on.Compile(defaultTemplate)) + assert.Equal(t, `ON (("id" > 8 AND "id" < 99) AND "name" = 'John' AND city_id = 728 AND ("last_name" = 'Smith' OR "last_name" = 'Reyes') AND ("age" > 18 AND "age" < 41))`, s) } func TestUsing(t *testing.T) { - var s, e string - using := UsingColumns( &Column{Name: "country"}, &Column{Name: "state"}, ) - s = mustTrim(using.Compile(defaultTemplate)) - e = `USING ("country", "state")` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(using.Compile(defaultTemplate)) + assert.Equal(t, `USING ("country", "state")`, s) } func TestJoinOn(t *testing.T) { - var s, e string - join := JoinConditions( &Join{ Table: TableWithName("countries c"), @@ -70,17 +58,11 @@ func TestJoinOn(t *testing.T) { }, ) - s = mustTrim(join.Compile(defaultTemplate)) - e = `JOIN "countries" AS "c" ON ("p"."country_id" = "a"."id" AND "p"."country_code" = "a"."code")` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `JOIN "countries" AS "c" ON ("p"."country_id" = "a"."id" AND "p"."country_code" = "a"."code")`, s) } func TestInnerJoinOn(t *testing.T) { - var s, e string - join := JoinConditions(&Join{ Type: "INNER", Table: TableWithName("countries c"), @@ -98,81 +80,51 @@ func TestInnerJoinOn(t *testing.T) { ), }) - s = mustTrim(join.Compile(defaultTemplate)) - e = `INNER JOIN "countries" AS "c" ON ("p"."country_id" = "a"."id" AND "p"."country_code" = "a"."code")` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `INNER JOIN "countries" AS "c" ON ("p"."country_id" = "a"."id" AND "p"."country_code" = "a"."code")`, s) } func TestLeftJoinUsing(t *testing.T) { - var s, e string - join := JoinConditions(&Join{ Type: "LEFT", Table: TableWithName("countries"), Using: UsingColumns(ColumnWithName("name")), }) - s = mustTrim(join.Compile(defaultTemplate)) - e = `LEFT JOIN "countries" USING ("name")` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `LEFT JOIN "countries" USING ("name")`, s) } func TestNaturalJoinOn(t *testing.T) { - var s, e string - join := JoinConditions(&Join{ Table: TableWithName("countries"), }) - s = mustTrim(join.Compile(defaultTemplate)) - e = `NATURAL JOIN "countries"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `NATURAL JOIN "countries"`, s) } func TestNaturalInnerJoinOn(t *testing.T) { - var s, e string - join := JoinConditions(&Join{ Type: "INNER", Table: TableWithName("countries"), }) - s = mustTrim(join.Compile(defaultTemplate)) - e = `NATURAL INNER JOIN "countries"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `NATURAL INNER JOIN "countries"`, s) } func TestCrossJoin(t *testing.T) { - var s, e string - join := JoinConditions(&Join{ Type: "CROSS", Table: TableWithName("countries"), }) - s = mustTrim(join.Compile(defaultTemplate)) - e = `CROSS JOIN "countries"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `CROSS JOIN "countries"`, s) } func TestMultipleJoins(t *testing.T) { - var s, e string - join := JoinConditions(&Join{ Type: "LEFT", Table: TableWithName("countries"), @@ -180,12 +132,8 @@ func TestMultipleJoins(t *testing.T) { Table: TableWithName("cities"), }) - s = mustTrim(join.Compile(defaultTemplate)) - e = `NATURAL LEFT JOIN "countries" NATURAL JOIN "cities"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `NATURAL LEFT JOIN "countries" NATURAL JOIN "cities"`, s) } func BenchmarkJoin(b *testing.B) { @@ -224,6 +172,7 @@ func BenchmarkCompileJoin(b *testing.B) { }, ), }) + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = j.Compile(defaultTemplate) } diff --git a/internal/sqladapter/exql/order_by.go b/internal/sqladapter/exql/order_by.go index 8ee9f464..ab35507f 100644 --- a/internal/sqladapter/exql/order_by.go +++ b/internal/sqladapter/exql/order_by.go @@ -1,8 +1,9 @@ package exql import ( - "fmt" "strings" + + "github.com/upper/db/v4/internal/cache" ) // Order represents the order in which SQL results are sorted. @@ -10,16 +11,20 @@ type Order uint8 // Possible values for Order const ( - DefaultOrder = Order(iota) - Ascendent - Descendent + Order_Default Order = iota + + Order_Ascendent + Order_Descendent ) +func (o Order) Hash() uint64 { + return cache.NewHash(FragmentType_Order, uint8(o)) +} + // SortColumn represents the column-order relation in an ORDER BY clause. type SortColumn struct { Column Fragment Order - hash hash } var _ = Fragment(&SortColumn{}) @@ -34,7 +39,6 @@ var _ = Fragment(&SortColumn{}) // SortColumns represents the columns in an ORDER BY clause. type SortColumns struct { Columns []Fragment - hash hash } var _ = Fragment(&SortColumns{}) @@ -42,7 +46,6 @@ var _ = Fragment(&SortColumns{}) // OrderBy represents an ORDER BY clause. type OrderBy struct { SortColumns Fragment - hash hash } var _ = Fragment(&OrderBy{}) @@ -62,8 +65,11 @@ func JoinWithOrderBy(sc *SortColumns) *OrderBy { } // Hash returns a unique identifier for the struct. -func (s *SortColumn) Hash() string { - return s.hash.Hash(s) +func (s *SortColumn) Hash() uint64 { + if s == nil { + return cache.NewHash(FragmentType_SortColumn, nil) + } + return cache.NewHash(FragmentType_SortColumn, s.Column, s.Order) } // Compile transforms the SortColumn into an equivalent SQL representation. @@ -93,8 +99,15 @@ func (s *SortColumn) Compile(layout *Template) (compiled string, err error) { } // Hash returns a unique identifier for the struct. -func (s *SortColumns) Hash() string { - return s.hash.Hash(s) +func (s *SortColumns) Hash() uint64 { + if s == nil { + return cache.NewHash(FragmentType_SortColumns, nil) + } + h := cache.InitHash(FragmentType_SortColumns) + for i := range s.Columns { + h = cache.AddToHash(h, s.Columns[i]) + } + return h } // Compile transforms the SortColumns into an equivalent SQL representation. @@ -120,8 +133,11 @@ func (s *SortColumns) Compile(layout *Template) (compiled string, err error) { } // Hash returns a unique identifier for the struct. -func (s *OrderBy) Hash() string { - return s.hash.Hash(s) +func (s *OrderBy) Hash() uint64 { + if s == nil { + return cache.NewHash(FragmentType_OrderBy, nil) + } + return cache.NewHash(FragmentType_OrderBy, s.SortColumns) } // Compile transforms the SortColumn into an equivalent SQL representation. @@ -147,17 +163,12 @@ func (s *OrderBy) Compile(layout *Template) (compiled string, err error) { return } -// Hash returns a unique identifier. -func (s *Order) Hash() string { - return fmt.Sprintf("%T.%d", s, uint8(*s)) -} - // Compile transforms the SortColumn into an equivalent SQL representation. func (s Order) Compile(layout *Template) (string, error) { switch s { - case Ascendent: + case Order_Ascendent: return layout.AscKeyword, nil - case Descendent: + case Order_Descendent: return layout.DescKeyword, nil } return "", nil diff --git a/internal/sqladapter/exql/order_by_test.go b/internal/sqladapter/exql/order_by_test.go index 1202fef8..f214f868 100644 --- a/internal/sqladapter/exql/order_by_test.go +++ b/internal/sqladapter/exql/order_by_test.go @@ -2,6 +2,8 @@ package exql import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestOrderBy(t *testing.T) { @@ -12,38 +14,29 @@ func TestOrderBy(t *testing.T) { ) s := mustTrim(o.Compile(defaultTemplate)) - e := `ORDER BY "foo"` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.Equal(t, `ORDER BY "foo"`, s) } func TestOrderByRaw(t *testing.T) { o := JoinWithOrderBy( JoinSortColumns( - &SortColumn{Column: RawValue("CASE WHEN id IN ? THEN 0 ELSE 1 END")}, + &SortColumn{Column: &Raw{Value: "CASE WHEN id IN ? THEN 0 ELSE 1 END"}}, ), ) s := mustTrim(o.Compile(defaultTemplate)) - e := `ORDER BY CASE WHEN id IN ? THEN 0 ELSE 1 END` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.Equal(t, `ORDER BY CASE WHEN id IN ? THEN 0 ELSE 1 END`, s) } func TestOrderByDesc(t *testing.T) { o := JoinWithOrderBy( JoinSortColumns( - &SortColumn{Column: &Column{Name: "foo"}, Order: Descendent}, + &SortColumn{Column: &Column{Name: "foo"}, Order: Order_Descendent}, ), ) s := mustTrim(o.Compile(defaultTemplate)) - e := `ORDER BY "foo" DESC` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.Equal(t, `ORDER BY "foo" DESC`, s) } func BenchmarkOrderBy(b *testing.B) { @@ -62,6 +55,7 @@ func BenchmarkOrderByHash(b *testing.B) { &SortColumn{Column: &Column{Name: "foo"}}, ), } + b.ResetTimer() for i := 0; i < b.N; i++ { o.Hash() } @@ -73,6 +67,7 @@ func BenchmarkCompileOrderByCompile(b *testing.B) { &SortColumn{Column: &Column{Name: "foo"}}, ), } + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = o.Compile(defaultTemplate) } @@ -90,7 +85,7 @@ func BenchmarkCompileOrderByCompileNoCache(b *testing.B) { } func BenchmarkCompileOrderCompile(b *testing.B) { - o := Descendent + o := Order_Descendent for i := 0; i < b.N; i++ { _, _ = o.Compile(defaultTemplate) } @@ -98,13 +93,14 @@ func BenchmarkCompileOrderCompile(b *testing.B) { func BenchmarkCompileOrderCompileNoCache(b *testing.B) { for i := 0; i < b.N; i++ { - o := Descendent + o := Order_Descendent _, _ = o.Compile(defaultTemplate) } } func BenchmarkSortColumnHash(b *testing.B) { s := &SortColumn{Column: &Column{Name: "foo"}} + b.ResetTimer() for i := 0; i < b.N; i++ { s.Hash() } @@ -112,6 +108,7 @@ func BenchmarkSortColumnHash(b *testing.B) { func BenchmarkSortColumnCompile(b *testing.B) { s := &SortColumn{Column: &Column{Name: "foo"}} + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = s.Compile(defaultTemplate) } @@ -129,6 +126,7 @@ func BenchmarkSortColumnsHash(b *testing.B) { &SortColumn{Column: &Column{Name: "foo"}}, &SortColumn{Column: &Column{Name: "bar"}}, ) + b.ResetTimer() for i := 0; i < b.N; i++ { s.Hash() } @@ -139,6 +137,7 @@ func BenchmarkSortColumnsCompile(b *testing.B) { &SortColumn{Column: &Column{Name: "foo"}}, &SortColumn{Column: &Column{Name: "bar"}}, ) + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = s.Compile(defaultTemplate) } diff --git a/internal/sqladapter/exql/raw.go b/internal/sqladapter/exql/raw.go index 2936c879..54dc97a1 100644 --- a/internal/sqladapter/exql/raw.go +++ b/internal/sqladapter/exql/raw.go @@ -2,7 +2,8 @@ package exql import ( "fmt" - "strings" + + "github.com/upper/db/v4/internal/cache" ) var ( @@ -11,18 +12,27 @@ var ( // Raw represents a value that is meant to be used in a query without escaping. type Raw struct { - Value string // Value should not be modified after assigned. - hash hash + Value string } -// RawValue creates and returns a new raw value. -func RawValue(v string) *Raw { - return &Raw{Value: strings.TrimSpace(v)} +func NewRawValue(v interface{}) (*Raw, error) { + switch t := v.(type) { + case string: + return &Raw{Value: t}, nil + case int, uint, int64, uint64, int32, uint32, int16, uint16: + return &Raw{Value: fmt.Sprintf("%d", t)}, nil + case fmt.Stringer: + return &Raw{Value: t.String()}, nil + } + return nil, fmt.Errorf("unexpected type: %T", v) } // Hash returns a unique identifier for the struct. -func (r *Raw) Hash() string { - return r.hash.Hash(r) +func (r *Raw) Hash() uint64 { + if r == nil { + return cache.NewHash(FragmentType_Raw, nil) + } + return cache.NewHash(FragmentType_Raw, r.Value) } // Compile returns the raw value. diff --git a/internal/sqladapter/exql/raw_test.go b/internal/sqladapter/exql/raw_test.go index 076d574b..66e38b1d 100644 --- a/internal/sqladapter/exql/raw_test.go +++ b/internal/sqladapter/exql/raw_test.go @@ -2,47 +2,22 @@ package exql import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestRawString(t *testing.T) { raw := &Raw{Value: "foo"} - s, err := raw.Compile(defaultTemplate) - if err != nil { - t.Fatal() - } - - e := `foo` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.NoError(t, err) + assert.Equal(t, `foo`, s) } func TestRawCompile(t *testing.T) { raw := &Raw{Value: "foo"} - s, err := raw.Compile(defaultTemplate) - if err != nil { - t.Fatal() - } - - e := `foo` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } -} - -func TestRawHash(t *testing.T) { - var s, e string - - raw := &Raw{Value: "foo"} - - s = raw.Hash() - e = `*exql.Raw:5772950988983410957` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.NoError(t, err) + assert.Equal(t, `foo`, s) } func BenchmarkRawCreate(b *testing.B) { @@ -53,6 +28,7 @@ func BenchmarkRawCreate(b *testing.B) { func BenchmarkRawString(b *testing.B) { raw := &Raw{Value: "foo"} + b.ResetTimer() for i := 0; i < b.N; i++ { _ = raw.String() } @@ -60,6 +36,7 @@ func BenchmarkRawString(b *testing.B) { func BenchmarkRawCompile(b *testing.B) { raw := &Raw{Value: "foo"} + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = raw.Compile(defaultTemplate) } @@ -67,6 +44,7 @@ func BenchmarkRawCompile(b *testing.B) { func BenchmarkRawHash(b *testing.B) { raw := &Raw{Value: "foo"} + b.ResetTimer() for i := 0; i < b.N; i++ { raw.Hash() } diff --git a/internal/sqladapter/exql/returning.go b/internal/sqladapter/exql/returning.go index ef392bf5..6e28f0a5 100644 --- a/internal/sqladapter/exql/returning.go +++ b/internal/sqladapter/exql/returning.go @@ -1,14 +1,20 @@ package exql +import ( + "github.com/upper/db/v4/internal/cache" +) + // Returning represents a RETURNING clause. type Returning struct { *Columns - hash hash } // Hash returns a unique identifier for the struct. -func (r *Returning) Hash() string { - return r.hash.Hash(r) +func (r *Returning) Hash() uint64 { + if r == nil { + return cache.NewHash(FragmentType_Returning, nil) + } + return cache.NewHash(FragmentType_Returning, r.Columns) } var _ = Fragment(&Returning{}) diff --git a/internal/sqladapter/exql/statement.go b/internal/sqladapter/exql/statement.go index 032466e2..9b9fd480 100644 --- a/internal/sqladapter/exql/statement.go +++ b/internal/sqladapter/exql/statement.go @@ -4,6 +4,8 @@ import ( "errors" "reflect" "strings" + + "github.com/upper/db/v4/internal/cache" ) var errUnknownTemplateType = errors.New("Unknown template type") @@ -28,7 +30,6 @@ type Statement struct { SQL string - hash hash amendFn func(string) string } @@ -40,8 +41,28 @@ func (layout *Template) doCompile(c Fragment) (string, error) { } // Hash returns a unique identifier for the struct. -func (s *Statement) Hash() string { - return s.hash.Hash(s) +func (s *Statement) Hash() uint64 { + if s == nil { + return cache.NewHash(FragmentType_Statement, nil) + } + return cache.NewHash( + FragmentType_Statement, + s.Type, + s.Table, + s.Database, + s.Columns, + s.Values, + s.Distinct, + s.ColumnValues, + s.OrderBy, + s.GroupBy, + s.Joins, + s.Where, + s.Returning, + s.Limit, + s.Offset, + s.SQL, + ) } func (s *Statement) SetAmendment(amendFn func(string) string) { diff --git a/internal/sqladapter/exql/statement_test.go b/internal/sqladapter/exql/statement_test.go index 33fe21b7..28e726a3 100644 --- a/internal/sqladapter/exql/statement_test.go +++ b/internal/sqladapter/exql/statement_test.go @@ -1,211 +1,134 @@ package exql import ( - "regexp" - "strings" "testing" -) -var ( - reInvisible = regexp.MustCompile(`[\t\n\r]`) - reSpace = regexp.MustCompile(`\s+`) + "github.com/stretchr/testify/assert" ) -func mustTrim(a string, err error) string { - if err != nil { - panic(err.Error()) - } - a = reInvisible.ReplaceAllString(strings.TrimSpace(a), " ") - a = reSpace.ReplaceAllString(strings.TrimSpace(a), " ") - return a -} - func TestTruncateTable(t *testing.T) { - var s, e string - stmt := Statement{ Type: Truncate, Table: TableWithName("table_name"), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `TRUNCATE TABLE "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `TRUNCATE TABLE "table_name"`, s) } func TestDropTable(t *testing.T) { - var s, e string - stmt := Statement{ Type: DropTable, Table: TableWithName("table_name"), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `DROP TABLE "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `DROP TABLE "table_name"`, s) } func TestDropDatabase(t *testing.T) { - var s, e string - stmt := Statement{ Type: DropDatabase, Database: &Database{Name: "table_name"}, } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `DROP DATABASE "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `DROP DATABASE "table_name"`, s) } func TestCount(t *testing.T) { - var s, e string - stmt := Statement{ Type: Count, Table: TableWithName("table_name"), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT COUNT(1) AS _t FROM "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT COUNT(1) AS _t FROM "table_name"`, s) } func TestCountRelation(t *testing.T) { - var s, e string - stmt := Statement{ Type: Count, Table: TableWithName("information_schema.tables"), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT COUNT(1) AS _t FROM "information_schema"."tables"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT COUNT(1) AS _t FROM "information_schema"."tables"`, s) } func TestCountWhere(t *testing.T) { - var s, e string - stmt := Statement{ Type: Count, Table: TableWithName("table_name"), Where: WhereConditions( - &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(RawValue("7"))}, + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: &Raw{Value: "7"}}, ), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT COUNT(1) AS _t FROM "table_name" WHERE ("a" = 7)` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT COUNT(1) AS _t FROM "table_name" WHERE ("a" = 7)`, s) } func TestSelectStarFrom(t *testing.T) { - var s, e string - stmt := Statement{ Type: Select, Table: TableWithName("table_name"), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT * FROM "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "table_name"`, s) } func TestSelectStarFromAlias(t *testing.T) { - var s, e string - stmt := Statement{ Type: Select, Table: TableWithName("table.name AS foo"), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT * FROM "table"."name" AS "foo"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "table"."name" AS "foo"`, s) } func TestSelectStarFromRawWhere(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: Select, - Table: TableWithName("table.name AS foo"), - Where: WhereConditions( - &Raw{Value: "foo.id = bar.foo_id"}, - ), - } - - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT * FROM "table"."name" AS "foo" WHERE (foo.id = bar.foo_id)` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + { + stmt := Statement{ + Type: Select, + Table: TableWithName("table.name AS foo"), + Where: WhereConditions( + &Raw{Value: "foo.id = bar.foo_id"}, + ), + } - stmt = Statement{ - Type: Select, - Table: TableWithName("table.name AS foo"), - Where: WhereConditions( - &Raw{Value: "foo.id = bar.foo_id"}, - &Raw{Value: "baz.id = exp.baz_id"}, - ), + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "table"."name" AS "foo" WHERE (foo.id = bar.foo_id)`, s) } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT * FROM "table"."name" AS "foo" WHERE (foo.id = bar.foo_id AND baz.id = exp.baz_id)` + { + stmt := Statement{ + Type: Select, + Table: TableWithName("table.name AS foo"), + Where: WhereConditions( + &Raw{Value: "foo.id = bar.foo_id"}, + &Raw{Value: "baz.id = exp.baz_id"}, + ), + } - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "table"."name" AS "foo" WHERE (foo.id = bar.foo_id AND baz.id = exp.baz_id)`, s) } } func TestSelectStarFromMany(t *testing.T) { - var s, e string - stmt := Statement{ Type: Select, Table: TableWithName("first.table AS foo, second.table as BAR, third.table aS baz"), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT * FROM "first"."table" AS "foo", "second"."table" AS "BAR", "third"."table" AS "baz"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "first"."table" AS "foo", "second"."table" AS "BAR", "third"."table" AS "baz"`, s) } func TestSelectTableStarFromMany(t *testing.T) { - var s, e string - stmt := Statement{ Type: Select, Columns: JoinColumns( @@ -216,17 +139,11 @@ func TestSelectTableStarFromMany(t *testing.T) { Table: TableWithName("first.table AS foo, second.table as BAR, third.table aS baz"), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo"."name", "BAR".*, "baz"."last_name" FROM "first"."table" AS "foo", "second"."table" AS "BAR", "third"."table" AS "baz"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo"."name", "BAR".*, "baz"."last_name" FROM "first"."table" AS "foo", "second"."table" AS "BAR", "third"."table" AS "baz"`, s) } func TestSelectArtistNameFrom(t *testing.T) { - var s, e string - stmt := Statement{ Type: Select, Table: TableWithName("artist"), @@ -235,17 +152,11 @@ func TestSelectArtistNameFrom(t *testing.T) { ), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "artist"."name" FROM "artist"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "artist"."name" FROM "artist"`, s) } func TestSelectJoin(t *testing.T) { - var s, e string - stmt := Statement{ Type: Select, Table: TableWithName("artist a"), @@ -264,17 +175,11 @@ func TestSelectJoin(t *testing.T) { }), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "a"."name" FROM "artist" AS "a" JOIN "books" AS "b" ON ("b"."author_id" = "a"."id")` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "a"."name" FROM "artist" AS "a" JOIN "books" AS "b" ON ("b"."author_id" = "a"."id")`, s) } func TestSelectJoinUsing(t *testing.T) { - var s, e string - stmt := Statement{ Type: Select, Table: TableWithName("artist a"), @@ -290,12 +195,8 @@ func TestSelectJoinUsing(t *testing.T) { }), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "a"."name" FROM "artist" AS "a" JOIN "books" AS "b" USING ("artist_id", "country")` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "a"."name" FROM "artist" AS "a" JOIN "books" AS "b" USING ("artist_id", "country")`, s) } func TestSelectUnfinishedJoin(t *testing.T) { @@ -309,15 +210,10 @@ func TestSelectUnfinishedJoin(t *testing.T) { } s := mustTrim(stmt.Compile(defaultTemplate)) - e := `SELECT "a"."name" FROM "artist" AS "a"` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.Equal(t, `SELECT "a"."name" FROM "artist" AS "a"`, s) } func TestSelectNaturalJoin(t *testing.T) { - var s, e string - stmt := Statement{ Type: Select, Table: TableWithName("artist"), @@ -326,37 +222,25 @@ func TestSelectNaturalJoin(t *testing.T) { }), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT * FROM "artist" NATURAL JOIN "books"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "artist" NATURAL JOIN "books"`, s) } func TestSelectRawFrom(t *testing.T) { - var s, e string - stmt := Statement{ Type: Select, Table: TableWithName(`artist`), Columns: JoinColumns( &Column{Name: `artist.name`}, - &Column{Name: Raw{Value: `CONCAT(artist.name, " ", artist.last_name)`}}, + &Column{Name: &Raw{Value: `CONCAT(artist.name, " ", artist.last_name)`}}, ), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "artist"."name", CONCAT(artist.name, " ", artist.last_name) FROM "artist"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "artist"."name", CONCAT(artist.name, " ", artist.last_name) FROM "artist"`, s) } func TestSelectFieldsFrom(t *testing.T) { - var s, e string - stmt := Statement{ Type: Select, Columns: JoinColumns( @@ -367,300 +251,248 @@ func TestSelectFieldsFrom(t *testing.T) { Table: TableWithName("table_name"), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name"`, s) } func TestSelectFieldsFromWithLimitOffset(t *testing.T) { - var s, e string - var stmt Statement - - // LIMIT only. - stmt = Statement{ - Type: Select, - Columns: JoinColumns( - &Column{Name: "foo"}, - &Column{Name: "bar"}, - &Column{Name: "baz"}, - ), - Limit: 42, - Table: TableWithName("table_name"), - } - - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" LIMIT 42` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Limit: 42, + Table: TableWithName("table_name"), + } - // OFFSET only. - stmt = Statement{ - Type: Select, - Columns: JoinColumns( - &Column{Name: "foo"}, - &Column{Name: "bar"}, - &Column{Name: "baz"}, - ), - Offset: 17, - Table: TableWithName("table_name"), + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" LIMIT 42`, s) } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" OFFSET 17` + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Offset: 17, + Table: TableWithName("table_name"), + } - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" OFFSET 17`, s) } - // LIMIT AND OFFSET. - stmt = Statement{ - Type: Select, - Columns: JoinColumns( - &Column{Name: "foo"}, - &Column{Name: "bar"}, - &Column{Name: "baz"}, - ), - Limit: 42, - Offset: 17, - Table: TableWithName("table_name"), - } - - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" LIMIT 42 OFFSET 17` + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Limit: 42, + Offset: 17, + Table: TableWithName("table_name"), + } - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" LIMIT 42 OFFSET 17`, s) } } func TestStatementGroupBy(t *testing.T) { - var s, e string - var stmt Statement - - // Simple GROUP BY - stmt = Statement{ - Type: Select, - Columns: JoinColumns( - &Column{Name: "foo"}, - &Column{Name: "bar"}, - &Column{Name: "baz"}, - ), - GroupBy: GroupByColumns( - &Column{Name: "foo"}, - ), - Table: TableWithName("table_name"), - } - - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" GROUP BY "foo"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + GroupBy: GroupByColumns( + &Column{Name: "foo"}, + ), + Table: TableWithName("table_name"), + } - stmt = Statement{ - Type: Select, - Columns: JoinColumns( - &Column{Name: "foo"}, - &Column{Name: "bar"}, - &Column{Name: "baz"}, - ), - GroupBy: GroupByColumns( - &Column{Name: "foo"}, - &Column{Name: "bar"}, - ), - Table: TableWithName("table_name"), + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" GROUP BY "foo"`, s) } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" GROUP BY "foo", "bar"` + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + GroupBy: GroupByColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + ), + Table: TableWithName("table_name"), + } - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" GROUP BY "foo", "bar"`, s) } } func TestSelectFieldsFromWithOrderBy(t *testing.T) { - var s, e string - var stmt Statement - - // Simple ORDER BY - stmt = Statement{ - Type: Select, - Columns: JoinColumns( - &Column{Name: "foo"}, - &Column{Name: "bar"}, - &Column{Name: "baz"}, - ), - OrderBy: JoinWithOrderBy( - JoinSortColumns( - &SortColumn{Column: &Column{Name: "foo"}}, + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, ), - ), - Table: TableWithName("table_name"), - } - - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo"` + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + ), + Table: TableWithName("table_name"), + } - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo"`, s) } - // ORDER BY field ASC - stmt = Statement{ - Type: Select, - Columns: JoinColumns( - &Column{Name: "foo"}, - &Column{Name: "bar"}, - &Column{Name: "baz"}, - ), - OrderBy: JoinWithOrderBy( - JoinSortColumns( - &SortColumn{Column: &Column{Name: "foo"}, Order: Ascendent}, + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, ), - ), - Table: TableWithName("table_name"), - } - - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" ASC` + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Order_Ascendent}, + ), + ), + Table: TableWithName("table_name"), + } - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" ASC`, s) } - // ORDER BY field DESC - stmt = Statement{ - Type: Select, - Columns: JoinColumns( - &Column{Name: "foo"}, - &Column{Name: "bar"}, - &Column{Name: "baz"}, - ), - OrderBy: JoinWithOrderBy( - JoinSortColumns( - &SortColumn{Column: &Column{Name: "foo"}, Order: Descendent}, + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, ), - ), - Table: TableWithName("table_name"), - } - - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" DESC` + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Order_Descendent}, + ), + ), + Table: TableWithName("table_name"), + } - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" DESC`, s) } - // ORDER BY many fields - stmt = Statement{ - Type: Select, - Columns: JoinColumns( - &Column{Name: "foo"}, - &Column{Name: "bar"}, - &Column{Name: "baz"}, - ), - OrderBy: JoinWithOrderBy( - JoinSortColumns( - &SortColumn{Column: &Column{Name: "foo"}, Order: Descendent}, - &SortColumn{Column: &Column{Name: "bar"}, Order: Ascendent}, - &SortColumn{Column: &Column{Name: "baz"}, Order: Descendent}, + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, ), - ), - Table: TableWithName("table_name"), - } - - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" DESC, "bar" ASC, "baz" DESC` + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Order_Descendent}, + &SortColumn{Column: &Column{Name: "bar"}, Order: Order_Ascendent}, + &SortColumn{Column: &Column{Name: "baz"}, Order: Order_Descendent}, + ), + ), + Table: TableWithName("table_name"), + } - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" DESC, "bar" ASC, "baz" DESC`, s) } - // ORDER BY function - stmt = Statement{ - Type: Select, - Columns: JoinColumns( - &Column{Name: "foo"}, - &Column{Name: "bar"}, - &Column{Name: "baz"}, - ), - OrderBy: JoinWithOrderBy( - JoinSortColumns( - &SortColumn{Column: &Column{Name: Raw{Value: "FOO()"}}, Order: Descendent}, - &SortColumn{Column: &Column{Name: Raw{Value: "BAR()"}}, Order: Ascendent}, + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, ), - ), - Table: TableWithName("table_name"), - } - - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY FOO() DESC, BAR() ASC` + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: &Raw{Value: "FOO()"}}, Order: Order_Descendent}, + &SortColumn{Column: &Column{Name: &Raw{Value: "BAR()"}}, Order: Order_Ascendent}, + ), + ), + Table: TableWithName("table_name"), + } - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY FOO() DESC, BAR() ASC`, s) } } func TestSelectFieldsFromWhere(t *testing.T) { - var s, e string - - stmt := Statement{ - Type: Select, - Columns: JoinColumns( - &Column{Name: "foo"}, - &Column{Name: "bar"}, - &Column{Name: "baz"}, - ), - Table: TableWithName("table_name"), - Where: WhereConditions( - &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, - ), - } - - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" WHERE ("baz" = '99')` + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" WHERE ("baz" = '99')`, s) } } func TestSelectFieldsFromWhereLimitOffset(t *testing.T) { - var s, e string - - stmt := Statement{ - Type: Select, - Columns: JoinColumns( - &Column{Name: "foo"}, - &Column{Name: "bar"}, - &Column{Name: "baz"}, - ), - Table: TableWithName("table_name"), - Where: WhereConditions( - &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, - ), - Limit: 10, - Offset: 23, - } - - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `SELECT "foo", "bar", "baz" FROM "table_name" WHERE ("baz" = '99') LIMIT 10 OFFSET 23` + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + Limit: 10, + Offset: 23, + } - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" WHERE ("baz" = '99') LIMIT 10 OFFSET 23`, s) } } func TestDelete(t *testing.T) { - var s, e string - stmt := Statement{ Type: Delete, Table: TableWithName("table_name"), @@ -669,59 +501,46 @@ func TestDelete(t *testing.T) { ), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `DELETE FROM "table_name" WHERE ("baz" = '99')` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `DELETE FROM "table_name" WHERE ("baz" = '99')`, s) } func TestUpdate(t *testing.T) { - var s, e string - var stmt Statement - - stmt = Statement{ - Type: Update, - Table: TableWithName("table_name"), - ColumnValues: JoinColumnValues( - &ColumnValue{Column: &Column{Name: "foo"}, Operator: "=", Value: NewValue(76)}, - ), - Where: WhereConditions( - &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, - ), - } - - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `UPDATE "table_name" SET "foo" = '76' WHERE ("baz" = '99')` + { + stmt := Statement{ + Type: Update, + Table: TableWithName("table_name"), + ColumnValues: JoinColumnValues( + &ColumnValue{Column: &Column{Name: "foo"}, Operator: "=", Value: NewValue(76)}, + ), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `UPDATE "table_name" SET "foo" = '76' WHERE ("baz" = '99')`, s) } - stmt = Statement{ - Type: Update, - Table: TableWithName("table_name"), - ColumnValues: JoinColumnValues( - &ColumnValue{Column: &Column{Name: "foo"}, Operator: "=", Value: NewValue(76)}, - &ColumnValue{Column: &Column{Name: "bar"}, Operator: "=", Value: NewValue(Raw{Value: "88"})}, - ), - Where: WhereConditions( - &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, - ), - } - - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `UPDATE "table_name" SET "foo" = '76', "bar" = 88 WHERE ("baz" = '99')` + { + stmt := Statement{ + Type: Update, + Table: TableWithName("table_name"), + ColumnValues: JoinColumnValues( + &ColumnValue{Column: &Column{Name: "foo"}, Operator: "=", Value: NewValue(76)}, + &ColumnValue{Column: &Column{Name: "bar"}, Operator: "=", Value: NewValue(&Raw{Value: "88"})}, + ), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `UPDATE "table_name" SET "foo" = '76', "bar" = 88 WHERE ("baz" = '99')`, s) } } func TestInsert(t *testing.T) { - var s, e string - stmt := Statement{ Type: Insert, Table: TableWithName("table_name"), @@ -733,21 +552,15 @@ func TestInsert(t *testing.T) { Values: NewValueGroup( &Value{V: "1"}, &Value{V: 2}, - &Value{V: Raw{Value: "3"}}, + &Value{V: &Raw{Value: "3"}}, ), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3)` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3)`, s) } func TestInsertMultiple(t *testing.T) { - var s, e string - stmt := Statement{ Type: Insert, Table: TableWithName("table_name"), @@ -760,27 +573,21 @@ func TestInsertMultiple(t *testing.T) { NewValueGroup( NewValue("1"), NewValue("2"), - NewValue(RawValue("3")), + NewValue(&Raw{Value: "3"}), ), NewValueGroup( - NewValue(RawValue("4")), - NewValue(RawValue("5")), - NewValue(RawValue("6")), + NewValue(&Raw{Value: "4"}), + NewValue(&Raw{Value: "5"}), + NewValue(&Raw{Value: "6"}), ), ), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3), (4, 5, 6)` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3), (4, 5, 6)`, s) } func TestInsertReturning(t *testing.T) { - var s, e string - stmt := Statement{ Type: Insert, Table: TableWithName("table_name"), @@ -795,27 +602,19 @@ func TestInsertReturning(t *testing.T) { Values: NewValueGroup( &Value{V: "1"}, &Value{V: 2}, - &Value{V: Raw{Value: "3"}}, + &Value{V: &Raw{Value: "3"}}, ), } - s = mustTrim(stmt.Compile(defaultTemplate)) - e = `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3) RETURNING "id"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3) RETURNING "id"`, s) } func TestRawSQLStatement(t *testing.T) { stmt := RawSQL(`SELECT * FROM "foo" ORDER BY "bar"`) s := mustTrim(stmt.Compile(defaultTemplate)) - e := `SELECT * FROM "foo" ORDER BY "bar"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.Equal(t, `SELECT * FROM "foo" ORDER BY "bar"`, s) } func BenchmarkStatementSimpleQuery(b *testing.B) { @@ -823,10 +622,11 @@ func BenchmarkStatementSimpleQuery(b *testing.B) { Type: Count, Table: TableWithName("table_name"), Where: WhereConditions( - &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(Raw{Value: "7"})}, + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(&Raw{Value: "7"})}, ), } + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = stmt.Compile(defaultTemplate) } @@ -837,10 +637,11 @@ func BenchmarkStatementSimpleQueryHash(b *testing.B) { Type: Count, Table: TableWithName("table_name"), Where: WhereConditions( - &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(Raw{Value: "7"})}, + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(&Raw{Value: "7"})}, ), } + b.ResetTimer() for i := 0; i < b.N; i++ { _ = stmt.Hash() } @@ -852,7 +653,7 @@ func BenchmarkStatementSimpleQueryNoCache(b *testing.B) { Type: Count, Table: TableWithName("table_name"), Where: WhereConditions( - &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(Raw{Value: "7"})}, + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(&Raw{Value: "7"})}, ), } _, _ = stmt.Compile(defaultTemplate) @@ -871,10 +672,11 @@ func BenchmarkStatementComplexQuery(b *testing.B) { Values: NewValueGroup( &Value{V: "1"}, &Value{V: 2}, - &Value{V: Raw{Value: "3"}}, + &Value{V: &Raw{Value: "3"}}, ), } + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = stmt.Compile(defaultTemplate) } @@ -893,7 +695,7 @@ func BenchmarkStatementComplexQueryNoCache(b *testing.B) { Values: NewValueGroup( &Value{V: "1"}, &Value{V: 2}, - &Value{V: Raw{Value: "3"}}, + &Value{V: &Raw{Value: "3"}}, ), } _, _ = stmt.Compile(defaultTemplate) diff --git a/internal/sqladapter/exql/table.go b/internal/sqladapter/exql/table.go index 5c0c8f83..8b5f9edc 100644 --- a/internal/sqladapter/exql/table.go +++ b/internal/sqladapter/exql/table.go @@ -2,6 +2,8 @@ package exql import ( "strings" + + "github.com/upper/db/v4/internal/cache" ) type tableT struct { @@ -12,7 +14,6 @@ type tableT struct { // Table struct represents a SQL table. type Table struct { Name interface{} - hash hash } var _ = Fragment(&Table{}) @@ -57,8 +58,11 @@ func TableWithName(name string) *Table { } // Hash returns a string hash of the table value. -func (t *Table) Hash() string { - return t.hash.Hash(t) +func (t *Table) Hash() uint64 { + if t == nil { + return cache.NewHash(FragmentType_Table, nil) + } + return cache.NewHash(FragmentType_Table, t.Name) } // Compile transforms a table struct into a SQL chunk. diff --git a/internal/sqladapter/exql/table_test.go b/internal/sqladapter/exql/table_test.go index 8f374b4a..08bc8250 100644 --- a/internal/sqladapter/exql/table_test.go +++ b/internal/sqladapter/exql/table_test.go @@ -1,111 +1,55 @@ package exql import ( + "github.com/stretchr/testify/assert" + "testing" ) func TestTableSimple(t *testing.T) { - var s, e string - table := TableWithName("artist") - - s = mustTrim(table.Compile(defaultTemplate)) - e = `"artist"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.Equal(t, `"artist"`, mustTrim(table.Compile(defaultTemplate))) } func TestTableCompound(t *testing.T) { - var s, e string - table := TableWithName("artist.foo") - - s = mustTrim(table.Compile(defaultTemplate)) - e = `"artist"."foo"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.Equal(t, `"artist"."foo"`, mustTrim(table.Compile(defaultTemplate))) } func TestTableCompoundAlias(t *testing.T) { - var s, e string - table := TableWithName("artist.foo AS baz") - s = mustTrim(table.Compile(defaultTemplate)) - e = `"artist"."foo" AS "baz"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.Equal(t, `"artist"."foo" AS "baz"`, mustTrim(table.Compile(defaultTemplate))) } func TestTableImplicitAlias(t *testing.T) { - var s, e string - table := TableWithName("artist.foo baz") - s = mustTrim(table.Compile(defaultTemplate)) - e = `"artist"."foo" AS "baz"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.Equal(t, `"artist"."foo" AS "baz"`, mustTrim(table.Compile(defaultTemplate))) } func TestTableMultiple(t *testing.T) { - var s, e string - table := TableWithName("artist.foo, artist.bar, artist.baz") - s = mustTrim(table.Compile(defaultTemplate)) - e = `"artist"."foo", "artist"."bar", "artist"."baz"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.Equal(t, `"artist"."foo", "artist"."bar", "artist"."baz"`, mustTrim(table.Compile(defaultTemplate))) } func TestTableMultipleAlias(t *testing.T) { - var s, e string - table := TableWithName("artist.foo AS foo, artist.bar as bar, artist.baz As baz") - s = mustTrim(table.Compile(defaultTemplate)) - e = `"artist"."foo" AS "foo", "artist"."bar" AS "bar", "artist"."baz" AS "baz"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.Equal(t, `"artist"."foo" AS "foo", "artist"."bar" AS "bar", "artist"."baz" AS "baz"`, mustTrim(table.Compile(defaultTemplate))) } func TestTableMinimal(t *testing.T) { - var s, e string - table := TableWithName("a") - s = mustTrim(table.Compile(defaultTemplate)) - e = `"a"` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.Equal(t, `"a"`, mustTrim(table.Compile(defaultTemplate))) } func TestTableEmpty(t *testing.T) { - var s, e string - table := TableWithName("") - s = mustTrim(table.Compile(defaultTemplate)) - e = `` - - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.Equal(t, "", mustTrim(table.Compile(defaultTemplate))) } func BenchmarkTableWithName(b *testing.B) { @@ -116,6 +60,7 @@ func BenchmarkTableWithName(b *testing.B) { func BenchmarkTableHash(b *testing.B) { t := TableWithName("name") + b.ResetTimer() for i := 0; i < b.N; i++ { t.Hash() } @@ -123,6 +68,7 @@ func BenchmarkTableHash(b *testing.B) { func BenchmarkTableCompile(b *testing.B) { t := TableWithName("name") + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = t.Compile(defaultTemplate) } diff --git a/internal/sqladapter/exql/template.go b/internal/sqladapter/exql/template.go index c32580a0..9aef852a 100644 --- a/internal/sqladapter/exql/template.go +++ b/internal/sqladapter/exql/template.go @@ -11,11 +11,11 @@ import ( ) // Type is the type of SQL query the statement represents. -type Type uint +type Type uint8 // Values for Type. const ( - NoOp = Type(iota) + NoOp Type = iota Truncate DropTable @@ -29,13 +29,25 @@ const ( SQL ) +func (t Type) Hash() uint64 { + return cache.NewHash(FragmentType_StatementType, uint8(t)) +} + type ( // Limit represents the SQL limit in a query. - Limit int + Limit int64 // Offset represents the SQL offset in a query. - Offset int + Offset int64 ) +func (t Limit) Hash() uint64 { + return cache.NewHash(FragmentType_Limit, uint64(t)) +} + +func (t Offset) Hash() uint64 { + return cache.NewHash(FragmentType_Offset, uint64(t)) +} + // Template is an SQL template. type Template struct { AndKeyword string diff --git a/internal/sqladapter/exql/types.go b/internal/sqladapter/exql/types.go new file mode 100644 index 00000000..d6ecca96 --- /dev/null +++ b/internal/sqladapter/exql/types.go @@ -0,0 +1,35 @@ +package exql + +const ( + FragmentType_None uint64 = iota + 713910251627 + + FragmentType_And + FragmentType_Column + FragmentType_ColumnValue + FragmentType_ColumnValues + FragmentType_Columns + FragmentType_Database + FragmentType_GroupBy + FragmentType_Join + FragmentType_Joins + FragmentType_Nil + FragmentType_Or + FragmentType_Limit + FragmentType_Offset + FragmentType_OrderBy + FragmentType_Order + FragmentType_Raw + FragmentType_Returning + FragmentType_SortBy + FragmentType_SortColumn + FragmentType_SortColumns + FragmentType_Statement + FragmentType_StatementType + FragmentType_Table + FragmentType_Value + FragmentType_On + FragmentType_Using + FragmentType_ValueGroups + FragmentType_Values + FragmentType_Where +) diff --git a/internal/sqladapter/exql/utilities_test.go b/internal/sqladapter/exql/utilities_test.go index 114be907..9dcbde35 100644 --- a/internal/sqladapter/exql/utilities_test.go +++ b/internal/sqladapter/exql/utilities_test.go @@ -6,6 +6,8 @@ import ( "strings" "testing" "unicode" + + "github.com/stretchr/testify/assert" ) const ( @@ -20,90 +22,64 @@ var ( stringWithLeadingBlanks = string(bytesWithLeadingBlanks) ) -func TestUtilIsBlankSymbol(t *testing.T) { - if isBlankSymbol(' ') == false { - t.Fail() - } - if isBlankSymbol('\n') == false { - t.Fail() - } - if isBlankSymbol('\t') == false { - t.Fail() - } - if isBlankSymbol('\r') == false { - t.Fail() - } - if isBlankSymbol('x') == true { - t.Fail() +var ( + reInvisible = regexp.MustCompile(`[\t\n\r]`) + reSpace = regexp.MustCompile(`\s+`) +) + +func mustTrim(a string, err error) string { + if err != nil { + panic(err.Error()) } + a = reInvisible.ReplaceAllString(strings.TrimSpace(a), " ") + a = reSpace.ReplaceAllString(strings.TrimSpace(a), " ") + return a +} + +func TestUtilIsBlankSymbol(t *testing.T) { + assert.True(t, isBlankSymbol(' ')) + assert.True(t, isBlankSymbol('\n')) + assert.True(t, isBlankSymbol('\t')) + assert.True(t, isBlankSymbol('\r')) + assert.False(t, isBlankSymbol('x')) } func TestUtilTrimBytes(t *testing.T) { var trimmed []byte trimmed = trimBytes([]byte(" \t\nHello World! \n")) - if string(trimmed) != "Hello World!" { - t.Fatalf("Got: %s\n", string(trimmed)) - } + assert.Equal(t, "Hello World!", string(trimmed)) trimmed = trimBytes([]byte("Nope")) - if string(trimmed) != "Nope" { - t.Fatalf("Got: %s\n", string(trimmed)) - } + assert.Equal(t, "Nope", string(trimmed)) trimmed = trimBytes([]byte("")) - if string(trimmed) != "" { - t.Fatalf("Got: %s\n", string(trimmed)) - } + assert.Equal(t, "", string(trimmed)) trimmed = trimBytes([]byte(" ")) - if string(trimmed) != "" { - t.Fatalf("Got: %s\n", string(trimmed)) - } + assert.Equal(t, "", string(trimmed)) trimmed = trimBytes(nil) - if string(trimmed) != "" { - t.Fatalf("Got: %s\n", string(trimmed)) - } + assert.Equal(t, "", string(trimmed)) } func TestUtilSeparateByComma(t *testing.T) { chunks := separateByComma("Hello,,World!,Enjoy") + assert.Equal(t, 4, len(chunks)) - if len(chunks) != 4 { - t.Fatal() - } - - if chunks[0] != "Hello" { - t.Fatal() - } - if chunks[1] != "" { - t.Fatal() - } - if chunks[2] != "World!" { - t.Fatal() - } - if chunks[3] != "Enjoy" { - t.Fatal() - } + assert.Equal(t, "Hello", chunks[0]) + assert.Equal(t, "", chunks[1]) + assert.Equal(t, "World!", chunks[2]) + assert.Equal(t, "Enjoy", chunks[3]) } func TestUtilSeparateBySpace(t *testing.T) { chunks := separateBySpace(" Hello World! Enjoy") + assert.Equal(t, 3, len(chunks)) - if len(chunks) != 3 { - t.Fatal() - } - - if chunks[0] != "Hello" { - t.Fatal() - } - if chunks[1] != "World!" { - t.Fatal() - } - if chunks[2] != "Enjoy" { - t.Fatal() - } + assert.Equal(t, "Hello", chunks[0]) + assert.Equal(t, "World!", chunks[1]) + assert.Equal(t, "Enjoy", chunks[2]) } func TestUtilSeparateByAS(t *testing.T) { @@ -117,96 +93,44 @@ func TestUtilSeparateByAS(t *testing.T) { for _, test := range tests { chunks = separateByAS(test) + assert.Len(t, chunks, 2) - if len(chunks) != 2 { - t.Fatalf(`Expecting 2 results.`) - } - - if chunks[0] != "table.Name" { - t.Fatal(`Expecting first result to be "table.Name".`) - } - if chunks[1] != "myTableAlias" { - t.Fatal(`Expecting second result to be myTableAlias.`) - } + assert.Equal(t, "table.Name", chunks[0]) + assert.Equal(t, "myTableAlias", chunks[1]) } // Single character. chunks = separateByAS("a") - - if len(chunks) != 1 { - t.Fatalf(`Expecting 1 results.`) - } - - if chunks[0] != "a" { - t.Fatal(`Expecting first result to be "a".`) - } + assert.Len(t, chunks, 1) + assert.Equal(t, "a", chunks[0]) // Empty name chunks = separateByAS("") - - if len(chunks) != 1 { - t.Fatalf(`Expecting 1 results.`) - } - - if chunks[0] != "" { - t.Fatal(`Expecting first result to be "".`) - } + assert.Len(t, chunks, 1) + assert.Equal(t, "", chunks[0]) // Single name chunks = separateByAS(" A Single Table ") - - if len(chunks) != 1 { - t.Fatalf(`Expecting 1 results.`) - } - - if chunks[0] != "A Single Table" { - t.Fatal(`Expecting first result to be "ASingleTable".`) - } + assert.Len(t, chunks, 1) + assert.Equal(t, "A Single Table", chunks[0]) // Minimal expression. chunks = separateByAS("a AS b") - - if len(chunks) != 2 { - t.Fatalf(`Expecting 2 results.`) - } - - if chunks[0] != "a" { - t.Fatal(`Expecting first result to be "a".`) - } - - if chunks[1] != "b" { - t.Fatal(`Expecting first result to be "b".`) - } + assert.Len(t, chunks, 2) + assert.Equal(t, "a", chunks[0]) + assert.Equal(t, "b", chunks[1]) // Minimal expression with spaces. chunks = separateByAS(" a AS b ") - - if len(chunks) != 2 { - t.Fatalf(`Expecting 2 results.`) - } - - if chunks[0] != "a" { - t.Fatal(`Expecting first result to be "a".`) - } - - if chunks[1] != "b" { - t.Fatal(`Expecting first result to be "b".`) - } + assert.Len(t, chunks, 2) + assert.Equal(t, "a", chunks[0]) + assert.Equal(t, "b", chunks[1]) // Minimal expression + 1 with spaces. chunks = separateByAS(" a AS bb ") - - if len(chunks) != 2 { - t.Fatalf(`Expecting 2 results.`) - } - - if chunks[0] != "a" { - t.Fatal(`Expecting first result to be "a".`) - } - - if chunks[1] != "bb" { - t.Fatal(`Expecting first result to be "bb".`) - } + assert.Len(t, chunks, 2) + assert.Equal(t, "a", chunks[0]) + assert.Equal(t, "bb", chunks[1]) } func BenchmarkUtilIsBlankSymbol(b *testing.B) { @@ -252,6 +176,7 @@ func BenchmarkUtilSeparateByComma(b *testing.B) { func BenchmarkUtilRegExpSeparateByComma(b *testing.B) { sep := regexp.MustCompile(`\s*?,\s*?`) + b.ResetTimer() for i := 0; i < b.N; i++ { _ = sep.Split(stringWithCommas, -1) } @@ -265,6 +190,7 @@ func BenchmarkUtilSeparateBySpace(b *testing.B) { func BenchmarkUtilRegExpSeparateBySpace(b *testing.B) { sep := regexp.MustCompile(`\s+`) + b.ResetTimer() for i := 0; i < b.N; i++ { _ = sep.Split(stringWithSpaces, -1) } @@ -278,6 +204,7 @@ func BenchmarkUtilSeparateByAS(b *testing.B) { func BenchmarkUtilRegExpSeparateByAS(b *testing.B) { sep := regexp.MustCompile(`(?i:\s+AS\s+)`) + b.ResetTimer() for i := 0; i < b.N; i++ { _ = sep.Split(stringWithASKeyword, -1) } diff --git a/internal/sqladapter/exql/value.go b/internal/sqladapter/exql/value.go index 49b22aa9..6c628287 100644 --- a/internal/sqladapter/exql/value.go +++ b/internal/sqladapter/exql/value.go @@ -1,14 +1,14 @@ package exql import ( - "fmt" "strings" + + "github.com/upper/db/v4/internal/cache" ) // ValueGroups represents an array of value groups. type ValueGroups struct { Values []*Values - hash hash } func (vg *ValueGroups) IsEmpty() bool { @@ -28,7 +28,6 @@ var _ = Fragment(&ValueGroups{}) // Values represents an array of Value. type Values struct { Values []Fragment - hash hash } func (vs *Values) IsEmpty() bool { @@ -38,12 +37,16 @@ func (vs *Values) IsEmpty() bool { return false } +// NewValueGroup creates and returns an array of values. +func NewValueGroup(v ...Fragment) *Values { + return &Values{Values: v} +} + var _ = Fragment(&Values{}) // Value represents an escaped SQL value. type Value struct { - V interface{} - hash hash + V interface{} } var _ = Fragment(&Value{}) @@ -53,50 +56,51 @@ func NewValue(v interface{}) *Value { return &Value{V: v} } -// NewValueGroup creates and returns an array of values. -func NewValueGroup(v ...Fragment) *Values { - return &Values{Values: v} -} - // Hash returns a unique identifier for the struct. -func (v *Value) Hash() string { - return v.hash.Hash(v) -} - -func (v *Value) IsEmpty() bool { - return false +func (v *Value) Hash() uint64 { + if v == nil { + return cache.NewHash(FragmentType_Value, nil) + } + return cache.NewHash(FragmentType_Value, v.V) } // Compile transforms the Value into an equivalent SQL representation. func (v *Value) Compile(layout *Template) (compiled string, err error) { - if z, ok := layout.Read(v); ok { return z, nil } - switch t := v.V.(type) { - case Raw: - compiled, err = t.Compile(layout) + switch value := v.V.(type) { + case compilable: + compiled, err = value.Compile(layout) if err != nil { return "", err } - case Fragment: - compiled, err = t.Compile(layout) + default: + value, err := NewRawValue(v.V) if err != nil { return "", err } - default: - compiled = layout.MustCompile(layout.ValueQuote, RawValue(fmt.Sprintf(`%v`, v.V))) + compiled = layout.MustCompile( + layout.ValueQuote, + value, + ) } layout.Write(v, compiled) - return } // Hash returns a unique identifier for the struct. -func (vs *Values) Hash() string { - return vs.hash.Hash(vs) +func (vs *Values) Hash() uint64 { + if vs == nil { + return cache.NewHash(FragmentType_Values, nil) + } + h := cache.InitHash(FragmentType_Values) + for i := range vs.Values { + h = cache.AddToHash(h, vs.Values[i]) + } + return h } // Compile transforms the Values into an equivalent SQL representation. @@ -122,8 +126,15 @@ func (vs *Values) Compile(layout *Template) (compiled string, err error) { } // Hash returns a unique identifier for the struct. -func (vg *ValueGroups) Hash() string { - return vg.hash.Hash(vg) +func (vg *ValueGroups) Hash() uint64 { + if vg == nil { + return cache.NewHash(FragmentType_ValueGroups, nil) + } + h := cache.InitHash(FragmentType_ValueGroups) + for i := range vg.Values { + h = cache.AddToHash(h, vg.Values[i]) + } + return h } // Compile transforms the ValueGroups into an equivalent SQL representation. diff --git a/internal/sqladapter/exql/value_test.go b/internal/sqladapter/exql/value_test.go index cf013e34..a0269b60 100644 --- a/internal/sqladapter/exql/value_test.go +++ b/internal/sqladapter/exql/value_test.go @@ -2,31 +2,59 @@ package exql import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestValue(t *testing.T) { val := NewValue(1) s, err := val.Compile(defaultTemplate) - if err != nil { - t.Fatal() + assert.NoError(t, err) + assert.Equal(t, `'1'`, s) + + val = NewValue(&Raw{Value: "NOW()"}) + + s, err = val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `NOW()`, s) +} + +func TestSameRawValue(t *testing.T) { + { + val := NewValue(&Raw{Value: `"1"`}) + + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"1"`, s) } + { + val := NewValue(&Raw{Value: `'1'`}) - e := `'1'` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `'1'`, s) } + { + val := NewValue(&Raw{Value: `1`}) - val = NewValue(&Raw{Value: "NOW()"}) + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `1`, s) + } + { + val := NewValue("1") - s, err = val.Compile(defaultTemplate) - if err != nil { - t.Fatal() + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `'1'`, s) } + { + val := NewValue(1) - e = `NOW()` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `'1'`, s) } } @@ -38,14 +66,9 @@ func TestValues(t *testing.T) { ) s, err := val.Compile(defaultTemplate) - if err != nil { - t.Fatal() - } + assert.NoError(t, err) - e := `(1, 2, '3')` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.Equal(t, `(1, 2, '3')`, s) } func BenchmarkValue(b *testing.B) { @@ -56,6 +79,7 @@ func BenchmarkValue(b *testing.B) { func BenchmarkValueHash(b *testing.B) { v := NewValue("a") + b.ResetTimer() for i := 0; i < b.N; i++ { _ = v.Hash() } @@ -63,6 +87,7 @@ func BenchmarkValueHash(b *testing.B) { func BenchmarkValueCompile(b *testing.B) { v := NewValue("a") + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = v.Compile(defaultTemplate) } @@ -83,6 +108,7 @@ func BenchmarkValues(b *testing.B) { func BenchmarkValuesHash(b *testing.B) { vs := NewValueGroup(NewValue("a"), NewValue("b")) + b.ResetTimer() for i := 0; i < b.N; i++ { _ = vs.Hash() } @@ -90,6 +116,7 @@ func BenchmarkValuesHash(b *testing.B) { func BenchmarkValuesCompile(b *testing.B) { vs := NewValueGroup(NewValue("a"), NewValue("b")) + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = vs.Compile(defaultTemplate) } diff --git a/internal/sqladapter/exql/where.go b/internal/sqladapter/exql/where.go index 3e77e005..37b51d6f 100644 --- a/internal/sqladapter/exql/where.go +++ b/internal/sqladapter/exql/where.go @@ -2,6 +2,8 @@ package exql import ( "strings" + + "github.com/upper/db/v4/internal/cache" ) // Or represents an SQL OR operator. @@ -13,7 +15,6 @@ type And Where // Where represents an SQL WHERE clause. type Where struct { Conditions []Fragment - hash hash } var _ = Fragment(&Where{}) @@ -38,8 +39,15 @@ func JoinWithAnd(conditions ...Fragment) *And { } // Hash returns a unique identifier for the struct. -func (w *Where) Hash() string { - return w.hash.Hash(w) +func (w *Where) Hash() uint64 { + if w == nil { + return cache.NewHash(FragmentType_Where, nil) + } + h := cache.InitHash(FragmentType_Where) + for i := range w.Conditions { + h = cache.AddToHash(h, w.Conditions[i]) + } + return h } // Appends adds the conditions to the ones that already exist. @@ -51,15 +59,19 @@ func (w *Where) Append(a *Where) *Where { } // Hash returns a unique identifier. -func (o *Or) Hash() string { - w := Where(*o) - return `Or(` + w.Hash() + `)` +func (o *Or) Hash() uint64 { + if o == nil { + return cache.NewHash(FragmentType_Or, nil) + } + return cache.NewHash(FragmentType_Or, (*Where)(o)) } // Hash returns a unique identifier. -func (a *And) Hash() string { - w := Where(*a) - return `And(` + w.Hash() + `)` +func (a *And) Hash() uint64 { + if a == nil { + return cache.NewHash(FragmentType_And, nil) + } + return cache.NewHash(FragmentType_And, (*Where)(a)) } // Compile transforms the Or into an equivalent SQL representation. diff --git a/internal/sqladapter/exql/where_test.go b/internal/sqladapter/exql/where_test.go index ec823bb6..16e301ed 100644 --- a/internal/sqladapter/exql/where_test.go +++ b/internal/sqladapter/exql/where_test.go @@ -2,6 +2,8 @@ package exql import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestWhereAnd(t *testing.T) { @@ -12,14 +14,8 @@ func TestWhereAnd(t *testing.T) { ) s, err := and.Compile(defaultTemplate) - if err != nil { - t.Fatal() - } - - e := `("id" > 8 AND "id" < 99 AND "name" = 'John')` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.NoError(t, err) + assert.Equal(t, `("id" > 8 AND "id" < 99 AND "name" = 'John')`, s) } func TestWhereOr(t *testing.T) { @@ -29,14 +25,8 @@ func TestWhereOr(t *testing.T) { ) s, err := or.Compile(defaultTemplate) - if err != nil { - t.Fatal() - } - - e := `("id" = 8 OR "id" = 99)` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.NoError(t, err) + assert.Equal(t, `("id" = 8 OR "id" = 99)`, s) } func TestWhereAndOr(t *testing.T) { @@ -51,39 +41,61 @@ func TestWhereAndOr(t *testing.T) { ) s, err := and.Compile(defaultTemplate) - if err != nil { - t.Fatal() - } + assert.NoError(t, err) - e := `("id" > 8 AND "id" < 99 AND "name" = 'John' AND ("last_name" = 'Smith' OR "last_name" = 'Reyes'))` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) - } + assert.Equal(t, `("id" > 8 AND "id" < 99 AND "name" = 'John' AND ("last_name" = 'Smith' OR "last_name" = 'Reyes'))`, s) } func TestWhereAndRawOrAnd(t *testing.T) { - where := WhereConditions( - JoinWithAnd( - &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, - &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "99"})}, - ), - &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, - &Raw{Value: "city_id = 728"}, - JoinWithOr( - &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Smith")}, - &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Reyes")}, - ), - JoinWithAnd( - &ColumnValue{Column: &Column{Name: "age"}, Operator: ">", Value: NewValue(&Raw{Value: "18"})}, - &ColumnValue{Column: &Column{Name: "age"}, Operator: "<", Value: NewValue(&Raw{Value: "41"})}, - ), - ) + { + where := WhereConditions( + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(2)}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "77"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(1)}, + ), + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + &Raw{Value: "city_id = 728"}, + JoinWithOr( + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Smith")}, + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Reyes")}, + ), + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "age"}, Operator: ">", Value: NewValue(&Raw{Value: "18"})}, + &ColumnValue{Column: &Column{Name: "age"}, Operator: "<", Value: NewValue(&Raw{Value: "41"})}, + ), + ) - s := mustTrim(where.Compile(defaultTemplate)) + assert.Equal(t, + `WHERE (("id" > '2' AND "id" < 77 AND "id" < '1') AND "name" = 'John' AND city_id = 728 AND ("last_name" = 'Smith' OR "last_name" = 'Reyes') AND ("age" > 18 AND "age" < 41))`, + mustTrim(where.Compile(defaultTemplate)), + ) + } - e := `WHERE (("id" > 8 AND "id" < 99) AND "name" = 'John' AND city_id = 728 AND ("last_name" = 'Smith' OR "last_name" = 'Reyes') AND ("age" > 18 AND "age" < 41))` - if s != e { - t.Fatalf("Got: %s, Expecting: %s", s, e) + { + where := WhereConditions( + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "99"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(1)}, + ), + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + &Raw{Value: "city_id = 728"}, + JoinWithOr( + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Smith")}, + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Reyes")}, + ), + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "age"}, Operator: ">", Value: NewValue(&Raw{Value: "18"})}, + &ColumnValue{Column: &Column{Name: "age"}, Operator: "<", Value: NewValue(&Raw{Value: "41"})}, + ), + ) + + assert.Equal(t, + `WHERE (("id" > 8 AND "id" > 8 AND "id" < 99 AND "id" < '1') AND "name" = 'John' AND city_id = 728 AND ("last_name" = 'Smith' OR "last_name" = 'Reyes') AND ("age" > 18 AND "age" < 41))`, + mustTrim(where.Compile(defaultTemplate)), + ) } } @@ -99,6 +111,7 @@ func BenchmarkCompileWhere(b *testing.B) { w := WhereConditions( &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, ) + b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = w.Compile(defaultTemplate) } diff --git a/internal/sqladapter/hash.go b/internal/sqladapter/hash.go new file mode 100644 index 00000000..4d754914 --- /dev/null +++ b/internal/sqladapter/hash.go @@ -0,0 +1,8 @@ +package sqladapter + +const ( + hashTypeNone = iota + 345065139389 + + hashTypeCollection + hashTypePrimaryKeys +) diff --git a/internal/sqladapter/session.go b/internal/sqladapter/session.go index b9a6921d..0978205a 100644 --- a/internal/sqladapter/session.go +++ b/internal/sqladapter/session.go @@ -1,6 +1,7 @@ package sqladapter import ( + "bytes" "context" "database/sql" "database/sql/driver" @@ -286,7 +287,8 @@ func (sess *sessionWithContext) Err(errIn error) (errOur error) { } func (sess *sessionWithContext) PrimaryKeys(tableName string) ([]string, error) { - h := cache.String(tableName) + h := cache.NewHashable(hashTypePrimaryKeys, tableName) + cachedPK, ok := sess.cachedPKs.ReadRaw(h) if ok { return cachedPK.([]string), nil @@ -652,7 +654,8 @@ func (sess *sessionWithContext) Collection(name string) db.Collection { sess.cacheMu.Lock() defer sess.cacheMu.Unlock() - h := cache.String(name) + h := cache.NewHashable(hashTypeCollection, name) + col, ok := sess.cachedCollections.ReadRaw(h) if !ok { col = newCollection(name, sess.adapter.NewCollection()) @@ -1001,29 +1004,37 @@ func (sess *sessionWithContext) WaitForConnection(connectFn func() error) error // ReplaceWithDollarSign turns a SQL statament with '?' placeholders into // dollar placeholders, like $1, $2, ..., $n -func ReplaceWithDollarSign(in string) string { - buf := []byte(in) - out := make([]byte, 0, len(buf)) - - i, j, k, t := 0, 1, 0, len(buf) - - for i < t { +func ReplaceWithDollarSign(buf []byte) []byte { + z := bytes.Count(buf, []byte{'?'}) + // the capacity is a quick estimation of the total memory required, this + // reduces reallocations + out := make([]byte, 0, len(buf)+z*3) + + var i, k = 0, 1 + for i < len(buf) { if buf[i] == '?' { - out = append(out, buf[k:i]...) - k = i + 1 - - if k < t && buf[k] == '?' { - i = k - } else { - out = append(out, []byte("$"+strconv.Itoa(j))...) - j++ + out = append(out, buf[:i]...) + buf = buf[i+1:] + i = 0 + + if len(buf) > 0 && buf[0] == '?' { + out = append(out, '?') + buf = buf[1:] + continue } + + out = append(out, '$') + out = append(out, []byte(strconv.Itoa(k))...) + k = k + 1 + continue } - i++ + i = i + 1 } - out = append(out, buf[k:i]...) - return string(out) + out = append(out, buf[:len(buf)]...) + buf = nil + + return out } func copySettings(from Session, into Session) { diff --git a/internal/sqladapter/sqladapter_test.go b/internal/sqladapter/sqladapter_test.go index 1e1ab5be..a95275af 100644 --- a/internal/sqladapter/sqladapter_test.go +++ b/internal/sqladapter/sqladapter_test.go @@ -40,6 +40,6 @@ func TestReplaceWithDollarSign(t *testing.T) { } for _, test := range tests { - assert.Equal(t, test.out, ReplaceWithDollarSign(test.in)) + assert.Equal(t, []byte(test.out), ReplaceWithDollarSign([]byte(test.in))) } } diff --git a/internal/sqladapter/statement.go b/internal/sqladapter/statement.go index 17e7c6d7..0b18ebd1 100644 --- a/internal/sqladapter/statement.go +++ b/internal/sqladapter/statement.go @@ -12,7 +12,7 @@ var ( ) // Stmt represents a *sql.Stmt that is cached and provides the -// OnPurge method to allow it to clean after itself. +// OnEvict method to allow it to clean after itself. type Stmt struct { *sql.Stmt @@ -69,8 +69,8 @@ func (c *Stmt) checkClose() error { return nil } -// OnPurge marks the statement as ready to be cleaned up. -func (c *Stmt) OnPurge() { +// OnEvict marks the statement as ready to be cleaned up. +func (c *Stmt) OnEvict() { c.mu.Lock() defer c.mu.Unlock() diff --git a/internal/sqlbuilder/builder.go b/internal/sqlbuilder/builder.go index b912e6cc..b9cf5799 100644 --- a/internal/sqlbuilder/builder.go +++ b/internal/sqlbuilder/builder.go @@ -75,7 +75,7 @@ type fieldValue struct { } var ( - sqlPlaceholder = exql.RawValue(`?`) + sqlPlaceholder = &exql.Raw{Value: `?`} ) var ( @@ -358,7 +358,7 @@ func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, err q, a := Preprocess(p.String(), p.Arguments()) - f[i] = exql.RawValue("(" + q + ")") + f[i] = &exql.Raw{Value: "(" + q + ")"} args = append(args, a...) case isCompilable: c, err := v.Compile() @@ -369,7 +369,7 @@ func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, err if _, ok := v.(db.Selector); ok { q = "(" + q + ")" } - f[i] = exql.RawValue(q) + f[i] = &exql.Raw{Value: q} args = append(args, a...) case *adapter.FuncExpr: fnName, fnArgs := v.Name(), v.Arguments() @@ -379,22 +379,24 @@ func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, err fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } fnName, fnArgs = Preprocess(fnName, fnArgs) - f[i] = exql.RawValue(fnName) + f[i] = &exql.Raw{Value: fnName} args = append(args, fnArgs...) case *adapter.RawExpr: q, a := Preprocess(v.Raw(), v.Arguments()) - f[i] = exql.RawValue(q) + f[i] = &exql.Raw{Value: q} args = append(args, a...) case exql.Fragment: f[i] = v case string: f[i] = exql.ColumnWithName(v) - case int: - f[i] = exql.RawValue(fmt.Sprintf("%v", v)) - case interface{}: - f[i] = exql.ColumnWithName(fmt.Sprintf("%v", v)) + case fmt.Stringer: + f[i] = exql.ColumnWithName(v.String()) default: - return nil, nil, fmt.Errorf("unexpected argument type %T for Select() argument", v) + var err error + f[i], err = exql.NewRawValue(columns[i]) + if err != nil { + return nil, nil, fmt.Errorf("unexpected argument type %T for Select() argument: %w", v, err) + } } } return f, args, nil diff --git a/internal/sqlbuilder/builder_test.go b/internal/sqlbuilder/builder_test.go index 971998e8..ce123909 100644 --- a/internal/sqlbuilder/builder_test.go +++ b/internal/sqlbuilder/builder_test.go @@ -1428,9 +1428,10 @@ func BenchmarkSelect4(b *testing.B) { } func BenchmarkSelect5(b *testing.B) { - bt := WithTemplate(&testTemplate) + t := WithTemplate(&testTemplate) + b.ResetTimer() for n := 0; n < b.N; n++ { - _ = bt.SelectFrom("artist a"). + _ = t.SelectFrom("artist a"). LeftJoin("publication p1").On("p1.id = a.id"). RightJoin("publication p2").On("p2.id = a.id"). String() diff --git a/internal/sqlbuilder/convert.go b/internal/sqlbuilder/convert.go index 37901f61..21e161fb 100644 --- a/internal/sqlbuilder/convert.go +++ b/internal/sqlbuilder/convert.go @@ -1,43 +1,93 @@ package sqlbuilder import ( + "bytes" "database/sql/driver" "reflect" - "strings" "github.com/upper/db/v4/internal/adapter" "github.com/upper/db/v4/internal/sqladapter/exql" ) var ( - sqlDefault = exql.RawValue(`DEFAULT`) + sqlDefault = &exql.Raw{Value: "DEFAULT"} ) -func expandQuery(in string, args []interface{}, fn func(interface{}) (string, []interface{})) (string, []interface{}) { - argn := 0 - argx := make([]interface{}, 0, len(args)) - for i := 0; i < len(in); i++ { - if in[i] != '?' { +func expandQuery(in []byte, inArgs []interface{}) ([]byte, []interface{}) { + out := make([]byte, 0, len(in)) + outArgs := make([]interface{}, 0, len(inArgs)) + + i := 0 + for i < len(in) && len(inArgs) > 0 { + if in[i] == '?' { + out = append(out, in[:i]...) + in = in[i+1:] + i = 0 + + replace, replaceArgs := expandArgument(inArgs[0]) + inArgs = inArgs[1:] + + if len(replace) > 0 { + replace, replaceArgs = expandQuery(replace, replaceArgs) + out = append(out, replace...) + } else { + out = append(out, '?') + } + + outArgs = append(outArgs, replaceArgs...) continue } - if len(args) > argn { - k, values := fn(args[argn]) - k, values = expandQuery(k, values, fn) + i = i + 1 + } + + if len(out) < 1 { + return in, inArgs + } + + out = append(out, in[:len(in)]...) + in = nil + + outArgs = append(outArgs, inArgs[:len(inArgs)]...) + inArgs = nil - if k != "" { - in = in[:i] + k + in[i+1:] - i += len(k) - 1 + return out, outArgs +} + +func expandArgument(arg interface{}) ([]byte, []interface{}) { + values, isSlice := toInterfaceArguments(arg) + + if isSlice { + if len(values) == 0 { + return []byte("(NULL)"), nil + } + buf := bytes.Repeat([]byte(" ?,"), len(values)) + buf[0] = '(' + buf[len(buf)-1] = ')' + return buf, values + } + + if len(values) == 1 { + switch t := arg.(type) { + case *adapter.RawExpr: + return expandQuery([]byte(t.Raw()), t.Arguments()) + case hasPaginator: + p, err := t.Paginator() + if err == nil { + return append([]byte{'('}, append([]byte(p.String()), ')')...), p.Arguments() } - if len(values) > 0 { - argx = append(argx, values...) + panic(err.Error()) + case isCompilable: + s, err := t.Compile() + if err == nil { + return append([]byte{'('}, append([]byte(s), ')')...), t.Arguments() } - argn++ + panic(err.Error()) } + } else if len(values) == 0 { + return []byte("NULL"), nil } - if len(argx) < len(args) { - argx = append(argx, args[argn:]...) - } - return in, argx + + return nil, []interface{}{arg} } // toInterfaceArguments converts the given value into an array of interfaces. @@ -108,42 +158,9 @@ func toColumnsValuesAndArguments(columnNames []string, columnValues []interface{ return columns, values, arguments, nil } -func preprocessFn(arg interface{}) (string, []interface{}) { - values, isSlice := toInterfaceArguments(arg) - - if isSlice { - if len(values) == 0 { - return `(NULL)`, nil - } - return `(?` + strings.Repeat(`, ?`, len(values)-1) + `)`, values - } - - if len(values) == 1 { - switch t := arg.(type) { - case *adapter.RawExpr: - return Preprocess(t.Raw(), t.Arguments()) - case hasPaginator: - p, err := t.Paginator() - if err == nil { - return `(` + p.String() + `)`, p.Arguments() - } - panic(err.Error()) - case isCompilable: - c, err := t.Compile() - if err == nil { - return `(` + c + `)`, t.Arguments() - } - panic(err.Error()) - } - } else if len(values) == 0 { - return `NULL`, nil - } - - return "", []interface{}{arg} -} - // Preprocess expands arguments that needs to be expanded and compiles a query // into a single string. func Preprocess(in string, args []interface{}) (string, []interface{}) { - return expandQuery(in, args, preprocessFn) + b, args := expandQuery([]byte(in), args) + return string(b), args } diff --git a/internal/sqlbuilder/insert.go b/internal/sqlbuilder/insert.go index d4c95946..80e26d4c 100644 --- a/internal/sqlbuilder/insert.go +++ b/internal/sqlbuilder/insert.go @@ -60,7 +60,7 @@ func (iq *inserterQuery) processValues() ([]*exql.Values, []interface{}, error) l := len(enqueuedValue) placeholders := make([]exql.Fragment, l) for i := 0; i < l; i++ { - placeholders[i] = exql.RawValue(`?`) + placeholders[i] = sqlPlaceholder } values = append(values, exql.NewValueGroup(placeholders...)) } diff --git a/internal/sqlbuilder/select.go b/internal/sqlbuilder/select.go index 73a2fb89..93772405 100644 --- a/internal/sqlbuilder/select.go +++ b/internal/sqlbuilder/select.go @@ -257,7 +257,7 @@ func (sel *selector) OrderBy(columns ...interface{}) db.Selector { case *adapter.RawExpr: query, args := Preprocess(value.Raw(), value.Arguments()) sort = &exql.SortColumn{ - Column: exql.RawValue(query), + Column: &exql.Raw{Value: query}, } sq.orderByArgs = append(sq.orderByArgs, args...) case *adapter.FuncExpr: @@ -269,21 +269,21 @@ func (sel *selector) OrderBy(columns ...interface{}) db.Selector { } fnName, fnArgs = Preprocess(fnName, fnArgs) sort = &exql.SortColumn{ - Column: exql.RawValue(fnName), + Column: &exql.Raw{Value: fnName}, } sq.orderByArgs = append(sq.orderByArgs, fnArgs...) case string: if strings.HasPrefix(value, "-") { sort = &exql.SortColumn{ Column: exql.ColumnWithName(value[1:]), - Order: exql.Descendent, + Order: exql.Order_Descendent, } } else { chunks := strings.SplitN(value, " ", 2) - order := exql.Ascendent + order := exql.Order_Ascendent if len(chunks) > 1 && strings.ToUpper(chunks[1]) == "DESC" { - order = exql.Descendent + order = exql.Order_Descendent } sort = &exql.SortColumn{ @@ -418,7 +418,7 @@ func (sel *selector) As(alias string) db.Selector { if err != nil { return err } - sq.table.Columns[last] = exql.RawValue(raw.Value + " AS " + compiled) + sq.table.Columns[last] = &exql.Raw{Value: raw.Value + " AS " + compiled} } return nil }) diff --git a/internal/sqlbuilder/template.go b/internal/sqlbuilder/template.go index 0c9c7917..eca2382d 100644 --- a/internal/sqlbuilder/template.go +++ b/internal/sqlbuilder/template.go @@ -21,7 +21,7 @@ func newTemplateWithUtils(template *exql.Template) *templateWithUtils { func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, []interface{}) { switch t := in.(type) { case *adapter.RawExpr: - return exql.RawValue(t.Raw()), t.Arguments() + return &exql.Raw{Value: t.Raw()}, t.Arguments() case *adapter.FuncExpr: fnName := t.Name() fnArgs := []interface{}{} @@ -35,7 +35,7 @@ func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, [] fnArgs = append(fnArgs, args...) } } - return exql.RawValue(fnName + `(` + strings.Join(fragments, `, `) + `)`), fnArgs + return &exql.Raw{Value: fnName + `(` + strings.Join(fragments, `, `) + `)`}, fnArgs default: return sqlPlaceholder, []interface{}{in} } @@ -51,7 +51,7 @@ func (tu *templateWithUtils) toWhereWithArguments(term interface{}) (where exql. if s, ok := t[0].(string); ok { if strings.ContainsAny(s, "?") || len(t) == 1 { s, args = Preprocess(s, t[1:]) - where.Conditions = []exql.Fragment{exql.RawValue(s)} + where.Conditions = []exql.Fragment{&exql.Raw{Value: s}} } else { var val interface{} key := s @@ -80,7 +80,7 @@ func (tu *templateWithUtils) toWhereWithArguments(term interface{}) (where exql. return case *adapter.RawExpr: r, v := Preprocess(t.Raw(), t.Arguments()) - where.Conditions = []exql.Fragment{exql.RawValue(r)} + where.Conditions = []exql.Fragment{&exql.Raw{Value: r}} args = append(args, v...) return case adapter.Constraints: @@ -172,10 +172,10 @@ func (tu *templateWithUtils) toColumnValues(term interface{}) (cv exql.ColumnVal } } else { if rawValue, ok := t.Key().(*adapter.RawExpr); ok { - columnValue.Column = exql.RawValue(rawValue.Raw()) + columnValue.Column = &exql.Raw{Value: rawValue.Raw()} args = append(args, rawValue.Arguments()...) } else { - columnValue.Column = exql.RawValue(fmt.Sprintf("%v", t.Key())) + columnValue.Column = &exql.Raw{Value: fmt.Sprintf("%v", t.Key())} } } @@ -190,14 +190,14 @@ func (tu *templateWithUtils) toColumnValues(term interface{}) (cv exql.ColumnVal fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } fnName, fnArgs = Preprocess(fnName, fnArgs) - columnValue.Value = exql.RawValue(fnName) + columnValue.Value = &exql.Raw{Value: fnName} args = append(args, fnArgs...) case *db.RawExpr: q, a := Preprocess(value.Raw(), value.Arguments()) - columnValue.Value = exql.RawValue(q) + columnValue.Value = &exql.Raw{Value: q} args = append(args, a...) case driver.Valuer: - columnValue.Value = exql.RawValue("?") + columnValue.Value = sqlPlaceholder args = append(args, value) case *db.Comparison: wrapper := &operatorWrapper{ @@ -210,7 +210,7 @@ func (tu *templateWithUtils) toColumnValues(term interface{}) (cv exql.ColumnVal q, a = Preprocess(q, a) columnValue = exql.ColumnValue{ - Column: exql.RawValue(q), + Column: &exql.Raw{Value: q}, } if a != nil { args = append(args, a...) @@ -229,7 +229,7 @@ func (tu *templateWithUtils) toColumnValues(term interface{}) (cv exql.ColumnVal q, a = Preprocess(q, a) columnValue = exql.ColumnValue{ - Column: exql.RawValue(q), + Column: &exql.Raw{Value: q}, } if a != nil { args = append(args, a...) @@ -249,7 +249,7 @@ func (tu *templateWithUtils) toColumnValues(term interface{}) (cv exql.ColumnVal case *adapter.RawExpr: columnValue := exql.ColumnValue{} p, q := Preprocess(t.Raw(), t.Arguments()) - columnValue.Column = exql.RawValue(p) + columnValue.Column = &exql.Raw{Value: p} cv.ColumnValues = append(cv.ColumnValues, &columnValue) args = append(args, q...) return cv, args @@ -294,7 +294,7 @@ func (tu *templateWithUtils) setColumnValues(term interface{}) (cv exql.ColumnVa columnValue := exql.ColumnValue{ Column: exql.ColumnWithName(column), Operator: tu.AssignmentOperator, - Value: exql.RawValue(format), + Value: &exql.Raw{Value: format}, } ps := strings.Count(format, "?") @@ -313,7 +313,7 @@ func (tu *templateWithUtils) setColumnValues(term interface{}) (cv exql.ColumnVa case *adapter.RawExpr: columnValue := exql.ColumnValue{} p, q := Preprocess(t.Raw(), t.Arguments()) - columnValue.Column = exql.RawValue(p) + columnValue.Column = &exql.Raw{Value: p} cv.ColumnValues = append(cv.ColumnValues, &columnValue) args = append(args, q...) return cv, args