diff --git a/CODEOWNERS b/CODEOWNERS index 7bfeb6ca6d..87ba38ad00 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -15,7 +15,7 @@ # appsec /appsec @DataDog/appsec-go /internal/appsec @DataDog/appsec-go -/contrib/**/appsec.go @DataDog/appsec-go +/contrib/**/*appsec*.go @DataDog/appsec-go /.github/workflows/appsec.yml @DataDog/appsec-go # telemetry diff --git a/contrib/gin-gonic/gin/appsec.go b/contrib/gin-gonic/gin/appsec.go index 27d85e74b9..1b0946343a 100644 --- a/contrib/gin-gonic/gin/appsec.go +++ b/contrib/gin-gonic/gin/appsec.go @@ -6,8 +6,6 @@ package gin import ( - "net" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation/httpsec" @@ -18,8 +16,8 @@ import ( // useAppSec executes the AppSec logic related to the operation start and // returns the function to be executed upon finishing the operation func useAppSec(c *gin.Context, span tracer.Span) func() { - req := c.Request instrumentation.SetAppSecEnabledTags(span) + var params map[string]string if l := len(c.Params); l > 0 { params = make(map[string]string, l) @@ -27,18 +25,20 @@ func useAppSec(c *gin.Context, span tracer.Span) func() { params[p.Key] = p.Value } } - args := httpsec.MakeHandlerOperationArgs(req, params) + + req := c.Request + ipTags, clientIP := httpsec.ClientIPTags(req.Header, req.RemoteAddr) + instrumentation.SetStringTags(span, ipTags) + + args := httpsec.MakeHandlerOperationArgs(req, clientIP, params) ctx, op := httpsec.StartOperation(req.Context(), args) c.Request = req.WithContext(ctx) + return func() { events := op.Finish(httpsec.HandlerOperationRes{Status: c.Writer.Status()}) + instrumentation.SetTags(span, op.Tags()) if len(events) > 0 { - remoteIP, _, err := net.SplitHostPort(req.RemoteAddr) - if err != nil { - remoteIP = req.RemoteAddr - } - httpsec.SetSecurityEventTags(span, events, remoteIP, args.Headers, c.Writer.Header()) + httpsec.SetSecurityEventTags(span, events, args.Headers, c.Writer.Header()) } - instrumentation.SetTags(span, op.Tags()) } } diff --git a/contrib/google.golang.org/grpc/appsec.go b/contrib/google.golang.org/grpc/appsec.go index 8006107f67..aa60298cf3 100644 --- a/contrib/google.golang.org/grpc/appsec.go +++ b/contrib/google.golang.org/grpc/appsec.go @@ -7,16 +7,17 @@ package grpc import ( "encoding/json" - "net" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation/grpcsec" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation/httpsec" "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" ) // UnaryHandler wrapper to use when AppSec is enabled to monitor its execution. @@ -24,15 +25,28 @@ func appsecUnaryHandlerMiddleware(span ddtrace.Span, handler grpc.UnaryHandler) instrumentation.SetAppSecEnabledTags(span) return func(ctx context.Context, req interface{}) (interface{}, error) { md, _ := metadata.FromIncomingContext(ctx) - op := grpcsec.StartHandlerOperation(grpcsec.HandlerOperationArgs{Metadata: md}, nil) + var remoteAddr string + if p, ok := peer.FromContext(ctx); ok { + remoteAddr = p.Addr.String() + } + ipTags, clientIP := httpsec.ClientIPTags(md, remoteAddr) + instrumentation.SetStringTags(span, ipTags) + + op := grpcsec.StartHandlerOperation(grpcsec.HandlerOperationArgs{Metadata: md, ClientIP: clientIP}, nil) defer func() { events := op.Finish(grpcsec.HandlerOperationRes{}) instrumentation.SetTags(span, op.Tags()) if len(events) == 0 { return } - setAppSecTags(ctx, span, events) + setAppSecEventsTags(ctx, span, events) }() + + if op.BlockedCode != nil { + op.AddTag(httpsec.BlockedRequestTag, true) + return nil, status.Errorf(*op.BlockedCode, "Request blocked") + } + defer grpcsec.StartReceiveOperation(grpcsec.ReceiveOperationArgs{}, op).Finish(grpcsec.ReceiveOperationRes{Message: req}) return handler(ctx, req) } @@ -43,15 +57,28 @@ func appsecStreamHandlerMiddleware(span ddtrace.Span, handler grpc.StreamHandler instrumentation.SetAppSecEnabledTags(span) return func(srv interface{}, stream grpc.ServerStream) error { md, _ := metadata.FromIncomingContext(stream.Context()) - op := grpcsec.StartHandlerOperation(grpcsec.HandlerOperationArgs{Metadata: md}, nil) + var remoteAddr string + if p, ok := peer.FromContext(stream.Context()); ok { + remoteAddr = p.Addr.String() + } + ipTags, clientIP := httpsec.ClientIPTags(md, remoteAddr) + instrumentation.SetStringTags(span, ipTags) + + op := grpcsec.StartHandlerOperation(grpcsec.HandlerOperationArgs{Metadata: md, ClientIP: clientIP}, nil) defer func() { events := op.Finish(grpcsec.HandlerOperationRes{}) instrumentation.SetTags(span, op.Tags()) if len(events) == 0 { return } - setAppSecTags(stream.Context(), span, events) + setAppSecEventsTags(stream.Context(), span, events) }() + + if op.BlockedCode != nil { + op.AddTag(httpsec.BlockedRequestTag, true) + return status.Error(*op.BlockedCode, "Request blocked") + } + return handler(srv, appsecServerStream{ServerStream: stream, handlerOperation: op}) } } @@ -72,11 +99,7 @@ func (ss appsecServerStream) RecvMsg(m interface{}) error { } // Set the AppSec tags when security events were found. -func setAppSecTags(ctx context.Context, span ddtrace.Span, events []json.RawMessage) { +func setAppSecEventsTags(ctx context.Context, span ddtrace.Span, events []json.RawMessage) { md, _ := metadata.FromIncomingContext(ctx) - var addr net.Addr - if p, ok := peer.FromContext(ctx); ok { - addr = p.Addr - } - grpcsec.SetSecurityEventTags(span, events, addr, md) + grpcsec.SetSecurityEventTags(span, events, md) } diff --git a/contrib/google.golang.org/grpc/appsec_test.go b/contrib/google.golang.org/grpc/appsec_test.go index 14862e7966..23063eee28 100644 --- a/contrib/google.golang.org/grpc/appsec_test.go +++ b/contrib/google.golang.org/grpc/appsec_test.go @@ -14,7 +14,9 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) func TestAppSec(t *testing.T) { @@ -94,3 +96,90 @@ func TestAppSec(t *testing.T) { require.True(t, strings.Contains(event, "ua0-600-55x")) // canary rule attack attempt }) } + +// Test that http blocking works by using custom rules/rules data +func TestBlocking(t *testing.T) { + t.Setenv("DD_APPSEC_RULES", "../../../internal/appsec/testdata/blocking.json") + appsec.Start() + defer appsec.Stop() + if !appsec.Enabled() { + t.Skip("appsec disabled") + } + + rig, err := newRig(false) + require.NoError(t, err) + defer rig.Close() + + client := rig.client + + t.Run("unary-block", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + // Send a XSS attack in the payload along with the canary value in the RPC metadata + ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.4")) + reply, err := client.Ping(ctx, &FixtureRequest{Name: ""}) + + require.Nil(t, reply) + require.Equal(t, codes.Aborted, status.Code(err)) + + finished := mt.FinishedSpans() + require.Len(t, finished, 1) + // The request should have the attack attempts + event, _ := finished[0].Tag("_dd.appsec.json").(string) + require.NotNil(t, event) + require.True(t, strings.Contains(event, "blk-001-001")) + }) + + t.Run("unary-no-block", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + // Send a XSS attack in the payload along with the canary value in the RPC metadata + ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.5")) + reply, err := client.Ping(ctx, &FixtureRequest{Name: ""}) + + require.Equal(t, "passed", reply.Message) + require.Equal(t, codes.OK, status.Code(err)) + }) + + t.Run("stream-block", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.4")) + stream, err := client.StreamPing(ctx) + require.NoError(t, err) + reply, err := stream.Recv() + + require.Equal(t, codes.Aborted, status.Code(err)) + require.Nil(t, reply) + + finished := mt.FinishedSpans() + require.Len(t, finished, 1) + // The request should have the attack attempts + event, _ := finished[0].Tag("_dd.appsec.json").(string) + require.NotNil(t, event) + require.True(t, strings.Contains(event, "blk-001-001")) + }) + + t.Run("stream-no-block", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.5")) + stream, err := client.StreamPing(ctx) + require.NoError(t, err) + + // Send a XSS attack + err = stream.Send(&FixtureRequest{Name: ""}) + require.NoError(t, err) + reply, err := stream.Recv() + require.Equal(t, codes.OK, status.Code(err)) + require.Equal(t, "passed", reply.Message) + + err = stream.CloseSend() + require.NoError(t, err) + }) + +} diff --git a/contrib/labstack/echo.v4/appsec.go b/contrib/labstack/echo.v4/appsec.go index e5dce33396..1b8b90dfd3 100644 --- a/contrib/labstack/echo.v4/appsec.go +++ b/contrib/labstack/echo.v4/appsec.go @@ -6,8 +6,6 @@ package echo import ( - "net" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation/httpsec" @@ -16,23 +14,24 @@ import ( ) func useAppSec(c echo.Context, span tracer.Span) func() { - req := c.Request() instrumentation.SetAppSecEnabledTags(span) + params := make(map[string]string) for _, n := range c.ParamNames() { params[n] = c.Param(n) } - args := httpsec.MakeHandlerOperationArgs(req, params) + + req := c.Request() + ipTags, clientIP := httpsec.ClientIPTags(req.Header, req.RemoteAddr) + instrumentation.SetStringTags(span, ipTags) + + args := httpsec.MakeHandlerOperationArgs(req, clientIP, params) ctx, op := httpsec.StartOperation(req.Context(), args) c.SetRequest(req.WithContext(ctx)) return func() { events := op.Finish(httpsec.HandlerOperationRes{Status: c.Response().Status}) if len(events) > 0 { - remoteIP, _, err := net.SplitHostPort(req.RemoteAddr) - if err != nil { - remoteIP = req.RemoteAddr - } - httpsec.SetSecurityEventTags(span, events, remoteIP, args.Headers, c.Response().Writer.Header()) + httpsec.SetSecurityEventTags(span, events, args.Headers, c.Response().Writer.Header()) } instrumentation.SetTags(span, op.Tags()) } diff --git a/internal/appsec/dyngo/instrumentation/common.go b/internal/appsec/dyngo/instrumentation/common.go index 9d0d66736e..8d2d06e5d9 100644 --- a/internal/appsec/dyngo/instrumentation/common.go +++ b/internal/appsec/dyngo/instrumentation/common.go @@ -49,7 +49,7 @@ func (m *TagsHolder) Tags() map[string]interface{} { // See httpsec/http.go and grpcsec/grpc.go. type SecurityEventsHolder struct { events []json.RawMessage - mu sync.Mutex + mu sync.RWMutex } // AddSecurityEvents adds the security events to the collected events list. @@ -62,9 +62,18 @@ func (s *SecurityEventsHolder) AddSecurityEvents(events ...json.RawMessage) { // Events returns the list of stored events. func (s *SecurityEventsHolder) Events() []json.RawMessage { + s.mu.RLock() + defer s.mu.RUnlock() return s.events } +// ClearEvents clears the list of stored events +func (s *SecurityEventsHolder) ClearEvents() { + s.mu.Lock() + defer s.mu.Unlock() + s.events = s.events[0:0] +} + // SetTags fills the span tags using the key/value pairs found in `tags` func SetTags(span TagSetter, tags map[string]interface{}) { for k, v := range tags { @@ -72,6 +81,14 @@ func SetTags(span TagSetter, tags map[string]interface{}) { } } +// SetStringTags fills the span tags using the key/value pairs of strings found +// in `tags` +func SetStringTags(span TagSetter, tags map[string]string) { + for k, v := range tags { + span.SetTag(k, v) + } +} + // SetAppSecEnabledTags sets the AppSec-specific span tags that are expected to be in // the web service entry span (span of type `web`) when AppSec is enabled. func SetAppSecEnabledTags(span TagSetter) { @@ -102,9 +119,11 @@ func SetEventSpanTags(span TagSetter, events []json.RawMessage) error { // Create the value of the security event tag. // TODO(Julio-Guerra): a future libddwaf version should return something -// avoiding us the following events concatenation logic which currently -// involves unserializing the top-level JSON arrays to concatenate them -// together. +// +// avoiding us the following events concatenation logic which currently +// involves unserializing the top-level JSON arrays to concatenate them +// together. +// // TODO(Julio-Guerra): avoid serializing the json in the request hot path func makeEventTagValue(events []json.RawMessage) (json.RawMessage, error) { var v interface{} diff --git a/internal/appsec/dyngo/instrumentation/grpcsec/actions.go b/internal/appsec/dyngo/instrumentation/grpcsec/actions.go new file mode 100644 index 0000000000..845ac5d3ca --- /dev/null +++ b/internal/appsec/dyngo/instrumentation/grpcsec/actions.go @@ -0,0 +1,69 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2022 Datadog, Inc. + +//go:build appsec +// +build appsec + +package grpcsec + +import ( + "sync" + + "google.golang.org/grpc/codes" +) + +// Action is used to identify any action kind +type Action interface { + isAction() +} + +// ActionsHandler handles WAF actions registration and execution +type ActionsHandler struct { + mu sync.RWMutex + actions map[string]Action +} + +// NewActionsHandler returns an action handler holding the default ASM actions. +// Currently, only the default "block" action is supported +func NewActionsHandler() ActionsHandler { + // Register the default "block" action as specified in the blocking RFC + actions := map[string]Action{"block": &BlockRequestAction{Status: codes.Aborted}} + + return ActionsHandler{ + actions: actions, + } +} + +// RegisterAction registers a specific action to the actions handler. If the action kind is unknown +// the action will have no effect +func (h *ActionsHandler) RegisterAction(id string, a Action) { + h.mu.Lock() + defer h.mu.Unlock() + h.actions[id] = a +} + +// Apply executes the action identified by `id` +func (h *ActionsHandler) Apply(id string, op *HandlerOperation) bool { + h.mu.RLock() + a, ok := h.actions[id] + h.mu.RUnlock() + if !ok { + return false + } + // Currently, only the "block_request" type is supported, so we only need to check for blockRequestParams + if p, ok := a.(*BlockRequestAction); ok { + op.BlockedCode = &p.Status + return true + } + return false +} + +// BlockRequestAction is the struct used to perform the request blocking action +type BlockRequestAction struct { + // Status is the return code to use when blocking the request + Status codes.Code +} + +func (*BlockRequestAction) isAction() {} diff --git a/internal/appsec/dyngo/instrumentation/grpcsec/grpc.go b/internal/appsec/dyngo/instrumentation/grpcsec/grpc.go index 8eb2c638f3..c581afd637 100644 --- a/internal/appsec/dyngo/instrumentation/grpcsec/grpc.go +++ b/internal/appsec/dyngo/instrumentation/grpcsec/grpc.go @@ -15,6 +15,8 @@ import ( "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation" + + "google.golang.org/grpc/codes" ) // Abstract gRPC server handler operation definitions. It is based on two @@ -39,12 +41,14 @@ type ( dyngo.Operation instrumentation.TagsHolder instrumentation.SecurityEventsHolder + BlockedCode *codes.Code } // HandlerOperationArgs is the grpc handler arguments. HandlerOperationArgs struct { // Message received by the gRPC handler. // Corresponds to the address `grpc.server.request.metadata`. Metadata map[string][]string + ClientIP instrumentation.NetaddrIP } // HandlerOperationRes is the grpc handler results. Empty as of today. HandlerOperationRes struct{} diff --git a/internal/appsec/dyngo/instrumentation/grpcsec/tags.go b/internal/appsec/dyngo/instrumentation/grpcsec/tags.go index 845ab3a044..871e81c266 100644 --- a/internal/appsec/dyngo/instrumentation/grpcsec/tags.go +++ b/internal/appsec/dyngo/instrumentation/grpcsec/tags.go @@ -7,7 +7,6 @@ package grpcsec import ( "encoding/json" - "net" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation" @@ -17,28 +16,20 @@ import ( // SetSecurityEventTags sets the AppSec-specific span tags when a security event // occurred into the service entry span. -func SetSecurityEventTags(span ddtrace.Span, events []json.RawMessage, addr net.Addr, md map[string][]string) { - if err := setSecurityEventTags(span, events, addr, md); err != nil { +func SetSecurityEventTags(span ddtrace.Span, events []json.RawMessage, md map[string][]string) { + if err := setSecurityEventTags(span, events, md); err != nil { log.Error("appsec: %v", err) } } -func setSecurityEventTags(span ddtrace.Span, events []json.RawMessage, addr net.Addr, md map[string][]string) error { +func setSecurityEventTags(span ddtrace.Span, events []json.RawMessage, md map[string][]string) error { if err := instrumentation.SetEventSpanTags(span, events); err != nil { return err } - var ip string - switch actual := addr.(type) { - case *net.UDPAddr: - ip = actual.IP.String() - case *net.TCPAddr: - ip = actual.IP.String() - } - if ip != "" { - span.SetTag("network.client.ip", ip) - } + for h, v := range httpsec.NormalizeHTTPHeaders(md) { span.SetTag("grpc.metadata."+h, v) } + return nil } diff --git a/internal/appsec/dyngo/instrumentation/grpcsec/tags_test.go b/internal/appsec/dyngo/instrumentation/grpcsec/tags_test.go index 5e0e21db69..8e34bdb802 100644 --- a/internal/appsec/dyngo/instrumentation/grpcsec/tags_test.go +++ b/internal/appsec/dyngo/instrumentation/grpcsec/tags_test.go @@ -12,9 +12,12 @@ import ( "testing" "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation/httpsec" "gopkg.in/DataDog/dd-trace-go.v1/internal/samplernames" ) @@ -52,105 +55,116 @@ func TestSetSecurityEventTags(t *testing.T) { }, } { eventCase := eventCase - for _, addrCase := range []struct { - name string - addr net.Addr - expectedTag string + for _, metadataCase := range []struct { + name string + md map[string][]string + expectedTags map[string]string }{ { - name: "tcp-ipv4-address", - addr: &net.TCPAddr{IP: net.ParseIP("1.2.3.4"), Port: 6789}, - expectedTag: "1.2.3.4", + name: "zero-metadata", }, { - name: "tcp-ipv6-address", - addr: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 6789}, - expectedTag: "::1", - }, - { - name: "udp-ipv4-address", - addr: &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 6789}, - expectedTag: "1.2.3.4", + name: "xff-metadata", + md: map[string][]string{ + "x-forwarded-for": {"1.2.3.4", "4.5.6.7"}, + ":authority": {"something"}, + }, + expectedTags: map[string]string{ + "grpc.metadata.x-forwarded-for": "1.2.3.4,4.5.6.7", + }, }, { - name: "udp-ipv6-address", - addr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 6789}, - expectedTag: "::1", + name: "xff-metadata", + md: map[string][]string{ + "x-forwarded-for": {"1.2.3.4"}, + ":authority": {"something"}, + }, + expectedTags: map[string]string{ + "grpc.metadata.x-forwarded-for": "1.2.3.4", + }, }, { - name: "unix-socket-address", - addr: &net.UnixAddr{Name: "/var/my.sock"}, + name: "no-monitored-metadata", + md: map[string][]string{ + ":authority": {"something"}, + }, }, } { - addrCase := addrCase - for _, metadataCase := range []struct { - name string - md map[string][]string - expectedTags map[string]string - }{ - { - name: "zero-metadata", - }, - { - name: "xff-metadata", - md: map[string][]string{ - "x-forwarded-for": {"1.2.3.4", "4.5.6.7"}, - ":authority": {"something"}, - }, - expectedTags: map[string]string{ - "grpc.metadata.x-forwarded-for": "1.2.3.4,4.5.6.7", - }, - }, - { - name: "xff-metadata", - md: map[string][]string{ - "x-forwarded-for": {"1.2.3.4"}, - ":authority": {"something"}, - }, - expectedTags: map[string]string{ - "grpc.metadata.x-forwarded-for": "1.2.3.4", - }, - }, - { - name: "no-monitored-metadata", - md: map[string][]string{ - ":authority": {"something"}, - }, - }, - } { - metadataCase := metadataCase - t.Run(fmt.Sprintf("%s-%s-%s", eventCase.name, addrCase.name, metadataCase.name), func(t *testing.T) { - var span MockSpan - err := setSecurityEventTags(&span, eventCase.events, addrCase.addr, metadataCase.md) - if eventCase.expectedError { - require.Error(t, err) - return - } - require.NoError(t, err) - - expectedTags := map[string]interface{}{ - "_dd.appsec.json": eventCase.expectedTag, - "manual.keep": true, - "appsec.event": true, - "_dd.origin": "appsec", - } - - if addr := addrCase.expectedTag; addr != "" { - expectedTags["network.client.ip"] = addr - } - - for k, v := range metadataCase.expectedTags { - expectedTags[k] = v - } - - require.Equal(t, expectedTags, span.tags) - require.False(t, span.finished) - }) - } + metadataCase := metadataCase + t.Run(fmt.Sprintf("%s-%s", eventCase.name, metadataCase.name), func(t *testing.T) { + var span MockSpan + err := setSecurityEventTags(&span, eventCase.events, metadataCase.md) + if eventCase.expectedError { + require.Error(t, err) + return + } + require.NoError(t, err) + + expectedTags := map[string]interface{}{ + "_dd.appsec.json": eventCase.expectedTag, + "manual.keep": true, + "appsec.event": true, + "_dd.origin": "appsec", + } + + for k, v := range metadataCase.expectedTags { + expectedTags[k] = v + } + + require.Equal(t, expectedTags, span.tags) + require.False(t, span.finished) + }) } } } +func TestClientIP(t *testing.T) { + for _, tc := range []struct { + name string + addr net.Addr + md metadata.MD + expectedClientIP string + }{ + { + name: "tcp-ipv4-address", + addr: &net.TCPAddr{IP: net.ParseIP("1.2.3.4"), Port: 6789}, + expectedClientIP: "1.2.3.4", + }, + { + name: "tcp-ipv4-address", + addr: &net.TCPAddr{IP: net.ParseIP("1.2.3.4"), Port: 6789}, + md: map[string][]string{"x-client-ip": {"127.0.0.1, 2.3.4.5"}}, + expectedClientIP: "2.3.4.5", + }, + { + name: "tcp-ipv6-address", + addr: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 6789}, + expectedClientIP: "::1", + }, + { + name: "udp-ipv4-address", + addr: &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 6789}, + expectedClientIP: "1.2.3.4", + }, + { + name: "udp-ipv6-address", + addr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 6789}, + expectedClientIP: "::1", + }, + { + name: "unix-socket-address", + addr: &net.UnixAddr{Name: "/var/my.sock"}, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + _, clientIP := httpsec.ClientIPTags(tc.md, tc.addr.String()) + expectedClientIP, _ := instrumentation.NetaddrParseIP(tc.expectedClientIP) + require.Equal(t, expectedClientIP.String(), clientIP.String()) + }) + } +} + type MockSpan struct { tags map[string]interface{} finished bool diff --git a/internal/appsec/dyngo/instrumentation/httpsec/actions.go b/internal/appsec/dyngo/instrumentation/httpsec/actions.go new file mode 100644 index 0000000000..c6d2925071 --- /dev/null +++ b/internal/appsec/dyngo/instrumentation/httpsec/actions.go @@ -0,0 +1,112 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2022 Datadog, Inc. + +package httpsec + +import ( + "net/http" + "strings" + "sync" + + "gopkg.in/DataDog/dd-trace-go.v1/internal/log" +) + +// Action is used to identify any action kind +type Action interface { + isAction() +} + +// BlockRequestAction is the action that holds the HTTP handler to use to block the request +type BlockRequestAction struct { + // handler is the http handler to use to block the request + handler http.Handler +} + +func (*BlockRequestAction) isAction() {} + +// NewBlockRequestAction creates, initializes and returns a new BlockRequestAction +func NewBlockRequestAction(status int, template string) BlockRequestAction { + htmlHandler := newBlockRequestHandler(status, "text/html", blockedTemplateHTML) + jsonHandler := newBlockRequestHandler(status, "application/json", blockedTemplateJSON) + var action BlockRequestAction + switch template { + case "json": + action.handler = jsonHandler + break + case "html": + action.handler = htmlHandler + break + default: + action.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := jsonHandler + hdr := r.Header.Get("Accept") + htmlIdx := strings.Index(hdr, "text/html") + jsonIdx := strings.Index(hdr, "application/json") + // Switch to html handler if text/html comes before application/json in the Accept header + if htmlIdx != -1 && (jsonIdx == -1 || htmlIdx < jsonIdx) { + h = htmlHandler + } + h.ServeHTTP(w, r) + }) + break + } + return action + +} + +func newBlockRequestHandler(status int, ct string, payload []byte) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", ct) + w.WriteHeader(status) + w.Write(payload) + }) +} + +// ActionsHandler handles actions registration and their application to operations +type ActionsHandler struct { + mu sync.RWMutex + actions map[string]Action +} + +// NewActionsHandler returns an action handler holding the default ASM actions. +// Currently, only the default "block" action is supported +func NewActionsHandler() *ActionsHandler { + handler := ActionsHandler{ + actions: map[string]Action{}, + } + // Register the default "block" action as specified in the RFC for HTTP blocking + block := NewBlockRequestAction(403, "auto") + handler.RegisterAction("block", &block) + + return &handler +} + +// RegisterAction registers a specific action to the handler. If the action kind is unknown +// the action will not be registered +func (h *ActionsHandler) RegisterAction(id string, a Action) { + h.mu.Lock() + defer h.mu.Unlock() + h.actions[id] = a +} + +// Apply applies the action identified by `id` for the given operation +// Returns true if the applied action will interrupt the request flow (block, redirect, etc...) +func (h *ActionsHandler) Apply(id string, op *Operation) bool { + h.mu.RLock() + a, ok := h.actions[id] + h.mu.RUnlock() + if !ok { + log.Debug("appsec: ignoring the returned waf action: unknown action id `%s`", id) + return false + } + op.AddAction(a) + + switch a.(type) { + case *BlockRequestAction: + return true + default: + return false + } +} diff --git a/internal/appsec/dyngo/instrumentation/httpsec/actions_test.go b/internal/appsec/dyngo/instrumentation/httpsec/actions_test.go new file mode 100644 index 0000000000..5f48a23df2 --- /dev/null +++ b/internal/appsec/dyngo/instrumentation/httpsec/actions_test.go @@ -0,0 +1,155 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package httpsec + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewBlockRequestAction(t *testing.T) { + mux := http.NewServeMux() + srv := httptest.NewServer(mux) + mux.HandleFunc("/json", NewBlockRequestAction(403, "json").handler.ServeHTTP) + mux.HandleFunc("/html", NewBlockRequestAction(403, "html").handler.ServeHTTP) + mux.HandleFunc("/auto", NewBlockRequestAction(403, "auto").handler.ServeHTTP) + defer srv.Close() + + t.Run("json", func(t *testing.T) { + for _, tc := range []struct { + name string + accept string + }{ + { + name: "no-accept", + }, + { + name: "irrelevant-accept", + accept: "text/html", + }, + { + name: "accept", + accept: "application/json", + }, + } { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest("POST", srv.URL+"/json", nil) + req.Header.Set("Accept", tc.accept) + require.NoError(t, err) + res, err := srv.Client().Do(req) + require.NoError(t, err) + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + require.Equal(t, 403, res.StatusCode) + require.Equal(t, blockedTemplateJSON, body) + }) + } + }) + + t.Run("html", func(t *testing.T) { + for _, tc := range []struct { + name string + accept string + }{ + { + name: "no-accept", + }, + { + name: "irrelevant-accept", + accept: "application/json", + }, + { + name: "accept", + accept: "text/html", + }, + } { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest("POST", srv.URL+"/html", nil) + require.NoError(t, err) + res, err := srv.Client().Do(req) + require.NoError(t, err) + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + require.Equal(t, 403, res.StatusCode) + require.Equal(t, blockedTemplateHTML, body) + }) + } + }) + + t.Run("auto", func(t *testing.T) { + for _, tc := range []struct { + name string + accept string + expected []byte + }{ + { + name: "no-accept", + expected: blockedTemplateJSON, + }, + { + name: "json-accept-1", + accept: "application/json", + expected: blockedTemplateJSON, + }, + { + name: "json-accept-2", + accept: "application/json,text/html", + expected: blockedTemplateJSON, + }, + { + name: "json-accept-3", + accept: "irrelevant/content,application/json,text/html", + expected: blockedTemplateJSON, + }, + { + name: "json-accept-4", + accept: "irrelevant/content,application/json,text/html,application/json", + expected: blockedTemplateJSON, + }, + { + name: "html-accept-1", + accept: "text/html", + expected: blockedTemplateHTML, + }, + { + name: "html-accept-2", + accept: "text/html,application/json", + expected: blockedTemplateHTML, + }, + { + name: "html-accept-3", + accept: "irrelevant/content,text/html,application/json", + expected: blockedTemplateHTML, + }, + { + name: "html-accept-4", + accept: "irrelevant/content,text/html,application/json,text/html", + expected: blockedTemplateHTML, + }, + { + name: "irrelevant-accept", + accept: "irrelevant/irrelevant,application/html", + expected: blockedTemplateJSON, + }, + } { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest("POST", srv.URL+"/auto", nil) + req.Header.Set("Accept", tc.accept) + require.NoError(t, err) + res, err := srv.Client().Do(req) + require.NoError(t, err) + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + require.Equal(t, 403, res.StatusCode) + require.Equal(t, tc.expected, body) + }) + } + }) +} diff --git a/internal/appsec/dyngo/instrumentation/httpsec/blocked-template.html b/internal/appsec/dyngo/instrumentation/httpsec/blocked-template.html new file mode 100644 index 0000000000..8c48babc80 --- /dev/null +++ b/internal/appsec/dyngo/instrumentation/httpsec/blocked-template.html @@ -0,0 +1 @@ + You've been blocked

Sorry, you cannot access this page. Please contact the customer service team.

diff --git a/internal/appsec/dyngo/instrumentation/httpsec/blocked-template.json b/internal/appsec/dyngo/instrumentation/httpsec/blocked-template.json new file mode 100644 index 0000000000..bbcafb6cb1 --- /dev/null +++ b/internal/appsec/dyngo/instrumentation/httpsec/blocked-template.json @@ -0,0 +1,8 @@ +{ + "errors": [ + { + "title": "You've been blocked", + "detail": "Sorry, you cannot access this page. Please contact the customer service team. Security provided by Datadog." + } + ] +} diff --git a/internal/appsec/dyngo/instrumentation/httpsec/http.go b/internal/appsec/dyngo/instrumentation/httpsec/http.go index b070065eeb..17606dee73 100644 --- a/internal/appsec/dyngo/instrumentation/httpsec/http.go +++ b/internal/appsec/dyngo/instrumentation/httpsec/http.go @@ -12,11 +12,14 @@ package httpsec import ( "context" + // Blank import needed to use embed for the default blocked response payloads + _ "embed" "encoding/json" - "net" "net/http" + "os" "reflect" "strings" + "sync" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo" @@ -38,6 +41,8 @@ type ( Query map[string][]string // PathParams corresponds to the address `server.request.path_params` PathParams map[string]string + // ClientIP corresponds to the addres `http.client_ip` + ClientIP instrumentation.NetaddrIP } // HandlerOperationRes is the HTTP handler operation results. @@ -68,17 +73,36 @@ func MonitorParsedBody(ctx context.Context, body interface{}) { } } +// applyActions executes the operation's actions and returns the resulting http handler +func applyActions(op *Operation) http.Handler { + defer op.ClearActions() + for _, action := range op.Actions() { + switch a := action.(type) { + case *BlockRequestAction: + op.AddTag(BlockedRequestTag, true) + return a.handler + default: + log.Error("appsec: ignoring security action: unexpected action type %T", a) + } + } + return nil +} + // WrapHandler wraps the given HTTP handler with the abstract HTTP operation defined by HandlerOperationArgs and // HandlerOperationRes. func WrapHandler(handler http.Handler, span ddtrace.Span, pathParams map[string]string) http.Handler { instrumentation.SetAppSecEnabledTags(span) - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - SetIPTags(span, r) + ipTags, clientIP := ClientIPTags(r.Header, r.RemoteAddr) + instrumentation.SetStringTags(span, ipTags) - args := MakeHandlerOperationArgs(r, pathParams) + args := MakeHandlerOperationArgs(r, clientIP, pathParams) ctx, op := StartOperation(r.Context(), args) r = r.WithContext(ctx) + + if h := applyActions(op); h != nil { + handler = h + } defer func() { var status int if mw, ok := w.(interface{ Status() int }); ok { @@ -91,21 +115,19 @@ func WrapHandler(handler http.Handler, span ddtrace.Span, pathParams map[string] return } - remoteIP, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - remoteIP = r.RemoteAddr - } - SetSecurityEventTags(span, events, remoteIP, args.Headers, w.Header()) + applyActions(op) + SetSecurityEventTags(span, events, args.Headers, w.Header()) }() handler.ServeHTTP(w, r) + }) } // MakeHandlerOperationArgs creates the HandlerOperationArgs out of a standard // http.Request along with the given current span. It returns an empty structure // when appsec is disabled. -func MakeHandlerOperationArgs(r *http.Request, pathParams map[string]string) HandlerOperationArgs { +func MakeHandlerOperationArgs(r *http.Request, clientIP instrumentation.NetaddrIP, pathParams map[string]string) HandlerOperationArgs { headers := make(http.Header, len(r.Header)) for k, v := range r.Header { k := strings.ToLower(k) @@ -123,6 +145,7 @@ func MakeHandlerOperationArgs(r *http.Request, pathParams map[string]string) Han Cookies: cookies, Query: r.URL.Query(), // TODO(Julio-Guerra): avoid actively parsing the query values thanks to dynamic instrumentation PathParams: pathParams, + ClientIP: clientIP, } } @@ -149,6 +172,8 @@ type ( dyngo.Operation instrumentation.TagsHolder instrumentation.SecurityEventsHolder + mu sync.RWMutex + actions []Action } // SDKBodyOperation type representing an SDK body. It must be created with @@ -187,6 +212,27 @@ func (op *Operation) Finish(res HandlerOperationRes) []json.RawMessage { return op.Events() } +// Actions returns the actions linked to the operation +func (op *Operation) Actions() []Action { + op.mu.RLock() + defer op.mu.RUnlock() + return op.actions +} + +// AddAction adds an action to the operation +func (op *Operation) AddAction(a Action) { + op.mu.Lock() + defer op.mu.Unlock() + op.actions = append(op.actions, a) +} + +// ClearActions clears all the actions linked to the operation +func (op *Operation) ClearActions() { + op.mu.Lock() + defer op.mu.Unlock() + op.actions = op.actions[0:0] +} + // StartSDKBodyOperation starts the SDKBody operation and emits a start event func StartSDKBodyOperation(parent *Operation, args SDKBodyOperationArgs) *SDKBodyOperation { op := &SDKBodyOperation{Operation: dyngo.NewOperation(parent)} @@ -261,3 +307,31 @@ func (OnSDKBodyOperationFinish) ListenedType() reflect.Type { return sdkBodyOper func (f OnSDKBodyOperationFinish) Call(op dyngo.Operation, v interface{}) { f(op.(*SDKBodyOperation), v.(SDKBodyOperationRes)) } + +// blockedTemplateJSON is the default JSON template used to write responses for blocked requests +// +//go:embed blocked-template.json +var blockedTemplateJSON []byte + +// blockedTemplateHTML is the default HTML template used to write responses for blocked requests +// +//go:embed blocked-template.html +var blockedTemplateHTML []byte + +const ( + envBlockedTemplateHTML = "DD_APPSEC_HTTP_BLOCKED_TEMPLATE_HTML" + envBlockedTemplateJSON = "DD_APPSEC_HTTP_BLOCKED_TEMPLATE_JSON" +) + +func init() { + for env, template := range map[string]*[]byte{envBlockedTemplateJSON: &blockedTemplateJSON, envBlockedTemplateHTML: &blockedTemplateHTML} { + if path, ok := os.LookupEnv(env); ok { + if t, err := os.ReadFile(path); err != nil { + log.Warn("Could not read template at %s: %v", path, err) + } else { + *template = t + } + } + + } +} diff --git a/internal/appsec/dyngo/instrumentation/httpsec/ip_default.go b/internal/appsec/dyngo/instrumentation/httpsec/ip_default.go deleted file mode 100644 index a96f74ef02..0000000000 --- a/internal/appsec/dyngo/instrumentation/httpsec/ip_default.go +++ /dev/null @@ -1,22 +0,0 @@ -// Unless explicitly stated otherwise all files in this repository are licensed -// under the Apache License Version 2.0. -// This product includes software developed at Datadog (https://www.datadoghq.com/). -// Copyright 2022 Datadog, Inc. - -//go:build !go1.19 -// +build !go1.19 - -package httpsec - -import "inet.af/netaddr" - -type netaddrIP = netaddr.IP -type netaddrIPPrefix = netaddr.IPPrefix - -var ( - netaddrParseIP = netaddr.ParseIP - netaddrParseIPPrefix = netaddr.ParseIPPrefix - netaddrMustParseIP = netaddr.MustParseIP - netaddrIPv4 = netaddr.IPv4 - netaddrIPv6Raw = netaddr.IPv6Raw -) diff --git a/internal/appsec/dyngo/instrumentation/httpsec/ip_go119.go b/internal/appsec/dyngo/instrumentation/httpsec/ip_go119.go deleted file mode 100644 index dbea4b60d3..0000000000 --- a/internal/appsec/dyngo/instrumentation/httpsec/ip_go119.go +++ /dev/null @@ -1,26 +0,0 @@ -// Unless explicitly stated otherwise all files in this repository are licensed -// under the Apache License Version 2.0. -// This product includes software developed at Datadog (https://www.datadoghq.com/). -// Copyright 2022 Datadog, Inc. - -//go:build go1.19 -// +build go1.19 - -package httpsec - -import "net/netip" - -type netaddrIP = netip.Addr -type netaddrIPPrefix = netip.Prefix - -var ( - netaddrParseIP = netip.ParseAddr - netaddrParseIPPrefix = netip.ParsePrefix - netaddrMustParseIP = netip.MustParseAddr - netaddrIPv6Raw = netip.AddrFrom16 -) - -func netaddrIPv4(a, b, c, d byte) netaddrIP { - e := [4]byte{a, b, c, d} - return netip.AddrFrom4(e) -} diff --git a/internal/appsec/dyngo/instrumentation/httpsec/tags.go b/internal/appsec/dyngo/instrumentation/httpsec/tags.go index faf300d950..5f07f140a7 100644 --- a/internal/appsec/dyngo/instrumentation/httpsec/tags.go +++ b/internal/appsec/dyngo/instrumentation/httpsec/tags.go @@ -8,7 +8,6 @@ package httpsec import ( "encoding/json" "net" - "net/http" "os" "sort" "strings" @@ -21,15 +20,21 @@ import ( const ( // envClientIPHeader is the name of the env var used to specify the IP header to be used for client IP collection. envClientIPHeader = "DD_TRACE_CLIENT_IP_HEADER" - // multipleIPHeaders sets the multiple ip header tag used internally to tell the backend an error occurred when + + // multipleIPHeadersTag sets the multiple ip header tag used internally to tell the backend an error occurred when // retrieving an HTTP request client IP. - multipleIPHeaders = "_dd.multiple-ip-headers" + multipleIPHeadersTag = "_dd.multiple-ip-headers" + + // BlockedRequestTag used to convey whether a request is blocked + BlockedRequestTag = "appsec.blocked" ) var ( - ipv6SpecialNetworks = []*netaddrIPPrefix{ + ipv6SpecialNetworks = []*instrumentation.NetaddrIPPrefix{ ippref("fec0::/10"), // site local } + + // List of IP-related headers leveraged to retrieve the public client IP address. defaultIPHeaders = []string{ "x-forwarded-for", "x-real-ip", @@ -41,6 +46,7 @@ var ( "via", "true-client-ip", } + // List of HTTP headers we collect and send. collectedHTTPHeaders = append(defaultIPHeaders, "host", @@ -53,21 +59,22 @@ var ( "accept", "accept-encoding", "accept-language") - clientIPHeader string + + clientIPHeaderCfg string ) func init() { // Required by sort.SearchStrings + sort.Strings(defaultIPHeaders[:]) sort.Strings(collectedHTTPHeaders[:]) - clientIPHeader = os.Getenv(envClientIPHeader) + clientIPHeaderCfg = os.Getenv(envClientIPHeader) } // SetSecurityEventTags sets the AppSec-specific span tags when a security event occurred into the service entry span. -func SetSecurityEventTags(span instrumentation.TagSetter, events []json.RawMessage, remoteIP string, headers, respHeaders map[string][]string) { +func SetSecurityEventTags(span instrumentation.TagSetter, events []json.RawMessage, headers, respHeaders map[string][]string) { if err := instrumentation.SetEventSpanTags(span, events); err != nil { log.Error("appsec: unexpected error while creating the appsec event tags: %v", err) } - span.SetTag("network.client.ip", remoteIP) for h, v := range NormalizeHTTPHeaders(headers) { span.SetTag("http.request.headers."+h, v) } @@ -96,69 +103,106 @@ func NormalizeHTTPHeaders(headers map[string][]string) (normalized map[string]st } // ippref returns the IP network from an IP address string s. If not possible, it returns nil. -func ippref(s string) *netaddrIPPrefix { - if prefix, err := netaddrParseIPPrefix(s); err == nil { +func ippref(s string) *instrumentation.NetaddrIPPrefix { + if prefix, err := instrumentation.NetaddrParseIPPrefix(s); err == nil { return &prefix } return nil } -// SetIPTags sets the IP related span tags for a given request -// See https://docs.datadoghq.com/tracing/configure_data_security#configuring-a-client-ip-header for more information. -func SetIPTags(span instrumentation.TagSetter, r *http.Request) { - ipHeaders := defaultIPHeaders - if len(clientIPHeader) > 0 { - ipHeaders = []string{clientIPHeader} - } - - var ( - headers []string - ips []string - ) - for _, hdr := range ipHeaders { - if v := r.Header.Get(hdr); v != "" { - headers = append(headers, hdr) - ips = append(ips, v) +// ClientIPTags generates the IP related span tags for a given request headers +func ClientIPTags(hdrs map[string][]string, remoteAddr string) (tags map[string]string, clientIP instrumentation.NetaddrIP) { + tags = map[string]string{} + monitoredHeaders := defaultIPHeaders + if clientIPHeaderCfg != "" { + monitoredHeaders = []string{clientIPHeaderCfg} + } + + // Filter the list of headers + foundHeaders := map[string][]string{} + for k, v := range hdrs { + k = strings.ToLower(k) + if i := sort.SearchStrings(monitoredHeaders, k); i < len(monitoredHeaders) && monitoredHeaders[i] == k { + if len(v) >= 1 && v[0] != "" { + foundHeaders[k] = v + } + } + } + + // If more than one IP header is present, report them and don't return any client ip + if len(foundHeaders) > 1 { + var headers []string + for header, ips := range foundHeaders { + tags[ext.HTTPRequestHeaders+"."+header] = strings.Join(ips, ",") + headers = append(headers, header) } + sort.Strings(headers) // produce a predictable value + tags[multipleIPHeadersTag] = strings.Join(headers, ",") + return tags, instrumentation.NetaddrIP{} } - if l := len(ips); l == 0 { - if remoteIP := parseIP(r.RemoteAddr); remoteIP.IsValid() && isGlobal(remoteIP) { - span.SetTag(ext.HTTPClientIP, remoteIP.String()) + // Walk IP-related headers + var foundIP instrumentation.NetaddrIP + for _, v := range foundHeaders { + // Handle multi-value headers by flattening the list of values + var ips []string + for _, ip := range v { + ips = append(ips, strings.Split(ip, ",")...) } - } else if l == 1 { - for _, ipstr := range strings.Split(ips[0], ",") { + + // Look for the first valid or global IP address in the comma-separated list + for _, ipstr := range ips { ip := parseIP(strings.TrimSpace(ipstr)) - if ip.IsValid() && isGlobal(ip) { - span.SetTag(ext.HTTPClientIP, ip.String()) + if !ip.IsValid() { + continue + } + // Replace foundIP if still not valid in order to keep the oldest + if !foundIP.IsValid() { + foundIP = ip + } + if isGlobal(ip) { + foundIP = ip break } } - } else { - for i := range ips { - span.SetTag(ext.HTTPRequestHeaders+"."+headers[i], ips[i]) - } - span.SetTag(multipleIPHeaders, strings.Join(headers, ",")) } + + // Decide which IP address is the client one by starting with the remote IP + remoteIP := parseIP(remoteAddr) + if remoteIP.IsValid() { + tags["network.client.ip"] = remoteIP.String() + clientIP = remoteIP + } + + // The IP address found in the headers supersedes a private remote IP address. + if foundIP.IsValid() && !isGlobal(remoteIP) || isGlobal(foundIP) { + clientIP = foundIP + } + + if clientIP.IsValid() { + tags[ext.HTTPClientIP] = clientIP.String() + } + + return tags, clientIP } -func parseIP(s string) netaddrIP { - if ip, err := netaddrParseIP(s); err == nil { +func parseIP(s string) instrumentation.NetaddrIP { + if ip, err := instrumentation.NetaddrParseIP(s); err == nil { return ip } if h, _, err := net.SplitHostPort(s); err == nil { - if ip, err := netaddrParseIP(h); err == nil { + if ip, err := instrumentation.NetaddrParseIP(h); err == nil { return ip } } - return netaddrIP{} + return instrumentation.NetaddrIP{} } -func isGlobal(ip netaddrIP) bool { +func isGlobal(ip instrumentation.NetaddrIP) bool { // IsPrivate also checks for ipv6 ULA. // We care to check for these addresses are not considered public, hence not global. // See https://www.rfc-editor.org/rfc/rfc4193.txt for more details. - isGlobal := !ip.IsPrivate() && !ip.IsLoopback() && !ip.IsLinkLocalUnicast() + isGlobal := ip.IsValid() && !ip.IsPrivate() && !ip.IsLoopback() && !ip.IsLinkLocalUnicast() if !isGlobal || !ip.Is6() { return isGlobal } diff --git a/internal/appsec/dyngo/instrumentation/httpsec/tags_test.go b/internal/appsec/dyngo/instrumentation/httpsec/tags_test.go index 1dfe92c6fe..2cb9bee406 100644 --- a/internal/appsec/dyngo/instrumentation/httpsec/tags_test.go +++ b/internal/appsec/dyngo/instrumentation/httpsec/tags_test.go @@ -11,6 +11,7 @@ import ( "testing" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" + "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation" "github.com/stretchr/testify/require" ) @@ -58,7 +59,7 @@ type ipTestCase struct { name string remoteAddr string headers map[string]string - expectedIP netaddrIP + expectedIP instrumentation.NetaddrIP multiHeaders string clientIPHeader string } @@ -68,170 +69,221 @@ func genIPTestCases() []ipTestCase { ipv6Global := randGlobalIPv6().String() ipv4Private := randPrivateIPv4().String() ipv6Private := randPrivateIPv6().String() - tcs := []ipTestCase{} + + tcs := []ipTestCase{ + { + name: "ipv4-global-remoteaddr", + remoteAddr: ipv4Global, + expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global), + }, + { + name: "ipv4-private-remoteaddr", + remoteAddr: ipv4Private, + expectedIP: instrumentation.NetaddrMustParseIP(ipv4Private), + }, + { + name: "ipv6-global-remoteaddr", + remoteAddr: ipv6Global, + expectedIP: instrumentation.NetaddrMustParseIP(ipv6Global), + }, + { + name: "ipv6-private-remoteaddr", + remoteAddr: ipv6Private, + expectedIP: instrumentation.NetaddrMustParseIP(ipv6Private), + }, + } + // Simple ipv4 test cases over all headers for _, header := range defaultIPHeaders { - tcs = append(tcs, ipTestCase{ - name: "ipv4-global." + header, - headers: map[string]string{header: ipv4Global}, - expectedIP: netaddrMustParseIP(ipv4Global), - }) - tcs = append(tcs, ipTestCase{ - name: "ipv4-private." + header, - headers: map[string]string{header: ipv4Private}, - expectedIP: netaddrIP{}, - }) + tcs = append(tcs, + ipTestCase{ + name: "ipv4-global." + header, + remoteAddr: ipv4Private, + headers: map[string]string{header: ipv4Global}, + expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global), + }, + ipTestCase{ + name: "ipv4-private." + header, + headers: map[string]string{header: ipv4Private}, + remoteAddr: ipv6Private, + expectedIP: instrumentation.NetaddrMustParseIP(ipv4Private), + }, + ipTestCase{ + name: "ipv4-global-remoteaddr-local-ip-header." + header, + remoteAddr: ipv4Global, + headers: map[string]string{header: ipv4Private}, + expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global), + }, + ipTestCase{ + name: "ipv4-global-remoteaddr-global-ip-header." + header, + remoteAddr: ipv6Global, + headers: map[string]string{header: ipv4Global}, + expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global), + }) } + // Simple ipv6 test cases over all headers for _, header := range defaultIPHeaders { tcs = append(tcs, ipTestCase{ name: "ipv6-global." + header, + remoteAddr: ipv4Private, headers: map[string]string{header: ipv6Global}, - expectedIP: netaddrMustParseIP(ipv6Global), - }) - tcs = append(tcs, ipTestCase{ - name: "ipv6-private." + header, - headers: map[string]string{header: ipv6Private}, - expectedIP: netaddrIP{}, - }) + expectedIP: instrumentation.NetaddrMustParseIP(ipv6Global), + }, + ipTestCase{ + name: "ipv6-private." + header, + headers: map[string]string{header: ipv6Private}, + remoteAddr: ipv4Private, + expectedIP: instrumentation.NetaddrMustParseIP(ipv6Private), + }, + ipTestCase{ + name: "ipv6-global-remoteaddr-local-ip-header." + header, + remoteAddr: ipv6Global, + headers: map[string]string{header: ipv6Private}, + expectedIP: instrumentation.NetaddrMustParseIP(ipv6Global), + }, + ipTestCase{ + name: "ipv6-global-remoteaddr-global-ip-header." + header, + remoteAddr: ipv4Global, + headers: map[string]string{header: ipv6Global}, + expectedIP: instrumentation.NetaddrMustParseIP(ipv6Global), + }) } + // private and global in same header tcs = append([]ipTestCase{ { name: "ipv4-private+global", headers: map[string]string{"x-forwarded-for": ipv4Private + "," + ipv4Global}, - expectedIP: netaddrMustParseIP(ipv4Global), + expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global), }, { name: "ipv4-global+private", headers: map[string]string{"x-forwarded-for": ipv4Global + "," + ipv4Private}, - expectedIP: netaddrMustParseIP(ipv4Global), + expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global), }, { name: "ipv6-private+global", headers: map[string]string{"x-forwarded-for": ipv6Private + "," + ipv6Global}, - expectedIP: netaddrMustParseIP(ipv6Global), + expectedIP: instrumentation.NetaddrMustParseIP(ipv6Global), }, { name: "ipv6-global+private", headers: map[string]string{"x-forwarded-for": ipv6Global + "," + ipv6Private}, - expectedIP: netaddrMustParseIP(ipv6Global), + expectedIP: instrumentation.NetaddrMustParseIP(ipv6Global), + }, + { + name: "mixed-global+global", + headers: map[string]string{"x-forwarded-for": ipv4Private + "," + ipv6Global + "," + ipv4Global}, + expectedIP: instrumentation.NetaddrMustParseIP(ipv6Global), + }, + { + name: "mixed-global+global", + headers: map[string]string{"x-forwarded-for": ipv4Private + "," + ipv4Global + "," + ipv6Global}, + expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global), }, }, tcs...) + // Invalid IPs (or a mix of valid/invalid over a single or multiple headers) tcs = append([]ipTestCase{ { name: "invalid-ipv4", headers: map[string]string{"x-forwarded-for": "127..0.0.1"}, - expectedIP: netaddrIP{}, + expectedIP: instrumentation.NetaddrIP{}, + }, + { + name: "invalid-ipv4-header-valid-remoteaddr", + headers: map[string]string{"x-forwarded-for": "127..0.0.1"}, + remoteAddr: ipv4Private, + expectedIP: instrumentation.NetaddrMustParseIP(ipv4Private), }, { name: "invalid-ipv4-recover", - headers: map[string]string{"x-forwarded-for": "127..0.0.1, " + ipv4Global}, - expectedIP: netaddrMustParseIP(ipv4Global), + headers: map[string]string{"x-forwarded-for": "127..0.0.1, " + ipv6Private + "," + ipv4Global}, + expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global), }, { name: "ipv4-multi-header-1", headers: map[string]string{"x-forwarded-for": "127.0.0.1", "forwarded-for": ipv4Global}, - expectedIP: netaddrIP{}, - multiHeaders: "x-forwarded-for,forwarded-for", + expectedIP: instrumentation.NetaddrIP{}, + multiHeaders: "forwarded-for,x-forwarded-for", }, { name: "ipv4-multi-header-2", headers: map[string]string{"forwarded-for": ipv4Global, "x-forwarded-for": "127.0.0.1"}, - expectedIP: netaddrIP{}, - multiHeaders: "x-forwarded-for,forwarded-for", + expectedIP: instrumentation.NetaddrIP{}, + multiHeaders: "forwarded-for,x-forwarded-for", }, { name: "invalid-ipv6", headers: map[string]string{"x-forwarded-for": "2001:0db8:2001:zzzz::"}, - expectedIP: netaddrIP{}, + expectedIP: instrumentation.NetaddrIP{}, }, { name: "invalid-ipv6-recover", headers: map[string]string{"x-forwarded-for": "2001:0db8:2001:zzzz::, " + ipv6Global}, - expectedIP: netaddrMustParseIP(ipv6Global), + expectedIP: instrumentation.NetaddrMustParseIP(ipv6Global), }, { name: "ipv6-multi-header-1", headers: map[string]string{"x-forwarded-for": "2001:0db8:2001:zzzz::", "forwarded-for": ipv6Global}, - expectedIP: netaddrIP{}, - multiHeaders: "x-forwarded-for,forwarded-for", + expectedIP: instrumentation.NetaddrIP{}, + multiHeaders: "forwarded-for,x-forwarded-for", }, { name: "ipv6-multi-header-2", headers: map[string]string{"forwarded-for": ipv6Global, "x-forwarded-for": "2001:0db8:2001:zzzz::"}, - expectedIP: netaddrIP{}, - multiHeaders: "x-forwarded-for,forwarded-for", + expectedIP: instrumentation.NetaddrIP{}, + multiHeaders: "forwarded-for,x-forwarded-for", }, - }, tcs...) - tcs = append([]ipTestCase{ { name: "no-headers", - expectedIP: netaddrIP{}, + expectedIP: instrumentation.NetaddrIP{}, }, { name: "header-case", - expectedIP: netaddrMustParseIP(ipv4Global), headers: map[string]string{"X-fOrWaRdEd-FoR": ipv4Global}, + expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global), }, { name: "user-header", - expectedIP: netaddrMustParseIP(ipv4Global), headers: map[string]string{"x-forwarded-for": ipv6Global, "custom-header": ipv4Global}, clientIPHeader: "custom-header", + expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global), }, { name: "user-header-not-found", - expectedIP: netaddrIP{}, headers: map[string]string{"x-forwarded-for": ipv4Global}, clientIPHeader: "custom-header", + expectedIP: instrumentation.NetaddrIP{}, }, }, tcs...) return tcs } -type mockspan struct { - tags map[string]interface{} -} - -func (m *mockspan) SetTag(tag string, value interface{}) { - if m.tags == nil { - m.tags = make(map[string]interface{}) - } - m.tags[tag] = value -} - -func (m *mockspan) Tag(tag string) interface{} { - if m.tags == nil { - return nil - } - return m.tags[tag] -} - func TestIPHeaders(t *testing.T) { - // Make sure to restore the real value of clientIPHeader at the end of the test - defer func(s string) { clientIPHeader = s }(clientIPHeader) + // Make sure to restore the real value of clientIPHeaderCfg at the end of the test + defer func(s string) { clientIPHeaderCfg = s }(clientIPHeaderCfg) for _, tc := range genIPTestCases() { t.Run(tc.name, func(t *testing.T) { header := http.Header{} for k, v := range tc.headers { header.Add(k, v) } - r := http.Request{Header: header, RemoteAddr: tc.remoteAddr} - clientIPHeader = tc.clientIPHeader - var span mockspan - SetIPTags(&span, &r) + clientIPHeaderCfg = tc.clientIPHeader + tags, clientIP := ClientIPTags(header, tc.remoteAddr) if tc.expectedIP.IsValid() { - require.Equal(t, tc.expectedIP.String(), span.Tag(ext.HTTPClientIP)) - require.Nil(t, span.Tag(multipleIPHeaders)) + expectedIP := tc.expectedIP.String() + require.Equal(t, expectedIP, tags[ext.HTTPClientIP]) + require.Equal(t, expectedIP, clientIP.String()) + require.NotContains(t, tags, multipleIPHeadersTag) } else { - require.Nil(t, span.Tag(ext.HTTPClientIP)) + require.NotContains(t, tags, ext.HTTPClientIP) if tc.multiHeaders != "" { - require.Equal(t, tc.multiHeaders, span.Tag(multipleIPHeaders)) + require.Equal(t, tc.multiHeaders, tags[multipleIPHeadersTag]) for hdr, ip := range tc.headers { - require.Equal(t, ip, span.Tag(ext.HTTPRequestHeaders+"."+hdr)) + require.Equal(t, ip, tags[ext.HTTPRequestHeaders+"."+hdr]) } } } @@ -239,12 +291,12 @@ func TestIPHeaders(t *testing.T) { } } -func randIPv4() netaddrIP { - return netaddrIPv4(uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32())) +func randIPv4() instrumentation.NetaddrIP { + return instrumentation.NetaddrIPv4(uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32())) } -func randIPv6() netaddrIP { - return netaddrIPv6Raw([16]byte{ +func randIPv6() instrumentation.NetaddrIP { + return instrumentation.NetaddrIPv6Raw([16]byte{ uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), @@ -252,7 +304,7 @@ func randIPv6() netaddrIP { }) } -func randGlobalIPv4() netaddrIP { +func randGlobalIPv4() instrumentation.NetaddrIP { for { ip := randIPv4() if isGlobal(ip) { @@ -261,7 +313,7 @@ func randGlobalIPv4() netaddrIP { } } -func randGlobalIPv6() netaddrIP { +func randGlobalIPv6() instrumentation.NetaddrIP { for { ip := randIPv6() if isGlobal(ip) { @@ -270,7 +322,7 @@ func randGlobalIPv6() netaddrIP { } } -func randPrivateIPv4() netaddrIP { +func randPrivateIPv4() instrumentation.NetaddrIP { for { ip := randIPv4() if !isGlobal(ip) && ip.IsPrivate() { @@ -279,7 +331,7 @@ func randPrivateIPv4() netaddrIP { } } -func randPrivateIPv6() netaddrIP { +func randPrivateIPv6() instrumentation.NetaddrIP { for { ip := randIPv6() if !isGlobal(ip) && ip.IsPrivate() { diff --git a/internal/appsec/dyngo/instrumentation/ip_default.go b/internal/appsec/dyngo/instrumentation/ip_default.go new file mode 100644 index 0000000000..3fde0af771 --- /dev/null +++ b/internal/appsec/dyngo/instrumentation/ip_default.go @@ -0,0 +1,30 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2022 Datadog, Inc. + +//go:build !go1.19 +// +build !go1.19 + +package instrumentation + +import "inet.af/netaddr" + +// NetaddrIP wraps an netaddr.IP value +type NetaddrIP = netaddr.IP + +// NetaddrIPPrefix wraps an netaddr.IPPrefix value +type NetaddrIPPrefix = netaddr.IPPrefix + +var ( + // NetaddrParseIP wraps the netaddr.ParseIP function + NetaddrParseIP = netaddr.ParseIP + // NetaddrParseIPPrefix wraps the netaddr.ParseIPPrefix function + NetaddrParseIPPrefix = netaddr.ParseIPPrefix + // NetaddrMustParseIP wraps the netaddr.MustParseIP function + NetaddrMustParseIP = netaddr.MustParseIP + // NetaddrIPv4 wraps the netaddr.IPv4 function + NetaddrIPv4 = netaddr.IPv4 + // NetaddrIPv6Raw wraps the netaddr.IPv6Raw function + NetaddrIPv6Raw = netaddr.IPv6Raw +) diff --git a/internal/appsec/dyngo/instrumentation/ip_go119.go b/internal/appsec/dyngo/instrumentation/ip_go119.go new file mode 100644 index 0000000000..78b3b0e561 --- /dev/null +++ b/internal/appsec/dyngo/instrumentation/ip_go119.go @@ -0,0 +1,34 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2022 Datadog, Inc. + +//go:build go1.19 +// +build go1.19 + +package instrumentation + +import "net/netip" + +// NetaddrIP wraps a netip.Addr value +type NetaddrIP = netip.Addr + +// NetaddrIPPrefix wraps a netip.Prefix value +type NetaddrIPPrefix = netip.Prefix + +var ( + // NetaddrParseIP wraps the netip.ParseAddr function + NetaddrParseIP = netip.ParseAddr + // NetaddrParseIPPrefix wraps the netip.ParsePrefix function + NetaddrParseIPPrefix = netip.ParsePrefix + // NetaddrMustParseIP wraps the netip.MustParseAddr function + NetaddrMustParseIP = netip.MustParseAddr + // NetaddrIPv6Raw wraps the netIP.AddrFrom16 function + NetaddrIPv6Raw = netip.AddrFrom16 +) + +// NetaddrIPv4 wraps the netip.AddrFrom4 function +func NetaddrIPv4(a, b, c, d byte) NetaddrIP { + e := [4]byte{a, b, c, d} + return netip.AddrFrom4(e) +} diff --git a/internal/appsec/testdata/blocking.json b/internal/appsec/testdata/blocking.json new file mode 100644 index 0000000000..8b53f4caf4 --- /dev/null +++ b/internal/appsec/testdata/blocking.json @@ -0,0 +1,42 @@ +{ + "version": "2.2", + "metadata": { + "rules_version": "1.4.2" + }, + "rules": [ + { + "id": "blk-001-001", + "name": "Block IP Addresses", + "tags": { + "type": "block_ip", + "category": "security_response" + }, + "conditions": [ + { + "parameters": { + "inputs": [ + { + "address": "http.client_ip" + } + ], + "data": "blocked_ips" + }, + "operator": "ip_match" + } + ], + "transformers": [], + "on_match": [ + "block" + ] + } + ], + "rules_data": [ + { + "id": "blocked_ips", + "type": "ip_with_expiration", + "data": [ + { "value": "1.2.3.4" } + ] + } + ] +} diff --git a/internal/appsec/waf.go b/internal/appsec/waf.go index ea13e36874..851360a7a3 100644 --- a/internal/appsec/waf.go +++ b/internal/appsec/waf.go @@ -99,9 +99,37 @@ func (a *appsec) registerWAF() (unreg dyngo.UnregisterFunc, err error) { // newWAFEventListener returns the WAF event listener to register in order to enable it. func newHTTPWAFEventListener(handle *waf.Handle, addresses []string, timeout time.Duration, limiter Limiter) dyngo.EventListener { var monitorRulesOnce sync.Once // per instantiation + actionHandler := httpsec.NewActionsHandler() return httpsec.OnHandlerOperationStart(func(op *httpsec.Operation, args httpsec.HandlerOperationArgs) { var body interface{} + wafCtx := waf.NewContext(handle) + if wafCtx == nil { + // The WAF event listener got concurrently released + return + } + + values := map[string]interface{}{} + for _, addr := range addresses { + if addr == httpClientIPAddr && args.ClientIP.IsValid() { + values[httpClientIPAddr] = args.ClientIP.String() + } + } + // TODO: suspicious request blocking by moving here all the addresses available when the request begins + + matches, actionIds := runWAF(wafCtx, values, timeout) + if len(matches) > 0 { + interrupt := false + for _, id := range actionIds { + interrupt = actionHandler.Apply(id, op) || interrupt + } + op.AddSecurityEvents(matches) + log.Debug("appsec: WAF detected an attack before executing the request") + if interrupt { + wafCtx.Close() + return + } + } op.On(httpsec.OnSDKBodyOperationStart(func(op *httpsec.SDKBodyOperation, args httpsec.SDKBodyOperationArgs) { body = args.Body @@ -110,13 +138,7 @@ func newHTTPWAFEventListener(handle *waf.Handle, addresses []string, timeout tim // At the moment, AppSec doesn't block the requests, and so we can use the fact we are in monitoring-only mode // to call the WAF only once at the end of the handler operation. op.On(httpsec.OnHandlerOperationFinish(func(op *httpsec.Operation, res httpsec.HandlerOperationRes) { - wafCtx := waf.NewContext(handle) - if wafCtx == nil { - // The WAF event listener got concurrently released - return - } defer wafCtx.Close() - // Run the WAF on the rule addresses available in the request args values := make(map[string]interface{}, len(addresses)) for _, addr := range addresses { @@ -135,19 +157,21 @@ func newHTTPWAFEventListener(handle *waf.Handle, addresses []string, timeout tim if query := args.Query; query != nil { values[serverRequestQueryAddr] = query } - case serverRequestPathParams: + case serverRequestPathParamsAddr: if pathParams := args.PathParams; pathParams != nil { - values[serverRequestPathParams] = pathParams + values[serverRequestPathParamsAddr] = pathParams } - case serverRequestBody: + case serverRequestBodyAddr: if body != nil { - values[serverRequestBody] = body + values[serverRequestBodyAddr] = body } case serverResponseStatusAddr: values[serverResponseStatusAddr] = res.Status } } - matches := runWAF(wafCtx, values, timeout) + // Run the WAF, ignoring the returned actions - if any - since blocking after the request handler's + // response is not supported at the moment. + matches, _ := runWAF(wafCtx, values, timeout) // Add WAF metrics. rInfo := handle.RulesetInfo() @@ -174,8 +198,9 @@ func newHTTPWAFEventListener(handle *waf.Handle, addresses []string, timeout tim // newGRPCWAFEventListener returns the WAF event listener to register in order // to enable it. -func newGRPCWAFEventListener(handle *waf.Handle, _ []string, timeout time.Duration, limiter Limiter) dyngo.EventListener { +func newGRPCWAFEventListener(handle *waf.Handle, addresses []string, timeout time.Duration, limiter Limiter) dyngo.EventListener { var monitorRulesOnce sync.Once // per instantiation + actionHandler := grpcsec.NewActionsHandler() return grpcsec.OnHandlerOperationStart(func(op *grpcsec.HandlerOperation, handlerArgs grpcsec.HandlerOperationArgs) { // Limit the maximum number of security events, as a streaming RPC could @@ -192,6 +217,34 @@ func newGRPCWAFEventListener(handle *waf.Handle, _ []string, timeout time.Durati mu sync.Mutex // events mutex ) + wafCtx := waf.NewContext(handle) + if wafCtx == nil { + // The WAF event listener got concurrently released + return + } + defer wafCtx.Close() + + // The same address is used for gRPC and http when it comes to client ip + values := map[string]interface{}{} + for _, addr := range addresses { + if addr == httpClientIPAddr && handlerArgs.ClientIP.IsValid() { + values[httpClientIPAddr] = handlerArgs.ClientIP.String() + } + } + + matches, actionIds := runWAF(wafCtx, values, timeout) + if len(matches) > 0 { + interrupt := false + for _, id := range actionIds { + interrupt = actionHandler.Apply(id, op) || interrupt + } + op.AddSecurityEvents(matches) + log.Debug("appsec: WAF detected an attack before executing the request") + if interrupt { + return + } + } + op.On(grpcsec.OnReceiveOperationFinish(func(_ grpcsec.ReceiveOperation, res grpcsec.ReceiveOperationRes) { if atomic.LoadUint32(&nbEvents) == maxWAFEventsPerRequest { logOnce.Do(func() { @@ -221,7 +274,9 @@ func newGRPCWAFEventListener(handle *waf.Handle, _ []string, timeout time.Durati if md := handlerArgs.Metadata; len(md) > 0 { values[grpcServerRequestMetadata] = md } - event := runWAF(wafCtx, values, timeout) + // Run the WAF, ignoring the returned actions - if any - since blocking after the request handler's + // response is not supported at the moment. + event, _ := runWAF(wafCtx, values, timeout) // WAF run durations are WAF context bound. As of now we need to keep track of those externally since // we use a new WAF context for each callback. When we are able to re-use the same WAF context across @@ -259,17 +314,17 @@ func newGRPCWAFEventListener(handle *waf.Handle, _ []string, timeout time.Durati }) } -func runWAF(wafCtx *waf.Context, values map[string]interface{}, timeout time.Duration) []byte { - matches, _, err := wafCtx.Run(values, timeout) +func runWAF(wafCtx *waf.Context, values map[string]interface{}, timeout time.Duration) ([]byte, []string) { + matches, actions, err := wafCtx.Run(values, timeout) if err != nil { if err == waf.ErrTimeout { log.Debug("appsec: waf timeout value of %s reached", timeout) } else { log.Error("appsec: unexpected waf error: %v", err) - return nil + return nil, nil } } - return matches + return matches, actions } // HTTP rule addresses currently supported by the WAF @@ -278,9 +333,10 @@ const ( serverRequestHeadersNoCookiesAddr = "server.request.headers.no_cookies" serverRequestCookiesAddr = "server.request.cookies" serverRequestQueryAddr = "server.request.query" - serverRequestPathParams = "server.request.path_params" - serverRequestBody = "server.request.body" + serverRequestPathParamsAddr = "server.request.path_params" + serverRequestBodyAddr = "server.request.body" serverResponseStatusAddr = "server.response.status" + httpClientIPAddr = "http.client_ip" ) // List of HTTP rule addresses currently supported by the WAF @@ -289,9 +345,10 @@ var httpAddresses = []string{ serverRequestHeadersNoCookiesAddr, serverRequestCookiesAddr, serverRequestQueryAddr, - serverRequestPathParams, - serverRequestBody, + serverRequestPathParamsAddr, + serverRequestBodyAddr, serverResponseStatusAddr, + httpClientIPAddr, } // gRPC rule addresses currently supported by the WAF @@ -304,6 +361,7 @@ const ( var grpcAddresses = []string{ grpcServerRequestMessage, grpcServerRequestMetadata, + httpClientIPAddr, } func init() { @@ -317,11 +375,16 @@ func init() { func supportedAddresses(ruleAddresses []string) (supportedHTTP, supportedGRPC, notSupported []string) { // Filter the supported addresses only for _, addr := range ruleAddresses { + supported := false if i := sort.SearchStrings(httpAddresses, addr); i < len(httpAddresses) && httpAddresses[i] == addr { supportedHTTP = append(supportedHTTP, addr) - } else if i := sort.SearchStrings(grpcAddresses, addr); i < len(grpcAddresses) && grpcAddresses[i] == addr { + supported = true + } + if i := sort.SearchStrings(grpcAddresses, addr); i < len(grpcAddresses) && grpcAddresses[i] == addr { supportedGRPC = append(supportedGRPC, addr) - } else { + supported = true + } + if !supported { notSupported = append(notSupported, addr) } } diff --git a/internal/appsec/waf_test.go b/internal/appsec/waf_test.go index dcf8bb4699..aee4923486 100644 --- a/internal/appsec/waf_test.go +++ b/internal/appsec/waf_test.go @@ -172,3 +172,67 @@ func TestWAF(t *testing.T) { require.NotContains(t, event, sensitivePayloadValue) }) } + +// Test that http blocking works by using custom rules/rules data +func TestBlocking(t *testing.T) { + t.Setenv("DD_APPSEC_RULES", "testdata/blocking.json") + + appsec.Start() + defer appsec.Stop() + if !appsec.Enabled() { + t.Skip("AppSec needs to be enabled for this test") + } + + // Start and trace an HTTP server + mux := httptrace.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello World!\n")) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + t.Run("block", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + req, err := http.NewRequest("POST", srv.URL, nil) + if err != nil { + panic(err) + } + // Hardcoded IP header holding an IP that is blocked + req.Header.Set("x-forwarded-for", "1.2.3.4") + res, err := srv.Client().Do(req) + require.NoError(t, err) + + // Check that the request was blocked + b, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.NotEqual(t, "Hello World!\n", string(b)) + require.Equal(t, 403, res.StatusCode) + }) + + t.Run("no-block", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + req1, err := http.NewRequest("POST", srv.URL, nil) + if err != nil { + panic(err) + } + req2, err := http.NewRequest("POST", srv.URL, nil) + if err != nil { + panic(err) + } + req2.Header.Set("x-forwarded-for", "1.2.3.5") + + for _, r := range []*http.Request{req1, req2} { + res, err := srv.Client().Do(r) + require.NoError(t, err) + // Check that the request was not blocked + b, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, "Hello World!\n", string(b)) + + } + }) +}