diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index cce0e0748bf8..f895d805f299 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -108,20 +108,21 @@ func (s *Source) newClient(region string) (*s3.S3, error) { // Chunks emits chunks of bytes over a channel. func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error { - client, err := s.newClient("us-east-1") + const defaultAWSRegion = "us-east-1" + + client, err := s.newClient(defaultAWSRegion) if err != nil { return errors.WrapPrefix(err, "could not create s3 client", 0) } - bucketsToScan := []string{} + var bucketsToScan []string switch s.conn.GetCredential().(type) { case *sourcespb.S3_AccessKey, *sourcespb.S3_CloudEnvironment: if len(s.conn.Buckets) == 0 { res, err := client.ListBuckets(&s3.ListBucketsInput{}) if err != nil { - s.log.Error(err, "could not list s3 buckets") - return errors.WrapPrefix(err, "could not list s3 buckets", 0) + return fmt.Errorf("could not list s3 buckets: %w", err) } buckets := res.Buckets for _, bucket := range buckets { @@ -150,7 +151,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err continue } var regionalClient *s3.S3 - if region != "us-east-1" { + if region != defaultAWSRegion { regionalClient, err = s.newClient(region) if err != nil { s.log.Error(err, "could not make regional s3 client") @@ -170,7 +171,6 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err if err != nil { s.log.Error(err, "could not list objects in s3 bucket", "bucket", bucket) - return errors.WrapPrefix(err, fmt.Sprintf("could not list objects in s3 bucket: %s", bucket), 0) } } s.SetProgressComplete(len(bucketsToScan), len(bucketsToScan), fmt.Sprintf("Completed scanning source %s", s.name), "") diff --git a/pkg/sources/s3/s3_test.go b/pkg/sources/s3/s3_test.go index 1054608f3542..f437df735621 100644 --- a/pkg/sources/s3/s3_test.go +++ b/pkg/sources/s3/s3_test.go @@ -75,7 +75,7 @@ func TestSource_Chunks(t *testing.T) { t.Fatal(err) } - err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 10) + err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 8) if (err != nil) != tt.wantErr { t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr) return