diff --git a/cmd/registry/main.go b/cmd/registry/main.go index 88cccf1dcd..573122c539 100644 --- a/cmd/registry/main.go +++ b/cmd/registry/main.go @@ -14,6 +14,7 @@ import ( _ "github.com/distribution/distribution/v3/registry/storage/driver/inmemory" _ "github.com/distribution/distribution/v3/registry/storage/driver/middleware/cloudfront" _ "github.com/distribution/distribution/v3/registry/storage/driver/middleware/redirect" + _ "github.com/distribution/distribution/v3/registry/storage/driver/middleware/rewrite" _ "github.com/distribution/distribution/v3/registry/storage/driver/s3-aws" ) diff --git a/registry/storage/driver/middleware/rewrite/middleware.go b/registry/storage/driver/middleware/rewrite/middleware.go new file mode 100644 index 0000000000..73f60aaa4c --- /dev/null +++ b/registry/storage/driver/middleware/rewrite/middleware.go @@ -0,0 +1,82 @@ +package middleware + +import ( + "context" + "fmt" + "net/url" + "strings" + + storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" + storagemiddleware "github.com/distribution/distribution/v3/registry/storage/driver/middleware" +) + +func init() { + storagemiddleware.Register("rewrite", newRewriteStorageMiddleware) +} + +type rewriteStorageMiddleware struct { + storagedriver.StorageDriver + overrideScheme string + overrideHost string + trimPathPrefix string +} + +var _ storagedriver.StorageDriver = &rewriteStorageMiddleware{} + +func getStringOption(key string, options map[string]interface{}) (string, error) { + o, ok := options[key] + if !ok { + return "", nil + } + s, ok := o.(string) + if !ok { + return "", fmt.Errorf("%s must be a string", key) + } + return s, nil +} + +func newRewriteStorageMiddleware(ctx context.Context, sd storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) { + var err error + + r := &rewriteStorageMiddleware{StorageDriver: sd} + + if r.overrideScheme, err = getStringOption("scheme", options); err != nil { + return nil, err + } + + if r.overrideHost, err = getStringOption("host", options); err != nil { + return nil, err + } + + if r.trimPathPrefix, err = getStringOption("trimpathprefix", options); err != nil { + return nil, err + } + + return r, nil +} + +func (r *rewriteStorageMiddleware) URLFor(ctx context.Context, urlPath string, options map[string]interface{}) (string, error) { + storagePath, err := r.StorageDriver.URLFor(ctx, urlPath, options) + if err != nil { + return "", err + } + + u, err := url.Parse(storagePath) + if err != nil { + return "", err + } + + if r.overrideScheme != "" { + u.Scheme = r.overrideScheme + } + + if r.overrideHost != "" { + u.Host = r.overrideHost + } + + if r.trimPathPrefix != "" { + u.Path = strings.TrimPrefix(u.Path, r.trimPathPrefix) + } + + return u.String(), nil +} diff --git a/registry/storage/driver/middleware/rewrite/middleware_test.go b/registry/storage/driver/middleware/rewrite/middleware_test.go new file mode 100644 index 0000000000..d7f1276fa4 --- /dev/null +++ b/registry/storage/driver/middleware/rewrite/middleware_test.go @@ -0,0 +1,81 @@ +package middleware + +import ( + "context" + "testing" + + "github.com/distribution/distribution/v3/registry/storage/driver/base" + "gopkg.in/check.v1" +) + +func Test(t *testing.T) { check.TestingT(t) } + +type MiddlewareSuite struct{} + +var _ = check.Suite(&MiddlewareSuite{}) + +type mockSD struct { + base.Base +} + +func (*mockSD) URLFor(ctx context.Context, urlPath string, options map[string]interface{}) (string, error) { + return "http://some.host/some/path/file", nil +} + +func (s *MiddlewareSuite) TestNoConfig(c *check.C) { + options := make(map[string]interface{}) + middleware, err := newRewriteStorageMiddleware(context.Background(), &mockSD{}, options) + c.Assert(err, check.Equals, nil) + + _, ok := middleware.(*rewriteStorageMiddleware) + c.Assert(ok, check.Equals, true) + + url, err := middleware.URLFor(context.Background(), "", nil) + c.Assert(err, check.Equals, nil) + + c.Assert(url, check.Equals, "http://some.host/some/path/file") +} + +func (s *MiddlewareSuite) TestWrongType(c *check.C) { + options := map[string]interface{}{ + "scheme": 1, + } + _, err := newRewriteStorageMiddleware(context.TODO(), nil, options) + c.Assert(err, check.ErrorMatches, "scheme must be a string") +} + +func (s *MiddlewareSuite) TestRewriteHostsScheme(c *check.C) { + options := map[string]interface{}{ + "scheme": "https", + "host": "example.com", + } + + middleware, err := newRewriteStorageMiddleware(context.TODO(), &mockSD{}, options) + c.Assert(err, check.Equals, nil) + + m, ok := middleware.(*rewriteStorageMiddleware) + c.Assert(ok, check.Equals, true) + c.Assert(m.overrideScheme, check.Equals, "https") + c.Assert(m.overrideHost, check.Equals, "example.com") + + url, err := middleware.URLFor(context.TODO(), "", nil) + c.Assert(err, check.Equals, nil) + c.Assert(url, check.Equals, "https://example.com/some/path/file") +} + +func (s *MiddlewareSuite) TestTrimPrefix(c *check.C) { + options := map[string]interface{}{ + "trimpathprefix": "/some/path", + } + + middleware, err := newRewriteStorageMiddleware(context.TODO(), &mockSD{}, options) + c.Assert(err, check.Equals, nil) + + m, ok := middleware.(*rewriteStorageMiddleware) + c.Assert(ok, check.Equals, true) + c.Assert(m.trimPathPrefix, check.Equals, "/some/path") + + url, err := middleware.URLFor(context.TODO(), "", nil) + c.Assert(err, check.Equals, nil) + c.Assert(url, check.Equals, "http://some.host/file") +}