Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add s3 object count to trace logs #975

Merged
merged 2 commits into from Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 8 additions & 4 deletions main.go
Expand Up @@ -93,7 +93,7 @@ var (
syslogTLSKey = syslogScan.Flag("key", "Path to TLS key.").String()
syslogFormat = syslogScan.Flag("format", "Log format. Can be rfc3164 or rfc5424").String()

stderrLevel = zap.NewAtomicLevel()
logLevel = zap.NewAtomicLevel()
)

func init() {
Expand All @@ -113,9 +113,13 @@ func init() {
}
switch {
case *trace:
log.SetLevel(5)
log.SetLevelForControl(logLevel, 5)
logrus.SetLevel(logrus.TraceLevel)
logrus.Debugf("running version %s", version.BuildVersion)
case *debug:
log.SetLevel(2)
log.SetLevelForControl(logLevel, 2)
logrus.SetLevel(logrus.DebugLevel)
logrus.Debugf("running version %s", version.BuildVersion)
default:
Expand Down Expand Up @@ -172,11 +176,11 @@ func run(state overseer.State) {
}
}()
}
logger, sync := log.New("trufflehog", log.WithConsoleSink(os.Stderr, log.WithLeveler(stderrLevel)))
context.SetDefaultLogger(logger)
logger, sync := log.New("trufflehog", log.WithConsoleSink(os.Stderr, log.WithLeveler(logLevel)))
ctx := context.WithLogger(context.TODO(), logger)

defer func() { _ = sync() }()

ctx := context.TODO()
e := engine.Start(ctx,
engine.WithConcurrency(*concurrency),
engine.WithDecoders(decoders.DefaultDecoders()...),
Expand Down
13 changes: 10 additions & 3 deletions pkg/sources/s3/s3.go
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -68,6 +69,7 @@ func (s *Source) Init(aCtx context.Context, name string, jobId, sourceId int64,
s.verify = verify
s.concurrency = concurrency
s.errorCount = &sync.Map{}
s.log = aCtx.Logger()

var conn sourcespb.S3
err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{})
Expand Down Expand Up @@ -137,6 +139,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
return errors.Errorf("invalid configuration given for %s source", s.name)
}

objectCount := uint64(0)
for i, bucket := range bucketsToScan {
if common.IsDone(ctx) {
return nil
Expand Down Expand Up @@ -165,21 +168,21 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
err = regionalClient.ListObjectsV2PagesWithContext(
ctx, &s3.ListObjectsV2Input{Bucket: &bucket},
func(page *s3.ListObjectsV2Output, last bool) bool {
s.pageChunker(ctx, regionalClient, chunksChan, bucket, page, &errorCount)
s.pageChunker(ctx, regionalClient, chunksChan, bucket, page, &errorCount, i+1, &objectCount)
return true
})

if err != nil {
s.log.Error(err, "could not list objects in s3 bucket", "bucket", bucket)
}
}
s.SetProgressComplete(len(bucketsToScan), len(bucketsToScan), fmt.Sprintf("Completed scanning source %s", s.name), "")
s.SetProgressComplete(len(bucketsToScan), len(bucketsToScan), fmt.Sprintf("Completed scanning source %s. %d objects scanned.", s.name, objectCount), "")

return nil
}

// pageChunker emits chunks onto the given channel from a page
func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan *sources.Chunk, bucket string, page *s3.ListObjectsV2Output, errorCount *sync.Map) {
func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan *sources.Chunk, bucket string, page *s3.ListObjectsV2Output, errorCount *sync.Map, pageNumber int, objectCount *uint64) {
sem := semaphore.NewWeighted(int64(s.concurrency))
var wg sync.WaitGroup
for _, obj := range page.Contents {
Expand Down Expand Up @@ -303,6 +306,8 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan
Verify: s.verify,
}
if handlers.HandleFile(ctx, reader, chunkSkel, chunksChan) {
atomic.AddUint64(objectCount, 1)
s.log.V(5).Info("S3 object scanned.", "object_count", objectCount, "page_number", pageNumber)
return
}

Expand All @@ -317,6 +322,8 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan
s.log.Error(err, "Could not read file data.")
return
}
atomic.AddUint64(objectCount, 1)
s.log.V(5).Info("S3 object scanned.", "object_count", objectCount, "page_number", pageNumber)
chunk.Data = chunkData
chunksChan <- &chunk

Expand Down