diff --git a/storage/disk/disk.go b/storage/disk/disk.go index 6869e1ce67..3141bf5d99 100644 --- a/storage/disk/disk.go +++ b/storage/disk/disk.go @@ -351,17 +351,20 @@ func (db *Store) Truncate(ctx context.Context, txn storage.Transaction, params s // update symlink to point to the active db symlink := filepath.Join(path.Dir(newDB.Opts().Dir), symlinkKey) + // "active" -> "backupXXXX" is what we want, not + // "active" -> "DIR/backupXXX", since that won't work when using a relative directory + target := filepath.Base(newDB.Opts().Dir) if _, err := os.Lstat(symlink); err == nil { if err := os.Remove(symlink); err != nil { return wrapError(err) } - err = os.Symlink(newDB.Opts().Dir, symlink) + err = os.Symlink(target, symlink) if err != nil { return wrapError(err) } } else if errors.Is(err, os.ErrNotExist) { - err = os.Symlink(newDB.Opts().Dir, symlink) + err = os.Symlink(target, symlink) if err != nil { return wrapError(err) } diff --git a/storage/disk/disk_test.go b/storage/disk/disk_test.go index 17b06db713..ae2f9fa52a 100644 --- a/storage/disk/disk_test.go +++ b/storage/disk/disk_test.go @@ -167,63 +167,83 @@ func TestPolicies(t *testing.T) { }) } -func TestTruncate(t *testing.T) { +func TestTruncateAbsoluteStoragePath(t *testing.T) { test.WithTempFS(map[string]string{}, func(dir string) { - ctx := context.Background() - s, err := New(ctx, logging.NewNoOpLogger(), nil, Options{Dir: dir, Partitions: nil}) - if err != nil { - t.Fatal(err) - } - defer s.Close(ctx) + runTruncateTest(t, dir) + }) +} - txn := storage.NewTransactionOrDie(ctx, s, storage.WriteParams) +func TestTruncateRelativeStoragePath(t *testing.T) { + dir := "foobar" + err := os.Mkdir(dir, 0700) + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + runTruncateTest(t, dir) +} - var archiveFiles = map[string]string{ - "/a/b/c/data.json": "[1,2,3]", - "/a/b/d/data.json": `e: true`, - "/data.json": `{"x": {"y": true}, "a": {"b": {"z": true}}}}`, - "/a/b/y/data.yaml": `foo: 1`, - "/policy.rego": "package foo\n p = 1", - "/roles/policy.rego": "package bar\n p = 1", - } +func runTruncateTest(t *testing.T, dir string) { + ctx := context.Background() + s, err := New(ctx, logging.NewNoOpLogger(), nil, Options{Dir: dir, Partitions: nil}) + if err != nil { + t.Fatal(err) + } + defer s.Close(ctx) - var files [][2]string - for name, content := range archiveFiles { - files = append(files, [2]string{name, content}) - } + txn := storage.NewTransactionOrDie(ctx, s, storage.WriteParams) - buf := archive.MustWriteTarGz(files) - b, err := bundle.NewReader(buf).WithLazyLoadingMode(true).Read() - if err != nil { - t.Fatal(err) - } + var archiveFiles = map[string]string{ + "/a/b/c/data.json": "[1,2,3]", + "/a/b/d/data.json": `e: true`, + "/data.json": `{"x": {"y": true}, "a": {"b": {"z": true}}}}`, + "/a/b/y/data.yaml": `foo: 1`, + "/policy.rego": "package foo\n p = 1", + "/roles/policy.rego": "package bar\n p = 1", + } - iterator := bundle.NewIterator(b.Raw) + var files [][2]string + for name, content := range archiveFiles { + files = append(files, [2]string{name, content}) + } - err = s.Truncate(ctx, txn, storage.WriteParams, iterator) - if err != nil { - t.Fatalf("Unexpected truncate error: %v", err) - } + buf := archive.MustWriteTarGz(files) + b, err := bundle.NewReader(buf).WithLazyLoadingMode(true).Read() + if err != nil { + t.Fatal(err) + } - // check if symlink is created - symlink := filepath.Join(dir, symlinkKey) - _, err = os.Lstat(symlink) - if err != nil { - t.Fatal(err) - } + iterator := bundle.NewIterator(b.Raw) - if err := s.Commit(ctx, txn); err != nil { - t.Fatalf("Unexpected commit error: %v", err) - } + err = s.Truncate(ctx, txn, storage.WriteParams, iterator) + if err != nil { + t.Fatalf("Unexpected truncate error: %v", err) + } - txn = storage.NewTransactionOrDie(ctx, s) + // check if symlink is created + symlink := filepath.Join(dir, symlinkKey) + _, err = os.Lstat(symlink) + if err != nil { + t.Fatal(err) + } + // check symlink target + _, err = filepath.EvalSymlinks(symlink) + if err != nil { + t.Fatalf("eval symlinks: %v", err) + } - actual, err := s.Read(ctx, txn, storage.MustParsePath("/")) - if err != nil { - t.Fatal(err) - } + if err := s.Commit(ctx, txn); err != nil { + t.Fatalf("Unexpected commit error: %v", err) + } - expected := ` + txn = storage.NewTransactionOrDie(ctx, s) + + actual, err := s.Read(ctx, txn, storage.MustParsePath("/")) + if err != nil { + t.Fatal(err) + } + + expected := ` { "a": { "b": { @@ -242,40 +262,48 @@ func TestTruncate(t *testing.T) { } } ` - jsn := util.MustUnmarshalJSON([]byte(expected)) + jsn := util.MustUnmarshalJSON([]byte(expected)) - if !reflect.DeepEqual(jsn, actual) { - t.Fatalf("Expected reader's read to be %v but got: %v", jsn, actual) - } + if !reflect.DeepEqual(jsn, actual) { + t.Fatalf("Expected reader's read to be %v but got: %v", jsn, actual) + } - s.Abort(ctx, txn) + s.Abort(ctx, txn) - txn = storage.NewTransactionOrDie(ctx, s) - ids, err := s.ListPolicies(ctx, txn) - if err != nil { - t.Fatal(err) - } + txn = storage.NewTransactionOrDie(ctx, s) + ids, err := s.ListPolicies(ctx, txn) + if err != nil { + t.Fatal(err) + } - expectedIds := map[string]struct{}{"/policy.rego": {}, "/roles/policy.rego": {}} + expectedIds := map[string]struct{}{"/policy.rego": {}, "/roles/policy.rego": {}} - for _, id := range ids { - if _, ok := expectedIds[id]; !ok { - t.Fatalf("Expected list policies to contain %v but got: %v", expectedIds, id) - } + for _, id := range ids { + if _, ok := expectedIds[id]; !ok { + t.Fatalf("Expected list policies to contain %v but got: %v", expectedIds, id) } + } - bs, err := s.GetPolicy(ctx, txn, "/policy.rego") - expectedBytes := []byte("package foo\n p = 1") - if err != nil || !reflect.DeepEqual(expectedBytes, bs) { - t.Fatalf("Expected get policy to return %v but got: %v (err: %v)", expectedBytes, bs, err) - } + bs, err := s.GetPolicy(ctx, txn, "/policy.rego") + expectedBytes := []byte("package foo\n p = 1") + if err != nil || !reflect.DeepEqual(expectedBytes, bs) { + t.Fatalf("Expected get policy to return %v but got: %v (err: %v)", expectedBytes, bs, err) + } - bs, err = s.GetPolicy(ctx, txn, "/roles/policy.rego") - expectedBytes = []byte("package bar\n p = 1") - if err != nil || !reflect.DeepEqual(expectedBytes, bs) { - t.Fatalf("Expected get policy to return %v but got: %v (err: %v)", expectedBytes, bs, err) - } - }) + bs, err = s.GetPolicy(ctx, txn, "/roles/policy.rego") + expectedBytes = []byte("package bar\n p = 1") + if err != nil || !reflect.DeepEqual(expectedBytes, bs) { + t.Fatalf("Expected get policy to return %v but got: %v (err: %v)", expectedBytes, bs, err) + } + + // Close and re-open store + if err := s.Close(ctx); err != nil { + t.Fatalf("store close: %v", err) + } + + if _, err := New(ctx, logging.NewNoOpLogger(), nil, Options{Dir: dir, Partitions: nil}); err != nil { + t.Fatalf("store re-open: %v", err) + } } func TestTruncateMultipleTxn(t *testing.T) {