diff --git a/storage/inmem/inmem.go b/storage/inmem/inmem.go index 7d91db72ca..818bd17a96 100644 --- a/storage/inmem/inmem.go +++ b/storage/inmem/inmem.go @@ -19,9 +19,13 @@ import ( "context" "fmt" "io" + "path/filepath" + "strings" "sync" "sync/atomic" + "github.com/open-policy-agent/opa/internal/merge" + "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/util" ) @@ -96,6 +100,7 @@ func (db *store) NewTransaction(_ context.Context, params ...storage.Transaction func (db *store) Truncate(ctx context.Context, txn storage.Transaction, _ storage.TransactionParams, it storage.Iterator) error { var update *storage.Update var err error + var mergedData map[string]interface{} underlying, err := db.underlying(txn) if err != nil { @@ -114,52 +119,29 @@ func (db *store) Truncate(ctx context.Context, txn storage.Transaction, _ storag return err } } else { - if len(update.Path) > 0 { - var obj interface{} - err = util.Unmarshal(update.Value, &obj) - if err != nil { - return err - } - - _, err := underlying.Read(update.Path[:len(update.Path)-1]) - if err != nil { - if !storage.IsNotFound(err) { - return err - } + var value interface{} + err = util.Unmarshal(update.Value, &value) + if err != nil { + return err + } - if err := storage.MakeDir(ctx, db, txn, update.Path[:len(update.Path)-1]); err != nil { - return err - } - } + var key []string + dirpath := strings.TrimLeft(update.Path.String(), "/") + if len(dirpath) > 0 { + key = strings.Split(dirpath, "/") + } - err = underlying.Write(storage.AddOp, update.Path, obj) + if value != nil { + obj, err := mktree(key, value) if err != nil { return err } - } else { - // write operation at root path - - var val map[string]interface{} - err := util.Unmarshal(update.Value, &val) - if err != nil { - return invalidPatchError(rootMustBeObjectMsg) - } - - for k := range val { - newPath, ok := storage.ParsePathEscaped("/" + k) - if !ok { - return fmt.Errorf("storage path invalid: %v", newPath) - } - - if err := storage.MakeDir(ctx, db, txn, newPath[:len(newPath)-1]); err != nil { - return err - } - err = underlying.Write(storage.AddOp, newPath, val[k]) - if err != nil { - return err - } + merged, ok := merge.InterfaceMaps(mergedData, obj) + if !ok { + return fmt.Errorf("failed to insert data file from path %s", filepath.Join(key...)) } + mergedData = merged } } } @@ -168,6 +150,24 @@ func (db *store) Truncate(ctx context.Context, txn storage.Transaction, _ storag return err } + // write merged data to store + for k := range mergedData { + newPath, ok := storage.ParsePathEscaped("/" + k) + if !ok { + return fmt.Errorf("storage path invalid: %v", newPath) + } + + if len(newPath) > 0 { + if err := storage.MakeDir(ctx, db, txn, newPath[:len(newPath)-1]); err != nil { + return err + } + } + + if err := underlying.Write(storage.AddOp, newPath, mergedData[k]); err != nil { + return err + } + } + return nil } @@ -327,3 +327,24 @@ func invalidPatchError(f string, a ...interface{}) *storage.Error { Message: fmt.Sprintf(f, a...), } } + +func mktree(path []string, value interface{}) (map[string]interface{}, error) { + if len(path) == 0 { + // For 0 length path the value is the full tree. + obj, ok := value.(map[string]interface{}) + if !ok { + return nil, invalidPatchError(rootMustBeObjectMsg) + } + return obj, nil + } + + dir := map[string]interface{}{} + for i := len(path) - 1; i > 0; i-- { + dir[path[i]] = value + value = dir + dir = map[string]interface{}{} + } + dir[path[0]] = value + + return dir, nil +} diff --git a/storage/inmem/inmem_test.go b/storage/inmem/inmem_test.go index cc3ed5d5c6..dd675b6c7f 100644 --- a/storage/inmem/inmem_test.go +++ b/storage/inmem/inmem_test.go @@ -476,6 +476,40 @@ func TestTruncate(t *testing.T) { } } +func TestTruncateDataMergeError(t *testing.T) { + ctx := context.Background() + store := NewFromObject(map[string]interface{}{}) + txn := storage.NewTransactionOrDie(ctx, store, storage.WriteParams) + + var archiveFiles = map[string]string{ + "/a/b/data.json": `{"c": "foo"}`, + "/data.json": `{"a": {"b": {"c": "bar"}}}`, + } + + var files [][2]string + for name, content := range archiveFiles { + files = append(files, [2]string{name, content}) + } + + buf := archive.MustWriteTarGz(files) + b, err := bundle.NewReader(buf).WithLazyLoadingMode(true).Read() + if err != nil { + t.Fatal(err) + } + + iterator := bundle.NewIterator(b.Raw) + + err = store.Truncate(ctx, txn, storage.WriteParams, iterator) + if err == nil { + t.Fatal("Expected truncate error but got nil") + } + + expected := "failed to insert data file from path a/b" + if err.Error() != expected { + t.Fatalf("Expected error %v but got %v", expected, err.Error()) + } +} + func TestTruncateBadRootWrite(t *testing.T) { ctx := context.Background() store := NewFromObject(map[string]interface{}{})