diff --git a/contrib/google.golang.org/grpc/appsec.go b/contrib/google.golang.org/grpc/appsec.go index cec9c2d424..e5ec845867 100644 --- a/contrib/google.golang.org/grpc/appsec.go +++ b/contrib/google.golang.org/grpc/appsec.go @@ -23,7 +23,8 @@ import ( func appsecUnaryHandlerMiddleware(span ddtrace.Span, handler grpc.UnaryHandler) grpc.UnaryHandler { httpsec.SetAppSecTags(span) return func(ctx context.Context, req interface{}) (interface{}, error) { - op := grpcsec.StartHandlerOperation(grpcsec.HandlerOperationArgs{}, nil) + md, _ := metadata.FromIncomingContext(ctx) + op := grpcsec.StartHandlerOperation(grpcsec.HandlerOperationArgs{Metadata: md}, nil) defer func() { events := op.Finish(grpcsec.HandlerOperationRes{}) if len(events) == 0 { @@ -40,7 +41,8 @@ func appsecUnaryHandlerMiddleware(span ddtrace.Span, handler grpc.UnaryHandler) func appsecStreamHandlerMiddleware(span ddtrace.Span, handler grpc.StreamHandler) grpc.StreamHandler { httpsec.SetAppSecTags(span) return func(srv interface{}, stream grpc.ServerStream) error { - op := grpcsec.StartHandlerOperation(grpcsec.HandlerOperationArgs{}, nil) + md, _ := metadata.FromIncomingContext(stream.Context()) + op := grpcsec.StartHandlerOperation(grpcsec.HandlerOperationArgs{Metadata: md}, nil) defer func() { events := op.Finish(grpcsec.HandlerOperationRes{}) if len(events) == 0 { diff --git a/contrib/google.golang.org/grpc/appsec_test.go b/contrib/google.golang.org/grpc/appsec_test.go index f424380cb4..06481c9e29 100644 --- a/contrib/google.golang.org/grpc/appsec_test.go +++ b/contrib/google.golang.org/grpc/appsec_test.go @@ -10,10 +10,11 @@ import ( "strings" "testing" - "github.com/stretchr/testify/require" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" ) func TestAppSec(t *testing.T) { @@ -33,8 +34,9 @@ func TestAppSec(t *testing.T) { mt := mocktracer.Start() defer mt.Stop() - // Send a XSS attack - res, err := client.Ping(context.Background(), &FixtureRequest{Name: ""}) + // Send a XSS attack in the payload along with the canary value in the RPC metadata + ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log")) + res, err := client.Ping(ctx, &FixtureRequest{Name: ""}) // Check that the handler was properly called require.NoError(t, err) require.Equal(t, "passed", res.Message) @@ -42,17 +44,20 @@ func TestAppSec(t *testing.T) { finished := mt.FinishedSpans() require.Len(t, finished, 1) - // The request should have the XSS attack attempt event (appsec rule id crs-941-100). - event := finished[0].Tag("_dd.appsec.json") + // The request should have the attack attempts + event, _ := finished[0].Tag("_dd.appsec.json").(string) require.NotNil(t, event) - require.True(t, strings.Contains(event.(string), "crs-941-100")) + require.True(t, strings.Contains(event, "crs-941-100")) // XSS attack attempt + require.True(t, strings.Contains(event, "ua0-600-55x")) // canary rule attack attempt }) t.Run("stream", func(t *testing.T) { mt := mocktracer.Start() defer mt.Stop() - stream, err := client.StreamPing(context.Background()) + // Send a XSS attack in the payload along with the canary value in the RPC metadata + ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log")) + stream, err := client.StreamPing(ctx) require.NoError(t, err) // Send a XSS attack @@ -81,11 +86,11 @@ func TestAppSec(t *testing.T) { finished := mt.FinishedSpans() require.Len(t, finished, 6) - // The request should both attacks: the XSS and SQLi attack attempt - // events (appsec rule id crs-941-100, crs-942-100). - event := finished[5].Tag("_dd.appsec.json") + // The request should have the attack attempts + event, _ := finished[5].Tag("_dd.appsec.json").(string) require.NotNil(t, event) - require.True(t, strings.Contains(event.(string), "crs-941-100")) - require.True(t, strings.Contains(event.(string), "crs-942-100")) + require.True(t, strings.Contains(event, "crs-941-100")) // XSS attack attempt + require.True(t, strings.Contains(event, "crs-942-100")) // SQL-injection attack attempt + require.True(t, strings.Contains(event, "ua0-600-55x")) // canary rule attack attempt }) } diff --git a/internal/appsec/dyngo/instrumentation/grpcsec/grpc.go b/internal/appsec/dyngo/instrumentation/grpcsec/grpc.go index d35cd817cf..304ffb46b5 100644 --- a/internal/appsec/dyngo/instrumentation/grpcsec/grpc.go +++ b/internal/appsec/dyngo/instrumentation/grpcsec/grpc.go @@ -41,8 +41,12 @@ type ( events []json.RawMessage mu sync.Mutex } - // HandlerOperationArgs is the grpc handler arguments. Empty as of today. - HandlerOperationArgs struct{} + // HandlerOperationArgs is the grpc handler arguments. + HandlerOperationArgs struct { + // Message received by the gRPC handler. + // Corresponds to the address `grpc.server.request.metadata`. + Metadata map[string][]string + } // HandlerOperationRes is the grpc handler results. Empty as of today. HandlerOperationRes struct{} diff --git a/internal/appsec/waf.go b/internal/appsec/waf.go index b08343bfeb..704a9871f9 100644 --- a/internal/appsec/waf.go +++ b/internal/appsec/waf.go @@ -144,7 +144,7 @@ func newHTTPWAFEventListener(handle *waf.Handle, addresses []string, timeout tim // newGRPCWAFEventListener returns the WAF event listener to register in order // to enable it. func newGRPCWAFEventListener(handle *waf.Handle, _ []string, timeout time.Duration, limiter Limiter) dyngo.EventListener { - return grpcsec.OnHandlerOperationStart(func(op *grpcsec.HandlerOperation, _ grpcsec.HandlerOperationArgs) { + return grpcsec.OnHandlerOperationStart(func(op *grpcsec.HandlerOperation, handlerArgs grpcsec.HandlerOperationArgs) { // Limit the maximum number of security events, as a streaming RPC could // receive unlimited number of messages where we could find security events const maxWAFEventsPerRequest = 10 @@ -180,7 +180,11 @@ func newGRPCWAFEventListener(handle *waf.Handle, _ []string, timeout time.Durati // Note that we don't check if the address is present in the rules // as we only support one at the moment, so this callback cannot be // set when the address is not present. - event := runWAF(wafCtx, map[string]interface{}{grpcServerRequestMessage: res.Message}, timeout) + values := map[string]interface{}{grpcServerRequestMessage: res.Message} + if md := handlerArgs.Metadata; len(md) > 0 { + values[grpcServerRequestMetadata] = md + } + event := runWAF(wafCtx, values, timeout) if len(event) == 0 { return } @@ -235,12 +239,14 @@ var httpAddresses = []string{ // gRPC rule addresses currently supported by the WAF const ( - grpcServerRequestMessage = "grpc.server.request.message" + grpcServerRequestMessage = "grpc.server.request.message" + grpcServerRequestMetadata = "grpc.server.request.metadata" ) // List of gRPC rule addresses currently supported by the WAF var grpcAddresses = []string{ grpcServerRequestMessage, + grpcServerRequestMetadata, } func init() {