From f688304c9edb089a1b73d28ade59a860f56196c7 Mon Sep 17 00:00:00 2001 From: Tomasz Janiszewski Date: Thu, 13 Jul 2023 19:03:56 +0200 Subject: [PATCH] refactor(pg): store empty map as {} not null --- .../cluster/store/cluster/postgres/store.go | 2 +- central/deployment/store/postgres/store.go | 12 +++++----- central/namespace/store/postgres/store.go | 8 +++---- .../k8srole/internal/store/postgres/store.go | 8 +++---- .../internal/store/postgres/store.go | 8 +++---- .../internal/store/postgres/store.go | 8 +++---- pkg/postgres/pgutils/utils.go | 8 +++++++ .../multitest/postgres/store.go | 4 ++-- .../pg-table-bindings/store.go.tpl | 6 ++++- .../pg-table-bindings/test/postgres/store.go | 4 ++-- .../test/postgres/store_null_map_test.go | 23 ++++++++++++++++++- .../testuuidkey/postgres/store.go | 4 ++-- 12 files changed, 64 insertions(+), 31 deletions(-) diff --git a/central/cluster/store/cluster/postgres/store.go b/central/cluster/store/cluster/postgres/store.go index 4c5fb1e4107b3..2b62193bb3b98 100644 --- a/central/cluster/store/cluster/postgres/store.go +++ b/central/cluster/store/cluster/postgres/store.go @@ -125,7 +125,7 @@ func insertIntoClusters(_ context.Context, batch *pgx.Batch, obj *storage.Cluste // parent primary keys start pgutils.NilOrUUID(obj.GetId()), obj.GetName(), - obj.GetLabels(), + pgutils.EmptyOrMap(obj.GetLabels()), serialized, } diff --git a/central/deployment/store/postgres/store.go b/central/deployment/store/postgres/store.go index da0a2ea910650..d666db56b3637 100644 --- a/central/deployment/store/postgres/store.go +++ b/central/deployment/store/postgres/store.go @@ -129,12 +129,12 @@ func insertIntoDeployments(ctx context.Context, batch *pgx.Batch, obj *storage.D obj.GetNamespace(), pgutils.NilOrUUID(obj.GetNamespaceId()), obj.GetOrchestratorComponent(), - obj.GetLabels(), - obj.GetPodLabels(), + pgutils.EmptyOrMap(obj.GetLabels()), + pgutils.EmptyOrMap(obj.GetPodLabels()), pgutils.NilOrTime(obj.GetCreated()), pgutils.NilOrUUID(obj.GetClusterId()), obj.GetClusterName(), - obj.GetAnnotations(), + pgutils.EmptyOrMap(obj.GetAnnotations()), obj.GetPriority(), obj.GetImagePullSecrets(), obj.GetServiceAccount(), @@ -369,12 +369,12 @@ func copyFromDeployments(ctx context.Context, s pgSearch.Deleter, tx *postgres.T obj.GetNamespace(), pgutils.NilOrUUID(obj.GetNamespaceId()), obj.GetOrchestratorComponent(), - obj.GetLabels(), - obj.GetPodLabels(), + pgutils.EmptyOrMap(obj.GetLabels()), + pgutils.EmptyOrMap(obj.GetPodLabels()), pgutils.NilOrTime(obj.GetCreated()), pgutils.NilOrUUID(obj.GetClusterId()), obj.GetClusterName(), - obj.GetAnnotations(), + pgutils.EmptyOrMap(obj.GetAnnotations()), obj.GetPriority(), obj.GetImagePullSecrets(), obj.GetServiceAccount(), diff --git a/central/namespace/store/postgres/store.go b/central/namespace/store/postgres/store.go index 9d0cbefc9cd5e..ffecd7c79f206 100644 --- a/central/namespace/store/postgres/store.go +++ b/central/namespace/store/postgres/store.go @@ -127,8 +127,8 @@ func insertIntoNamespaces(_ context.Context, batch *pgx.Batch, obj *storage.Name obj.GetName(), pgutils.NilOrUUID(obj.GetClusterId()), obj.GetClusterName(), - obj.GetLabels(), - obj.GetAnnotations(), + pgutils.EmptyOrMap(obj.GetLabels()), + pgutils.EmptyOrMap(obj.GetAnnotations()), serialized, } @@ -171,8 +171,8 @@ func copyFromNamespaces(ctx context.Context, s pgSearch.Deleter, tx *postgres.Tx obj.GetName(), pgutils.NilOrUUID(obj.GetClusterId()), obj.GetClusterName(), - obj.GetLabels(), - obj.GetAnnotations(), + pgutils.EmptyOrMap(obj.GetLabels()), + pgutils.EmptyOrMap(obj.GetAnnotations()), serialized, }) diff --git a/central/rbac/k8srole/internal/store/postgres/store.go b/central/rbac/k8srole/internal/store/postgres/store.go index 4507132f3d709..538a426b8ed7e 100644 --- a/central/rbac/k8srole/internal/store/postgres/store.go +++ b/central/rbac/k8srole/internal/store/postgres/store.go @@ -129,8 +129,8 @@ func insertIntoK8sRoles(_ context.Context, batch *pgx.Batch, obj *storage.K8SRol pgutils.NilOrUUID(obj.GetClusterId()), obj.GetClusterName(), obj.GetClusterRole(), - obj.GetLabels(), - obj.GetAnnotations(), + pgutils.EmptyOrMap(obj.GetLabels()), + pgutils.EmptyOrMap(obj.GetAnnotations()), serialized, } @@ -177,8 +177,8 @@ func copyFromK8sRoles(ctx context.Context, s pgSearch.Deleter, tx *postgres.Tx, pgutils.NilOrUUID(obj.GetClusterId()), obj.GetClusterName(), obj.GetClusterRole(), - obj.GetLabels(), - obj.GetAnnotations(), + pgutils.EmptyOrMap(obj.GetLabels()), + pgutils.EmptyOrMap(obj.GetAnnotations()), serialized, }) diff --git a/central/rbac/k8srolebinding/internal/store/postgres/store.go b/central/rbac/k8srolebinding/internal/store/postgres/store.go index fc623bdc63b82..7bad298af4a53 100644 --- a/central/rbac/k8srolebinding/internal/store/postgres/store.go +++ b/central/rbac/k8srolebinding/internal/store/postgres/store.go @@ -129,8 +129,8 @@ func insertIntoRoleBindings(ctx context.Context, batch *pgx.Batch, obj *storage. pgutils.NilOrUUID(obj.GetClusterId()), obj.GetClusterName(), obj.GetClusterRole(), - obj.GetLabels(), - obj.GetAnnotations(), + pgutils.EmptyOrMap(obj.GetLabels()), + pgutils.EmptyOrMap(obj.GetAnnotations()), pgutils.NilOrUUID(obj.GetRoleId()), serialized, } @@ -205,8 +205,8 @@ func copyFromRoleBindings(ctx context.Context, s pgSearch.Deleter, tx *postgres. pgutils.NilOrUUID(obj.GetClusterId()), obj.GetClusterName(), obj.GetClusterRole(), - obj.GetLabels(), - obj.GetAnnotations(), + pgutils.EmptyOrMap(obj.GetLabels()), + pgutils.EmptyOrMap(obj.GetAnnotations()), pgutils.NilOrUUID(obj.GetRoleId()), serialized, }) diff --git a/central/serviceaccount/internal/store/postgres/store.go b/central/serviceaccount/internal/store/postgres/store.go index e38592b8d2033..c6af603a87869 100644 --- a/central/serviceaccount/internal/store/postgres/store.go +++ b/central/serviceaccount/internal/store/postgres/store.go @@ -128,8 +128,8 @@ func insertIntoServiceAccounts(_ context.Context, batch *pgx.Batch, obj *storage obj.GetNamespace(), obj.GetClusterName(), pgutils.NilOrUUID(obj.GetClusterId()), - obj.GetLabels(), - obj.GetAnnotations(), + pgutils.EmptyOrMap(obj.GetLabels()), + pgutils.EmptyOrMap(obj.GetAnnotations()), serialized, } @@ -174,8 +174,8 @@ func copyFromServiceAccounts(ctx context.Context, s pgSearch.Deleter, tx *postgr obj.GetNamespace(), obj.GetClusterName(), pgutils.NilOrUUID(obj.GetClusterId()), - obj.GetLabels(), - obj.GetAnnotations(), + pgutils.EmptyOrMap(obj.GetLabels()), + pgutils.EmptyOrMap(obj.GetAnnotations()), serialized, }) diff --git a/pkg/postgres/pgutils/utils.go b/pkg/postgres/pgutils/utils.go index 0d5e7a8556a4f..231e1d7f98f34 100644 --- a/pkg/postgres/pgutils/utils.go +++ b/pkg/postgres/pgutils/utils.go @@ -76,6 +76,14 @@ func NilOrUUID(value string) *uuid.UUID { return &id } +// EmptyOrMap allows for map to be stored explicit as an empty object ({}) rather than null. +func EmptyOrMap[K comparable, V any, M map[K]V](m M) interface{} { + if m == nil { + return make(M) + } + return m +} + // CreateTableFromModel executes input create statement using the input connection. func CreateTableFromModel(ctx context.Context, db *gorm.DB, createStmt *postgres.CreateStmts) { // Partitioned tables are not supported by Gorm migration or models diff --git a/tools/generate-helpers/pg-table-bindings/multitest/postgres/store.go b/tools/generate-helpers/pg-table-bindings/multitest/postgres/store.go index 3f6e68976d424..29fc1172d7841 100644 --- a/tools/generate-helpers/pg-table-bindings/multitest/postgres/store.go +++ b/tools/generate-helpers/pg-table-bindings/multitest/postgres/store.go @@ -110,7 +110,7 @@ func insertIntoTestStructs(ctx context.Context, batch *pgx.Batch, obj *storage.T obj.GetUint64(), obj.GetInt64(), obj.GetFloat(), - obj.GetLabels(), + pgutils.EmptyOrMap(obj.GetLabels()), pgutils.NilOrTime(obj.GetTimestamp()), obj.GetEnum(), obj.GetEnums(), @@ -200,7 +200,7 @@ func copyFromTestStructs(ctx context.Context, s pgSearch.Deleter, tx *postgres.T obj.GetUint64(), obj.GetInt64(), obj.GetFloat(), - obj.GetLabels(), + pgutils.EmptyOrMap(obj.GetLabels()), pgutils.NilOrTime(obj.GetTimestamp()), obj.GetEnum(), obj.GetEnums(), diff --git a/tools/generate-helpers/pg-table-bindings/store.go.tpl b/tools/generate-helpers/pg-table-bindings/store.go.tpl index 435d78cea0c14..7ab2670df3168 100644 --- a/tools/generate-helpers/pg-table-bindings/store.go.tpl +++ b/tools/generate-helpers/pg-table-bindings/store.go.tpl @@ -205,6 +205,8 @@ func {{ template "insertFunctionName" $schema }}({{ if eq (len $schema.Children) pgutils.NilOrTime({{$field.Getter "obj"}}), {{- else if eq $field.SQLType "uuid" }} pgutils.NilOrUUID({{$field.Getter "obj"}}), + {{- else if eq $field.DataType "map" }} + pgutils.EmptyOrMap({{$field.Getter "obj"}}), {{- else }} {{$field.Getter "obj"}},{{end}} {{- end}} @@ -275,7 +277,9 @@ func {{ template "copyFunctionName" $schema }}(ctx context.Context, s pgSearch.D pgutils.NilOrTime({{$field.Getter "obj"}}), {{- else if eq $field.SQLType "uuid" }} pgutils.NilOrUUID({{$field.Getter "obj"}}), - {{- else}} + {{- else if eq $field.DataType "map" }} + pgutils.EmptyOrMap({{$field.Getter "obj"}}), + {{- else }} {{$field.Getter "obj"}},{{end}} {{- end}} }) diff --git a/tools/generate-helpers/pg-table-bindings/test/postgres/store.go b/tools/generate-helpers/pg-table-bindings/test/postgres/store.go index 4ec38051d8efa..1be2b5ca100e6 100644 --- a/tools/generate-helpers/pg-table-bindings/test/postgres/store.go +++ b/tools/generate-helpers/pg-table-bindings/test/postgres/store.go @@ -111,7 +111,7 @@ func insertIntoTestSingleKeyStructs(_ context.Context, batch *pgx.Batch, obj *st obj.GetUint64(), obj.GetInt64(), obj.GetFloat(), - obj.GetLabels(), + pgutils.EmptyOrMap(obj.GetLabels()), pgutils.NilOrTime(obj.GetTimestamp()), obj.GetEnum(), obj.GetEnums(), @@ -165,7 +165,7 @@ func copyFromTestSingleKeyStructs(ctx context.Context, s pgSearch.Deleter, tx *p obj.GetUint64(), obj.GetInt64(), obj.GetFloat(), - obj.GetLabels(), + pgutils.EmptyOrMap(obj.GetLabels()), pgutils.NilOrTime(obj.GetTimestamp()), obj.GetEnum(), obj.GetEnums(), diff --git a/tools/generate-helpers/pg-table-bindings/test/postgres/store_null_map_test.go b/tools/generate-helpers/pg-table-bindings/test/postgres/store_null_map_test.go index f24b8c949a4ed..1981e0e5846ee 100644 --- a/tools/generate-helpers/pg-table-bindings/test/postgres/store_null_map_test.go +++ b/tools/generate-helpers/pg-table-bindings/test/postgres/store_null_map_test.go @@ -4,6 +4,7 @@ package postgres import ( "context" + "fmt" "github.com/stackrox/rox/generated/storage" "github.com/stackrox/rox/pkg/sac" @@ -19,5 +20,25 @@ func (s *TestSingleKeyStructsStoreSuite) TestStoreNilMap() { row := s.testDB.QueryRow(ctx, "select labels from test_single_key_structs") err := row.Scan(&val) s.NoError(err) - s.Equal("null", val) + s.Equal("{}", val) +} + +func (s *TestSingleKeyStructsStoreSuite) TestStoreNilMapUpsertMany() { + ctx := sac.WithAllAccess(context.Background()) + + const batchSize = 10000 + testSingleKeyStructs := make([]*storage.TestSingleKeyStruct, batchSize) + for i := range testSingleKeyStructs { + testSingleKeyStructs[i] = &storage.TestSingleKeyStruct{ + Key: fmt.Sprintf("%d", i), + Name: fmt.Sprintf("%d", i), + } + } + s.NoError(s.store.UpsertMany(ctx, testSingleKeyStructs)) + + var val string + row := s.testDB.QueryRow(ctx, "select labels from test_single_key_structs limit 1") + err := row.Scan(&val) + s.NoError(err) + s.Equal("{}", val) } diff --git a/tools/generate-helpers/pg-table-bindings/testuuidkey/postgres/store.go b/tools/generate-helpers/pg-table-bindings/testuuidkey/postgres/store.go index f135e4abec7d8..f071ebf12c5e7 100644 --- a/tools/generate-helpers/pg-table-bindings/testuuidkey/postgres/store.go +++ b/tools/generate-helpers/pg-table-bindings/testuuidkey/postgres/store.go @@ -111,7 +111,7 @@ func insertIntoTestSingleUUIDKeyStructs(_ context.Context, batch *pgx.Batch, obj obj.GetUint64(), obj.GetInt64(), obj.GetFloat(), - obj.GetLabels(), + pgutils.EmptyOrMap(obj.GetLabels()), pgutils.NilOrTime(obj.GetTimestamp()), obj.GetEnum(), obj.GetEnums(), @@ -165,7 +165,7 @@ func copyFromTestSingleUUIDKeyStructs(ctx context.Context, s pgSearch.Deleter, t obj.GetUint64(), obj.GetInt64(), obj.GetFloat(), - obj.GetLabels(), + pgutils.EmptyOrMap(obj.GetLabels()), pgutils.NilOrTime(obj.GetTimestamp()), obj.GetEnum(), obj.GetEnums(),