diff --git a/go.mod b/go.mod index a485429593..779e384fd5 100644 --- a/go.mod +++ b/go.mod @@ -50,7 +50,6 @@ require ( github.com/aws/aws-sdk-go-v2 v1.16.16 github.com/aws/aws-sdk-go-v2/config v1.17.8 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.17 - github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.34 github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs v1.15.20 github.com/aws/aws-sdk-go-v2/service/ec2 v1.32.0 github.com/aws/aws-sdk-go-v2/service/kms v1.18.12 diff --git a/go.sum b/go.sum index 2e3a4c5735..9735ec7ed7 100644 --- a/go.sum +++ b/go.sum @@ -269,8 +269,6 @@ github.com/aws/aws-sdk-go-v2/credentials v1.12.21 h1:4tjlyCD0hRGNQivh5dN8hbP30qQ github.com/aws/aws-sdk-go-v2/credentials v1.12.21/go.mod h1:O+4XyAt4e+oBAoIwNUYkRg3CVMscaIJdmZBOcPgJ8D8= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.17 h1:r08j4sbZu/RVi+BNxkBJwPMUYY3P8mgSDuKkZ/ZN1lE= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.17/go.mod h1:yIkQcCDYNsZfXpd5UX2Cy+sWA1jPgIhGTw9cOBzfVnQ= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.34 h1:1PNtaCM+2ruo1dfYL2RweUdtbuPvinjAejjNcPa/RQY= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.34/go.mod h1:+Six+CXNHYllXam32j+YW8ixk82+am345ei89kEz8p4= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.6/go.mod h1:SSPEdf9spsFgJyhjrXvawfpyzrXHBCUe+2eQ1CjC1Ak= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.23 h1:s4g/wnzMf+qepSNgTvaQQHNxyMLKSawNhKCPNy++2xY= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.23/go.mod h1:2DFxAQ9pfIRy0imBCJv+vZ2X6RKxves6fbnEuSry6b4= diff --git a/hack/go.mod b/hack/go.mod index cf85169d0e..4111f95c4c 100644 --- a/hack/go.mod +++ b/hack/go.mod @@ -79,7 +79,6 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.17.8 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.12.21 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.17 // indirect - github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.34 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.23 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.17 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.3.24 // indirect @@ -155,7 +154,6 @@ require ( github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect github.com/jedisct1/go-minisign v0.0.0-20211028175153-1c139d1cc84b // indirect - github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/kevinburke/ssh_config v0.0.0-20201106050909-4977a11b4351 // indirect diff --git a/hack/go.sum b/hack/go.sum index a9ff5defbd..921e5c2625 100644 --- a/hack/go.sum +++ b/hack/go.sum @@ -230,8 +230,6 @@ github.com/aws/aws-sdk-go-v2/credentials v1.12.21 h1:4tjlyCD0hRGNQivh5dN8hbP30qQ github.com/aws/aws-sdk-go-v2/credentials v1.12.21/go.mod h1:O+4XyAt4e+oBAoIwNUYkRg3CVMscaIJdmZBOcPgJ8D8= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.17 h1:r08j4sbZu/RVi+BNxkBJwPMUYY3P8mgSDuKkZ/ZN1lE= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.17/go.mod h1:yIkQcCDYNsZfXpd5UX2Cy+sWA1jPgIhGTw9cOBzfVnQ= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.34 h1:1PNtaCM+2ruo1dfYL2RweUdtbuPvinjAejjNcPa/RQY= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.34/go.mod h1:+Six+CXNHYllXam32j+YW8ixk82+am345ei89kEz8p4= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.6/go.mod h1:SSPEdf9spsFgJyhjrXvawfpyzrXHBCUe+2eQ1CjC1Ak= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.23 h1:s4g/wnzMf+qepSNgTvaQQHNxyMLKSawNhKCPNy++2xY= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.23/go.mod h1:2DFxAQ9pfIRy0imBCJv+vZ2X6RKxves6fbnEuSry6b4= @@ -882,9 +880,7 @@ github.com/jingyugao/rowserrcheck v1.1.1/go.mod h1:4yvlZSDb3IyDTUZJUmpZfm2Hwok+D github.com/jirfag/go-printf-func-name v0.0.0-20200119135958-7558a9eaa5af/go.mod h1:HEWGJkRDzjJY2sqdDwxccsGicWEf9BQOZsq2tV+xzM0= github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= -github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= -github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/jmhodges/clock v0.0.0-20160418191101-880ee4c33548 h1:dYTbLf4m0a5u0KLmPfB6mgxbcV7588bOCx79hxa5Sr4= github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks= diff --git a/kms/internal/storage/azurestorage.go b/kms/internal/storage/azurestorage.go index 2cbd9fc58c..37574506cb 100644 --- a/kms/internal/storage/azurestorage.go +++ b/kms/internal/storage/azurestorage.go @@ -9,36 +9,25 @@ package storage import ( "bytes" "context" - "errors" "fmt" "io" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" - "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" "github.com/edgelesssys/constellation/v2/kms/internal/config" ) -type azureContainerAPI interface { - Create(ctx context.Context, options *azblob.ContainerCreateOptions) (azblob.ContainerCreateResponse, error) - NewBlockBlobClient(blobName string) (azureBlobAPI, error) -} - type azureBlobAPI interface { - DownloadToWriterAt(ctx context.Context, offset int64, count int64, writer io.WriterAt, options azblob.DownloadOptions) error - Upload(ctx context.Context, body io.ReadSeekCloser, options *azblob.BlockBlobUploadOptions) (azblob.BlockBlobUploadResponse, error) -} - -type wrappedAzureClient struct { - azblob.ContainerClient -} - -func (c wrappedAzureClient) NewBlockBlobClient(blobName string) (azureBlobAPI, error) { - return c.ContainerClient.NewBlockBlobClient(blobName) + CreateContainer(context.Context, string, *container.CreateOptions) (azblob.CreateContainerResponse, error) + DownloadStream(context.Context, string, string, *blob.DownloadStreamOptions) (azblob.DownloadStreamResponse, error) + UploadStream(context.Context, string, string, io.Reader, *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) } // AzureStorage is an implementation of the Storage interface, storing keys in the Azure Blob Store. type AzureStorage struct { - newClient func(ctx context.Context, connectionString, containerName string, opts *azblob.ClientOptions) (azureContainerAPI, error) + client azureBlobAPI connectionString string containerName string opts *AzureOpts @@ -46,8 +35,9 @@ type AzureStorage struct { // AzureOpts are additional options to be used when interacting with the Azure API. type AzureOpts struct { - upload *azblob.BlockBlobUploadOptions - service *azblob.ClientOptions + service *azblob.ClientOptions + download *azblob.DownloadStreamOptions + upload *azblob.UploadStreamOptions } // NewAzureStorage initializes a storage client using Azure's Blob Storage: https://azure.microsoft.com/en-us/services/storage/blobs/ @@ -60,8 +50,13 @@ func NewAzureStorage(ctx context.Context, connectionString, containerName string opts = &AzureOpts{} } + client, err := azblob.NewClientFromConnectionString(connectionString, opts.service) + if err != nil { + return nil, fmt.Errorf("creating storage client from connection string: %w", err) + } + s := &AzureStorage{ - newClient: azureContainerClientFactory, + client: client, connectionString: connectionString, containerName: containerName, opts: opts, @@ -77,90 +72,34 @@ func NewAzureStorage(ctx context.Context, connectionString, containerName string // Get returns a DEK from from Azure Blob Storage by key ID. func (s *AzureStorage) Get(ctx context.Context, keyID string) ([]byte, error) { - client, err := s.newBlobClient(ctx, keyID) + res, err := s.client.DownloadStream(ctx, s.containerName, keyID, s.opts.download) if err != nil { - return nil, err - } - - // the Azure SDK requires an io.WriterAt, the AWS SDK provides a utility function to create one from a byte slice - keyBuffer := manager.NewWriteAtBuffer([]byte{}) - - opts := azblob.DownloadOptions{ - RetryReaderOptionsPerBlock: azblob.RetryReaderOptions{ - MaxRetryRequests: 5, - TreatEarlyCloseAsError: true, - }, - } - - if err := client.DownloadToWriterAt(ctx, 0, 0, keyBuffer, opts); err != nil { - var storeErr *azblob.StorageError - if errors.As(err, &storeErr) && (storeErr.ErrorCode == azblob.StorageErrorCodeBlobNotFound) { + if bloberror.HasCode(err, bloberror.BlobNotFound) { return nil, ErrDEKUnset } return nil, fmt.Errorf("downloading DEK from storage: %w", err) } - - return keyBuffer.Bytes(), nil + defer res.Body.Close() + return io.ReadAll(res.Body) } // Put saves a DEK to Azure Blob Storage by key ID. func (s *AzureStorage) Put(ctx context.Context, keyID string, encDEK []byte) error { - client, err := s.newBlobClient(ctx, keyID) - if err != nil { - return err - } - - if _, err := client.Upload(ctx, readSeekNopCloser{bytes.NewReader(encDEK)}, s.opts.upload); err != nil { + if _, err := s.client.UploadStream(ctx, s.containerName, keyID, bytes.NewReader(encDEK), s.opts.upload); err != nil { return fmt.Errorf("uploading DEK to storage: %w", err) } + return nil } // createContainerOrContinue creates a new storage container if necessary, or continues if it already exists. func (s *AzureStorage) createContainerOrContinue(ctx context.Context) error { - client, err := s.newClient(ctx, s.connectionString, s.containerName, s.opts.service) - if err != nil { - return err - } - - var storeErr *azblob.StorageError - _, err = client.Create(ctx, &azblob.ContainerCreateOptions{ + _, err := s.client.CreateContainer(ctx, s.containerName, &azblob.CreateContainerOptions{ Metadata: config.StorageTags, }) - if (err == nil) || (errors.As(err, &storeErr) && (storeErr.ErrorCode == azblob.StorageErrorCodeContainerAlreadyExists)) { + if (err == nil) || bloberror.HasCode(err, bloberror.ContainerAlreadyExists) { return nil } return fmt.Errorf("creating storage container: %w", err) } - -// newBlobClient is a convenience function to create BlockBlobClients. -func (s *AzureStorage) newBlobClient(ctx context.Context, blobName string) (azureBlobAPI, error) { - c, err := s.newClient(ctx, s.connectionString, s.containerName, s.opts.service) - if err != nil { - return nil, err - } - return c.NewBlockBlobClient(blobName) -} - -func azureContainerClientFactory(ctx context.Context, connectionString, containerName string, opts *azblob.ClientOptions) (azureContainerAPI, error) { - service, err := azblob.NewServiceClientFromConnectionString(connectionString, opts) - if err != nil { - return nil, fmt.Errorf("creating storage client from connection string: %w", err) - } - - containerClient, err := service.NewContainerClient(containerName) - if err != nil { - return nil, fmt.Errorf("creating storage container client: %w", err) - } - return &wrappedAzureClient{*containerClient}, err -} - -// readSeekNopCloser is a wrapper for io.ReadSeeker implementing the Close method. This is required by the Azure SDK. -type readSeekNopCloser struct { - io.ReadSeeker -} - -func (n readSeekNopCloser) Close() error { - return nil -} diff --git a/kms/internal/storage/azurestorage_test.go b/kms/internal/storage/azurestorage_test.go index 5349134c3f..5660001a0a 100644 --- a/kms/internal/storage/azurestorage_test.go +++ b/kms/internal/storage/azurestorage_test.go @@ -7,91 +7,35 @@ SPDX-License-Identifier: AGPL-3.0-only package storage import ( + "bytes" "context" "errors" "io" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" "github.com/stretchr/testify/assert" ) -type stubAzureContainerAPI struct { - newClientErr error - createErr error - createCalled *bool - blockBlobAPI stubAzureBlockBlobAPI -} - -func newStubClientFactory(stub stubAzureContainerAPI) func(ctx context.Context, connectionString, containerName string, opts *azblob.ClientOptions) (azureContainerAPI, error) { - return func(ctx context.Context, connectionString, containerName string, opts *azblob.ClientOptions) (azureContainerAPI, error) { - return stub, stub.newClientErr - } -} - -func (s stubAzureContainerAPI) Create(ctx context.Context, options *azblob.ContainerCreateOptions) (azblob.ContainerCreateResponse, error) { - *s.createCalled = true - return azblob.ContainerCreateResponse{}, s.createErr -} - -func (s stubAzureContainerAPI) NewBlockBlobClient(blobName string) (azureBlobAPI, error) { - return s.blockBlobAPI, nil -} - -type stubAzureBlockBlobAPI struct { - downloadBlobToWriterAtErr error - downloadBlobToWriterOutput []byte - uploadErr error - uploadData chan []byte -} - -func (s stubAzureBlockBlobAPI) DownloadToWriterAt(ctx context.Context, offset int64, count int64, writer io.WriterAt, o azblob.DownloadOptions) error { - if _, err := writer.WriteAt(s.downloadBlobToWriterOutput, 0); err != nil { - panic(err) - } - return s.downloadBlobToWriterAtErr -} - -func (s stubAzureBlockBlobAPI) Upload(ctx context.Context, body io.ReadSeekCloser, options *azblob.BlockBlobUploadOptions) (azblob.BlockBlobUploadResponse, error) { - res, err := io.ReadAll(body) - if err != nil { - panic(err) - } - s.uploadData <- res - return azblob.BlockBlobUploadResponse{}, s.uploadErr -} - func TestAzureGet(t *testing.T) { - someErr := errors.New("error") - testCases := map[string]struct { - client stubAzureContainerAPI + client stubAzureBlobAPI unsetError bool wantErr bool }{ "success": { - client: stubAzureContainerAPI{ - blockBlobAPI: stubAzureBlockBlobAPI{downloadBlobToWriterOutput: []byte("test-data")}, - }, - }, - "creating client fails": { - client: stubAzureContainerAPI{newClientErr: someErr}, - wantErr: true, + client: stubAzureBlobAPI{downloadData: []byte{0x1, 0x2, 0x3}}, }, - "DownloadBlobToBuffer fails": { - client: stubAzureContainerAPI{ - blockBlobAPI: stubAzureBlockBlobAPI{downloadBlobToWriterAtErr: someErr}, - }, + "DownloadBuffer fails": { + client: stubAzureBlobAPI{downloadErr: errors.New("failed")}, wantErr: true, }, "BlobNotFound error": { - client: stubAzureContainerAPI{ - blockBlobAPI: stubAzureBlockBlobAPI{ - downloadBlobToWriterAtErr: &azblob.StorageError{ - ErrorCode: azblob.StorageErrorCodeBlobNotFound, - }, - }, - }, + client: stubAzureBlobAPI{downloadErr: &azcore.ResponseError{ErrorCode: string(bloberror.BlobNotFound)}}, unsetError: true, wantErr: true, }, @@ -102,7 +46,7 @@ func TestAzureGet(t *testing.T) { assert := assert.New(t) client := &AzureStorage{ - newClient: newStubClientFactory(tc.client), + client: &tc.client, connectionString: "test", containerName: "test", opts: &AzureOpts{}, @@ -117,33 +61,24 @@ func TestAzureGet(t *testing.T) { } else { assert.False(errors.Is(err, ErrDEKUnset)) } - - } else { - assert.NoError(err) - assert.Equal(tc.client.blockBlobAPI.downloadBlobToWriterOutput, out) + return } + assert.NoError(err) + assert.Equal(tc.client.downloadData, out) }) } } func TestAzurePut(t *testing.T) { - someErr := errors.New("error") - testCases := map[string]struct { - client stubAzureContainerAPI + client stubAzureBlobAPI wantErr bool }{ "success": { - client: stubAzureContainerAPI{}, - }, - "creating client fails": { - client: stubAzureContainerAPI{newClientErr: someErr}, - wantErr: true, + client: stubAzureBlobAPI{}, }, "Upload fails": { - client: stubAzureContainerAPI{ - blockBlobAPI: stubAzureBlockBlobAPI{uploadErr: someErr}, - }, + client: stubAzureBlobAPI{uploadErr: errors.New("failed")}, wantErr: true, }, } @@ -153,10 +88,9 @@ func TestAzurePut(t *testing.T) { assert := assert.New(t) testData := []byte{0x1, 0x2, 0x3} - tc.client.blockBlobAPI.uploadData = make(chan []byte, len(testData)) client := &AzureStorage{ - newClient: newStubClientFactory(tc.client), + client: &tc.client, connectionString: "test", containerName: "test", opts: &AzureOpts{}, @@ -165,32 +99,27 @@ func TestAzurePut(t *testing.T) { err := client.Put(context.Background(), "test-key", testData) if tc.wantErr { assert.Error(err) - } else { - assert.NoError(err) - assert.Equal(testData, <-tc.client.blockBlobAPI.uploadData) + return } + assert.NoError(err) + assert.Equal(testData, tc.client.uploadData) }) } } func TestCreateContainerOrContinue(t *testing.T) { - someErr := errors.New("error") testCases := map[string]struct { - client stubAzureContainerAPI + client stubAzureBlobAPI wantErr bool }{ "success": { - client: stubAzureContainerAPI{}, + client: stubAzureBlobAPI{}, }, "container already exists": { - client: stubAzureContainerAPI{createErr: &azblob.StorageError{ErrorCode: azblob.StorageErrorCodeContainerAlreadyExists}}, + client: stubAzureBlobAPI{createErr: &azcore.ResponseError{ErrorCode: string(bloberror.ContainerAlreadyExists)}}, }, - "creating client fails": { - client: stubAzureContainerAPI{newClientErr: someErr}, - wantErr: true, - }, - "Create fails": { - client: stubAzureContainerAPI{createErr: someErr}, + "CreateContainer fails": { + client: stubAzureBlobAPI{createErr: errors.New("failed")}, wantErr: true, }, } @@ -199,9 +128,8 @@ func TestCreateContainerOrContinue(t *testing.T) { t.Run(name, func(t *testing.T) { assert := assert.New(t) - tc.client.createCalled = new(bool) client := &AzureStorage{ - newClient: newStubClientFactory(tc.client), + client: &tc.client, connectionString: "test", containerName: "test", opts: &AzureOpts{}, @@ -212,8 +140,37 @@ func TestCreateContainerOrContinue(t *testing.T) { assert.Error(err) } else { assert.NoError(err) - assert.True(*tc.client.createCalled) + assert.True(tc.client.createCalled) } }) } } + +type stubAzureBlobAPI struct { + createErr error + createCalled bool + downloadErr error + downloadData []byte + uploadErr error + uploadData []byte +} + +func (s *stubAzureBlobAPI) CreateContainer(context.Context, string, *container.CreateOptions) (azblob.CreateContainerResponse, error) { + s.createCalled = true + return azblob.CreateContainerResponse{}, s.createErr +} + +func (s *stubAzureBlobAPI) DownloadStream(context.Context, string, string, *blob.DownloadStreamOptions) (blob.DownloadStreamResponse, error) { + res := blob.DownloadStreamResponse{} + res.Body = io.NopCloser(bytes.NewReader(s.downloadData)) + return res, s.downloadErr +} + +func (s *stubAzureBlobAPI) UploadStream(_ context.Context, _, _ string, data io.Reader, _ *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) { + uploadData, err := io.ReadAll(data) + if err != nil { + return azblob.UploadStreamResponse{}, err + } + s.uploadData = uploadData + return azblob.UploadStreamResponse{}, s.uploadErr +}