From d6cb7f78471874119358ddbc54c447c76497f250 Mon Sep 17 00:00:00 2001 From: Ahrav Dutta Date: Fri, 9 Dec 2022 15:29:03 -0800 Subject: [PATCH 1/6] Handle error when scanning s# bucket. --- pkg/sources/s3/s3.go | 69 ++++++++++++++++++++------------------------ 1 file changed, 32 insertions(+), 37 deletions(-) diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index cce0e0748bf8..cc5ba566e477 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -15,7 +15,7 @@ import ( diskbufferreader "github.com/bill-rich/disk-buffer-reader" "github.com/go-errors/errors" "github.com/go-logr/logr" - "golang.org/x/sync/semaphore" + "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" @@ -29,16 +29,16 @@ import ( ) type Source struct { - name string - sourceId int64 - jobId int64 - verify bool - concurrency int - aCtx context.Context - log logr.Logger + name string + sourceId int64 + jobId int64 + verify bool + aCtx context.Context + log logr.Logger sources.Progress errorCount *sync.Map conn *sourcespb.S3 + jobPool *errgroup.Group } // Ensure the Source satisfies the interface at compile time @@ -66,7 +66,8 @@ func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64, s.sourceId = sourceId s.jobId = jobId s.verify = verify - s.concurrency = concurrency + s.jobPool = &errgroup.Group{} + s.jobPool.SetLimit(concurrency) s.errorCount = &sync.Map{} var conn sourcespb.S3 @@ -108,23 +109,23 @@ func (s *Source) newClient(region string) (*s3.S3, error) { // Chunks emits chunks of bytes over a channel. func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error { - client, err := s.newClient("us-east-1") + const defaultAWSRegion = "us-east-1" + + client, err := s.newClient(defaultAWSRegion) if err != nil { return errors.WrapPrefix(err, "could not create s3 client", 0) } - bucketsToScan := []string{} + var bucketsToScan []string switch s.conn.GetCredential().(type) { case *sourcespb.S3_AccessKey, *sourcespb.S3_CloudEnvironment: if len(s.conn.Buckets) == 0 { res, err := client.ListBuckets(&s3.ListBucketsInput{}) if err != nil { - s.log.Error(err, "could not list s3 buckets") - return errors.WrapPrefix(err, "could not list s3 buckets", 0) + return fmt.Errorf("could not list s3 buckets: %w", err) } - buckets := res.Buckets - for _, bucket := range buckets { + for _, bucket := range res.Buckets { bucketsToScan = append(bucketsToScan, *bucket.Name) } } else { @@ -150,7 +151,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err continue } var regionalClient *s3.S3 - if region != "us-east-1" { + if region != defaultAWSRegion { regionalClient, err = s.newClient(region) if err != nil { s.log.Error(err, "could not make regional s3 client") @@ -170,7 +171,6 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err if err != nil { s.log.Error(err, "could not list objects in s3 bucket", "bucket", bucket) - return errors.WrapPrefix(err, fmt.Sprintf("could not list objects in s3 bucket: %s", bucket), 0) } } s.SetProgressComplete(len(bucketsToScan), len(bucketsToScan), fmt.Sprintf("Completed scanning source %s", s.name), "") @@ -180,9 +180,8 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err // pageChunker emits chunks onto the given channel from a page func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan *sources.Chunk, bucket string, page *s3.ListObjectsV2Output, errorCount *sync.Map) { - sem := semaphore.NewWeighted(int64(s.concurrency)) - var wg sync.WaitGroup for _, obj := range page.Contents { + obj := obj if common.IsDone(ctx) { return } @@ -215,19 +214,12 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan return } - err := sem.Acquire(ctx, 1) - if err != nil { - s.log.Error(err, "could not acquire semaphore") - continue - } - wg.Add(1) - go func(ctx context.Context, wg *sync.WaitGroup, sem *semaphore.Weighted, obj *s3.Object) { + s.jobPool.Go(func() error { defer common.RecoverWithExit(ctx) - defer sem.Release(1) - defer wg.Done() if (*obj.Key)[len(*obj.Key)-1:] == "/" { - return + s.log.V(5).Info("Skipping directory", "object", *obj.Key) + return nil } path := strings.Split(*obj.Key, "/") @@ -239,7 +231,7 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan } if nErr.(int) > 3 { s.log.V(2).Info("Skipped due to excessive errors", "object", *obj.Key) - return + return nil } // files break with spaces, must replace with + @@ -261,7 +253,7 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan } if nErr.(int) > 3 { s.log.V(3).Info("Skipped due to excessive errors", "object", *obj.Key) - return + return nil } nErr = nErr.(int) + 1 errorCount.Store(prefix, nErr) @@ -269,14 +261,14 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan if nErr.(int) > 3 { s.log.V(2).Info("Too many consecutive errors, excluding prefix", "prefix", prefix) } - return + return nil } defer res.Body.Close() reader, err := diskbufferreader.New(res.Body) if err != nil { s.log.Error(err, "Could not create reader.") - return + return nil } defer reader.Close() @@ -303,7 +295,7 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan Verify: s.verify, } if handlers.HandleFile(ctx, reader, chunkSkel, chunksChan) { - return + return nil } if err := reader.Reset(); err != nil { @@ -315,7 +307,7 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan chunkData, err := io.ReadAll(reader) if err != nil { s.log.Error(err, "Could not read file data.") - return + return nil } chunk.Data = chunkData chunksChan <- &chunk @@ -327,9 +319,12 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan if nErr.(int) > 0 { errorCount.Store(prefix, 0) } - }(ctx, &wg, sem, obj) + return nil + }) + + _ = s.jobPool.Wait() + s.log.V(5).Info("Finished processing object", "object", *obj.Key) } - wg.Wait() } // S3 links currently have the general format of: From 571f1c4e1f66693134db267552166d945d8e7afc Mon Sep 17 00:00:00 2001 From: Ahrav Dutta Date: Fri, 9 Dec 2022 15:32:07 -0800 Subject: [PATCH 2/6] move wait outside loop. --- pkg/sources/s3/s3.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index cc5ba566e477..be837f23793b 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -322,9 +322,10 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan return nil }) - _ = s.jobPool.Wait() s.log.V(5).Info("Finished processing object", "object", *obj.Key) } + + _ = s.jobPool.Wait() } // S3 links currently have the general format of: From af2e4e8b93788f49e59718e8a60ea9d61498d597 Mon Sep 17 00:00:00 2001 From: Ahrav Dutta Date: Fri, 9 Dec 2022 15:34:24 -0800 Subject: [PATCH 3/6] Add logging. --- pkg/sources/s3/s3.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index be837f23793b..e9ba0a8e6e9a 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -322,10 +322,11 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan return nil }) - s.log.V(5).Info("Finished processing object", "object", *obj.Key) + s.log.V(3).Info("Finished processing object", "object", *obj.Key) } _ = s.jobPool.Wait() + s.log.V(1).Info("Finished processing page", "page", page.Name) } // S3 links currently have the general format of: From c8dae15b64b49ba3116d93cc9279a1efaebf361f Mon Sep 17 00:00:00 2001 From: Ahrav Dutta Date: Fri, 9 Dec 2022 16:25:50 -0800 Subject: [PATCH 4/6] revert changes. --- pkg/sources/s3/s3.go | 54 +++++++++++++++++++++------------------ pkg/sources/s3/s3_test.go | 4 +-- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index e9ba0a8e6e9a..13e643daa5a5 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -16,6 +16,7 @@ import ( "github.com/go-errors/errors" "github.com/go-logr/logr" "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" @@ -29,12 +30,13 @@ import ( ) type Source struct { - name string - sourceId int64 - jobId int64 - verify bool - aCtx context.Context - log logr.Logger + name string + sourceId int64 + jobId int64 + verify bool + concurrency int + aCtx context.Context + log logr.Logger sources.Progress errorCount *sync.Map conn *sourcespb.S3 @@ -66,8 +68,7 @@ func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64, s.sourceId = sourceId s.jobId = jobId s.verify = verify - s.jobPool = &errgroup.Group{} - s.jobPool.SetLimit(concurrency) + s.concurrency = concurrency s.errorCount = &sync.Map{} var conn sourcespb.S3 @@ -180,8 +181,9 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err // pageChunker emits chunks onto the given channel from a page func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan *sources.Chunk, bucket string, page *s3.ListObjectsV2Output, errorCount *sync.Map) { + sem := semaphore.NewWeighted(int64(s.concurrency)) + var wg sync.WaitGroup for _, obj := range page.Contents { - obj := obj if common.IsDone(ctx) { return } @@ -214,12 +216,19 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan return } - s.jobPool.Go(func() error { + err := sem.Acquire(ctx, 1) + if err != nil { + s.log.Error(err, "could not acquire semaphore") + continue + } + wg.Add(1) + go func(ctx context.Context, wg *sync.WaitGroup, sem *semaphore.Weighted, obj *s3.Object) { defer common.RecoverWithExit(ctx) + defer sem.Release(1) + defer wg.Done() if (*obj.Key)[len(*obj.Key)-1:] == "/" { - s.log.V(5).Info("Skipping directory", "object", *obj.Key) - return nil + return } path := strings.Split(*obj.Key, "/") @@ -231,7 +240,7 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan } if nErr.(int) > 3 { s.log.V(2).Info("Skipped due to excessive errors", "object", *obj.Key) - return nil + return } // files break with spaces, must replace with + @@ -253,7 +262,7 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan } if nErr.(int) > 3 { s.log.V(3).Info("Skipped due to excessive errors", "object", *obj.Key) - return nil + return } nErr = nErr.(int) + 1 errorCount.Store(prefix, nErr) @@ -261,14 +270,14 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan if nErr.(int) > 3 { s.log.V(2).Info("Too many consecutive errors, excluding prefix", "prefix", prefix) } - return nil + return } defer res.Body.Close() reader, err := diskbufferreader.New(res.Body) if err != nil { s.log.Error(err, "Could not create reader.") - return nil + return } defer reader.Close() @@ -295,7 +304,7 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan Verify: s.verify, } if handlers.HandleFile(ctx, reader, chunkSkel, chunksChan) { - return nil + return } if err := reader.Reset(); err != nil { @@ -307,7 +316,7 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan chunkData, err := io.ReadAll(reader) if err != nil { s.log.Error(err, "Could not read file data.") - return nil + return } chunk.Data = chunkData chunksChan <- &chunk @@ -319,14 +328,9 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan if nErr.(int) > 0 { errorCount.Store(prefix, 0) } - return nil - }) - - s.log.V(3).Info("Finished processing object", "object", *obj.Key) + }(ctx, &wg, sem, obj) } - - _ = s.jobPool.Wait() - s.log.V(1).Info("Finished processing page", "page", page.Name) + wg.Wait() } // S3 links currently have the general format of: diff --git a/pkg/sources/s3/s3_test.go b/pkg/sources/s3/s3_test.go index 1054608f3542..e00b35d441e4 100644 --- a/pkg/sources/s3/s3_test.go +++ b/pkg/sources/s3/s3_test.go @@ -67,7 +67,7 @@ func TestSource_Chunks(t *testing.T) { var cancelOnce sync.Once defer cancelOnce.Do(cancel) - s := Source{} + s := &Source{} log.SetLevel(log.DebugLevel) conn, err := anypb.New(tt.init.connection) @@ -75,7 +75,7 @@ func TestSource_Chunks(t *testing.T) { t.Fatal(err) } - err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 10) + err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 8) if (err != nil) != tt.wantErr { t.Errorf("Source.Init() error = %v, wantErr %v", err, tt.wantErr) return From 49c15e6b368eeb4ada2856fd962b45d54cc3498b Mon Sep 17 00:00:00 2001 From: Ahrav Dutta Date: Fri, 9 Dec 2022 16:26:28 -0800 Subject: [PATCH 5/6] remove. --- pkg/sources/s3/s3.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index 13e643daa5a5..bbdf8d9ac795 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -15,7 +15,6 @@ import ( diskbufferreader "github.com/bill-rich/disk-buffer-reader" "github.com/go-errors/errors" "github.com/go-logr/logr" - "golang.org/x/sync/errgroup" "golang.org/x/sync/semaphore" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" @@ -40,7 +39,6 @@ type Source struct { sources.Progress errorCount *sync.Map conn *sourcespb.S3 - jobPool *errgroup.Group } // Ensure the Source satisfies the interface at compile time From 0fb57b63bc7dc157a761a07ea07c6086f7a1f7e8 Mon Sep 17 00:00:00 2001 From: Ahrav Dutta Date: Fri, 9 Dec 2022 16:31:51 -0800 Subject: [PATCH 6/6] revert. --- pkg/sources/s3/s3.go | 3 ++- pkg/sources/s3/s3_test.go | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index bbdf8d9ac795..f895d805f299 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -124,7 +124,8 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err if err != nil { return fmt.Errorf("could not list s3 buckets: %w", err) } - for _, bucket := range res.Buckets { + buckets := res.Buckets + for _, bucket := range buckets { bucketsToScan = append(bucketsToScan, *bucket.Name) } } else { diff --git a/pkg/sources/s3/s3_test.go b/pkg/sources/s3/s3_test.go index e00b35d441e4..f437df735621 100644 --- a/pkg/sources/s3/s3_test.go +++ b/pkg/sources/s3/s3_test.go @@ -67,7 +67,7 @@ func TestSource_Chunks(t *testing.T) { var cancelOnce sync.Once defer cancelOnce.Do(cancel) - s := &Source{} + s := Source{} log.SetLevel(log.DebugLevel) conn, err := anypb.New(tt.init.connection)