diff --git a/CODEOWNERS b/CODEOWNERS
index 7bfeb6ca6d..87ba38ad00 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -15,7 +15,7 @@
# appsec
/appsec @DataDog/appsec-go
/internal/appsec @DataDog/appsec-go
-/contrib/**/appsec.go @DataDog/appsec-go
+/contrib/**/*appsec*.go @DataDog/appsec-go
/.github/workflows/appsec.yml @DataDog/appsec-go
# telemetry
diff --git a/contrib/gin-gonic/gin/appsec.go b/contrib/gin-gonic/gin/appsec.go
index 27d85e74b9..1b0946343a 100644
--- a/contrib/gin-gonic/gin/appsec.go
+++ b/contrib/gin-gonic/gin/appsec.go
@@ -6,8 +6,6 @@
package gin
import (
- "net"
-
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation"
"gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation/httpsec"
@@ -18,8 +16,8 @@ import (
// useAppSec executes the AppSec logic related to the operation start and
// returns the function to be executed upon finishing the operation
func useAppSec(c *gin.Context, span tracer.Span) func() {
- req := c.Request
instrumentation.SetAppSecEnabledTags(span)
+
var params map[string]string
if l := len(c.Params); l > 0 {
params = make(map[string]string, l)
@@ -27,18 +25,20 @@ func useAppSec(c *gin.Context, span tracer.Span) func() {
params[p.Key] = p.Value
}
}
- args := httpsec.MakeHandlerOperationArgs(req, params)
+
+ req := c.Request
+ ipTags, clientIP := httpsec.ClientIPTags(req.Header, req.RemoteAddr)
+ instrumentation.SetStringTags(span, ipTags)
+
+ args := httpsec.MakeHandlerOperationArgs(req, clientIP, params)
ctx, op := httpsec.StartOperation(req.Context(), args)
c.Request = req.WithContext(ctx)
+
return func() {
events := op.Finish(httpsec.HandlerOperationRes{Status: c.Writer.Status()})
+ instrumentation.SetTags(span, op.Tags())
if len(events) > 0 {
- remoteIP, _, err := net.SplitHostPort(req.RemoteAddr)
- if err != nil {
- remoteIP = req.RemoteAddr
- }
- httpsec.SetSecurityEventTags(span, events, remoteIP, args.Headers, c.Writer.Header())
+ httpsec.SetSecurityEventTags(span, events, args.Headers, c.Writer.Header())
}
- instrumentation.SetTags(span, op.Tags())
}
}
diff --git a/contrib/google.golang.org/grpc/appsec.go b/contrib/google.golang.org/grpc/appsec.go
index 8006107f67..aa60298cf3 100644
--- a/contrib/google.golang.org/grpc/appsec.go
+++ b/contrib/google.golang.org/grpc/appsec.go
@@ -7,16 +7,17 @@ package grpc
import (
"encoding/json"
- "net"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace"
"gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation"
"gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation/grpcsec"
+ "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation/httpsec"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
+ "google.golang.org/grpc/status"
)
// UnaryHandler wrapper to use when AppSec is enabled to monitor its execution.
@@ -24,15 +25,28 @@ func appsecUnaryHandlerMiddleware(span ddtrace.Span, handler grpc.UnaryHandler)
instrumentation.SetAppSecEnabledTags(span)
return func(ctx context.Context, req interface{}) (interface{}, error) {
md, _ := metadata.FromIncomingContext(ctx)
- op := grpcsec.StartHandlerOperation(grpcsec.HandlerOperationArgs{Metadata: md}, nil)
+ var remoteAddr string
+ if p, ok := peer.FromContext(ctx); ok {
+ remoteAddr = p.Addr.String()
+ }
+ ipTags, clientIP := httpsec.ClientIPTags(md, remoteAddr)
+ instrumentation.SetStringTags(span, ipTags)
+
+ op := grpcsec.StartHandlerOperation(grpcsec.HandlerOperationArgs{Metadata: md, ClientIP: clientIP}, nil)
defer func() {
events := op.Finish(grpcsec.HandlerOperationRes{})
instrumentation.SetTags(span, op.Tags())
if len(events) == 0 {
return
}
- setAppSecTags(ctx, span, events)
+ setAppSecEventsTags(ctx, span, events)
}()
+
+ if op.BlockedCode != nil {
+ op.AddTag(httpsec.BlockedRequestTag, true)
+ return nil, status.Errorf(*op.BlockedCode, "Request blocked")
+ }
+
defer grpcsec.StartReceiveOperation(grpcsec.ReceiveOperationArgs{}, op).Finish(grpcsec.ReceiveOperationRes{Message: req})
return handler(ctx, req)
}
@@ -43,15 +57,28 @@ func appsecStreamHandlerMiddleware(span ddtrace.Span, handler grpc.StreamHandler
instrumentation.SetAppSecEnabledTags(span)
return func(srv interface{}, stream grpc.ServerStream) error {
md, _ := metadata.FromIncomingContext(stream.Context())
- op := grpcsec.StartHandlerOperation(grpcsec.HandlerOperationArgs{Metadata: md}, nil)
+ var remoteAddr string
+ if p, ok := peer.FromContext(stream.Context()); ok {
+ remoteAddr = p.Addr.String()
+ }
+ ipTags, clientIP := httpsec.ClientIPTags(md, remoteAddr)
+ instrumentation.SetStringTags(span, ipTags)
+
+ op := grpcsec.StartHandlerOperation(grpcsec.HandlerOperationArgs{Metadata: md, ClientIP: clientIP}, nil)
defer func() {
events := op.Finish(grpcsec.HandlerOperationRes{})
instrumentation.SetTags(span, op.Tags())
if len(events) == 0 {
return
}
- setAppSecTags(stream.Context(), span, events)
+ setAppSecEventsTags(stream.Context(), span, events)
}()
+
+ if op.BlockedCode != nil {
+ op.AddTag(httpsec.BlockedRequestTag, true)
+ return status.Error(*op.BlockedCode, "Request blocked")
+ }
+
return handler(srv, appsecServerStream{ServerStream: stream, handlerOperation: op})
}
}
@@ -72,11 +99,7 @@ func (ss appsecServerStream) RecvMsg(m interface{}) error {
}
// Set the AppSec tags when security events were found.
-func setAppSecTags(ctx context.Context, span ddtrace.Span, events []json.RawMessage) {
+func setAppSecEventsTags(ctx context.Context, span ddtrace.Span, events []json.RawMessage) {
md, _ := metadata.FromIncomingContext(ctx)
- var addr net.Addr
- if p, ok := peer.FromContext(ctx); ok {
- addr = p.Addr
- }
- grpcsec.SetSecurityEventTags(span, events, addr, md)
+ grpcsec.SetSecurityEventTags(span, events, md)
}
diff --git a/contrib/google.golang.org/grpc/appsec_test.go b/contrib/google.golang.org/grpc/appsec_test.go
index 14862e7966..23063eee28 100644
--- a/contrib/google.golang.org/grpc/appsec_test.go
+++ b/contrib/google.golang.org/grpc/appsec_test.go
@@ -14,7 +14,9 @@ import (
"gopkg.in/DataDog/dd-trace-go.v1/internal/appsec"
"github.com/stretchr/testify/require"
+ "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/status"
)
func TestAppSec(t *testing.T) {
@@ -94,3 +96,90 @@ func TestAppSec(t *testing.T) {
require.True(t, strings.Contains(event, "ua0-600-55x")) // canary rule attack attempt
})
}
+
+// Test that http blocking works by using custom rules/rules data
+func TestBlocking(t *testing.T) {
+ t.Setenv("DD_APPSEC_RULES", "../../../internal/appsec/testdata/blocking.json")
+ appsec.Start()
+ defer appsec.Stop()
+ if !appsec.Enabled() {
+ t.Skip("appsec disabled")
+ }
+
+ rig, err := newRig(false)
+ require.NoError(t, err)
+ defer rig.Close()
+
+ client := rig.client
+
+ t.Run("unary-block", func(t *testing.T) {
+ mt := mocktracer.Start()
+ defer mt.Stop()
+
+ // Send a XSS attack in the payload along with the canary value in the RPC metadata
+ ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.4"))
+ reply, err := client.Ping(ctx, &FixtureRequest{Name: ""})
+
+ require.Nil(t, reply)
+ require.Equal(t, codes.Aborted, status.Code(err))
+
+ finished := mt.FinishedSpans()
+ require.Len(t, finished, 1)
+ // The request should have the attack attempts
+ event, _ := finished[0].Tag("_dd.appsec.json").(string)
+ require.NotNil(t, event)
+ require.True(t, strings.Contains(event, "blk-001-001"))
+ })
+
+ t.Run("unary-no-block", func(t *testing.T) {
+ mt := mocktracer.Start()
+ defer mt.Stop()
+
+ // Send a XSS attack in the payload along with the canary value in the RPC metadata
+ ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.5"))
+ reply, err := client.Ping(ctx, &FixtureRequest{Name: ""})
+
+ require.Equal(t, "passed", reply.Message)
+ require.Equal(t, codes.OK, status.Code(err))
+ })
+
+ t.Run("stream-block", func(t *testing.T) {
+ mt := mocktracer.Start()
+ defer mt.Stop()
+
+ ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.4"))
+ stream, err := client.StreamPing(ctx)
+ require.NoError(t, err)
+ reply, err := stream.Recv()
+
+ require.Equal(t, codes.Aborted, status.Code(err))
+ require.Nil(t, reply)
+
+ finished := mt.FinishedSpans()
+ require.Len(t, finished, 1)
+ // The request should have the attack attempts
+ event, _ := finished[0].Tag("_dd.appsec.json").(string)
+ require.NotNil(t, event)
+ require.True(t, strings.Contains(event, "blk-001-001"))
+ })
+
+ t.Run("stream-no-block", func(t *testing.T) {
+ mt := mocktracer.Start()
+ defer mt.Stop()
+
+ ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.5"))
+ stream, err := client.StreamPing(ctx)
+ require.NoError(t, err)
+
+ // Send a XSS attack
+ err = stream.Send(&FixtureRequest{Name: ""})
+ require.NoError(t, err)
+ reply, err := stream.Recv()
+ require.Equal(t, codes.OK, status.Code(err))
+ require.Equal(t, "passed", reply.Message)
+
+ err = stream.CloseSend()
+ require.NoError(t, err)
+ })
+
+}
diff --git a/contrib/labstack/echo.v4/appsec.go b/contrib/labstack/echo.v4/appsec.go
index e5dce33396..1b8b90dfd3 100644
--- a/contrib/labstack/echo.v4/appsec.go
+++ b/contrib/labstack/echo.v4/appsec.go
@@ -6,8 +6,6 @@
package echo
import (
- "net"
-
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation"
"gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation/httpsec"
@@ -16,23 +14,24 @@ import (
)
func useAppSec(c echo.Context, span tracer.Span) func() {
- req := c.Request()
instrumentation.SetAppSecEnabledTags(span)
+
params := make(map[string]string)
for _, n := range c.ParamNames() {
params[n] = c.Param(n)
}
- args := httpsec.MakeHandlerOperationArgs(req, params)
+
+ req := c.Request()
+ ipTags, clientIP := httpsec.ClientIPTags(req.Header, req.RemoteAddr)
+ instrumentation.SetStringTags(span, ipTags)
+
+ args := httpsec.MakeHandlerOperationArgs(req, clientIP, params)
ctx, op := httpsec.StartOperation(req.Context(), args)
c.SetRequest(req.WithContext(ctx))
return func() {
events := op.Finish(httpsec.HandlerOperationRes{Status: c.Response().Status})
if len(events) > 0 {
- remoteIP, _, err := net.SplitHostPort(req.RemoteAddr)
- if err != nil {
- remoteIP = req.RemoteAddr
- }
- httpsec.SetSecurityEventTags(span, events, remoteIP, args.Headers, c.Response().Writer.Header())
+ httpsec.SetSecurityEventTags(span, events, args.Headers, c.Response().Writer.Header())
}
instrumentation.SetTags(span, op.Tags())
}
diff --git a/internal/appsec/dyngo/instrumentation/common.go b/internal/appsec/dyngo/instrumentation/common.go
index 9d0d66736e..8d2d06e5d9 100644
--- a/internal/appsec/dyngo/instrumentation/common.go
+++ b/internal/appsec/dyngo/instrumentation/common.go
@@ -49,7 +49,7 @@ func (m *TagsHolder) Tags() map[string]interface{} {
// See httpsec/http.go and grpcsec/grpc.go.
type SecurityEventsHolder struct {
events []json.RawMessage
- mu sync.Mutex
+ mu sync.RWMutex
}
// AddSecurityEvents adds the security events to the collected events list.
@@ -62,9 +62,18 @@ func (s *SecurityEventsHolder) AddSecurityEvents(events ...json.RawMessage) {
// Events returns the list of stored events.
func (s *SecurityEventsHolder) Events() []json.RawMessage {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
return s.events
}
+// ClearEvents clears the list of stored events
+func (s *SecurityEventsHolder) ClearEvents() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.events = s.events[0:0]
+}
+
// SetTags fills the span tags using the key/value pairs found in `tags`
func SetTags(span TagSetter, tags map[string]interface{}) {
for k, v := range tags {
@@ -72,6 +81,14 @@ func SetTags(span TagSetter, tags map[string]interface{}) {
}
}
+// SetStringTags fills the span tags using the key/value pairs of strings found
+// in `tags`
+func SetStringTags(span TagSetter, tags map[string]string) {
+ for k, v := range tags {
+ span.SetTag(k, v)
+ }
+}
+
// SetAppSecEnabledTags sets the AppSec-specific span tags that are expected to be in
// the web service entry span (span of type `web`) when AppSec is enabled.
func SetAppSecEnabledTags(span TagSetter) {
@@ -102,9 +119,11 @@ func SetEventSpanTags(span TagSetter, events []json.RawMessage) error {
// Create the value of the security event tag.
// TODO(Julio-Guerra): a future libddwaf version should return something
-// avoiding us the following events concatenation logic which currently
-// involves unserializing the top-level JSON arrays to concatenate them
-// together.
+//
+// avoiding us the following events concatenation logic which currently
+// involves unserializing the top-level JSON arrays to concatenate them
+// together.
+//
// TODO(Julio-Guerra): avoid serializing the json in the request hot path
func makeEventTagValue(events []json.RawMessage) (json.RawMessage, error) {
var v interface{}
diff --git a/internal/appsec/dyngo/instrumentation/grpcsec/actions.go b/internal/appsec/dyngo/instrumentation/grpcsec/actions.go
new file mode 100644
index 0000000000..845ac5d3ca
--- /dev/null
+++ b/internal/appsec/dyngo/instrumentation/grpcsec/actions.go
@@ -0,0 +1,69 @@
+// Unless explicitly stated otherwise all files in this repository are licensed
+// under the Apache License Version 2.0.
+// This product includes software developed at Datadog (https://www.datadoghq.com/).
+// Copyright 2022 Datadog, Inc.
+
+//go:build appsec
+// +build appsec
+
+package grpcsec
+
+import (
+ "sync"
+
+ "google.golang.org/grpc/codes"
+)
+
+// Action is used to identify any action kind
+type Action interface {
+ isAction()
+}
+
+// ActionsHandler handles WAF actions registration and execution
+type ActionsHandler struct {
+ mu sync.RWMutex
+ actions map[string]Action
+}
+
+// NewActionsHandler returns an action handler holding the default ASM actions.
+// Currently, only the default "block" action is supported
+func NewActionsHandler() ActionsHandler {
+ // Register the default "block" action as specified in the blocking RFC
+ actions := map[string]Action{"block": &BlockRequestAction{Status: codes.Aborted}}
+
+ return ActionsHandler{
+ actions: actions,
+ }
+}
+
+// RegisterAction registers a specific action to the actions handler. If the action kind is unknown
+// the action will have no effect
+func (h *ActionsHandler) RegisterAction(id string, a Action) {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ h.actions[id] = a
+}
+
+// Apply executes the action identified by `id`
+func (h *ActionsHandler) Apply(id string, op *HandlerOperation) bool {
+ h.mu.RLock()
+ a, ok := h.actions[id]
+ h.mu.RUnlock()
+ if !ok {
+ return false
+ }
+ // Currently, only the "block_request" type is supported, so we only need to check for blockRequestParams
+ if p, ok := a.(*BlockRequestAction); ok {
+ op.BlockedCode = &p.Status
+ return true
+ }
+ return false
+}
+
+// BlockRequestAction is the struct used to perform the request blocking action
+type BlockRequestAction struct {
+ // Status is the return code to use when blocking the request
+ Status codes.Code
+}
+
+func (*BlockRequestAction) isAction() {}
diff --git a/internal/appsec/dyngo/instrumentation/grpcsec/grpc.go b/internal/appsec/dyngo/instrumentation/grpcsec/grpc.go
index 8eb2c638f3..c581afd637 100644
--- a/internal/appsec/dyngo/instrumentation/grpcsec/grpc.go
+++ b/internal/appsec/dyngo/instrumentation/grpcsec/grpc.go
@@ -15,6 +15,8 @@ import (
"gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo"
"gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation"
+
+ "google.golang.org/grpc/codes"
)
// Abstract gRPC server handler operation definitions. It is based on two
@@ -39,12 +41,14 @@ type (
dyngo.Operation
instrumentation.TagsHolder
instrumentation.SecurityEventsHolder
+ BlockedCode *codes.Code
}
// HandlerOperationArgs is the grpc handler arguments.
HandlerOperationArgs struct {
// Message received by the gRPC handler.
// Corresponds to the address `grpc.server.request.metadata`.
Metadata map[string][]string
+ ClientIP instrumentation.NetaddrIP
}
// HandlerOperationRes is the grpc handler results. Empty as of today.
HandlerOperationRes struct{}
diff --git a/internal/appsec/dyngo/instrumentation/grpcsec/tags.go b/internal/appsec/dyngo/instrumentation/grpcsec/tags.go
index 845ab3a044..871e81c266 100644
--- a/internal/appsec/dyngo/instrumentation/grpcsec/tags.go
+++ b/internal/appsec/dyngo/instrumentation/grpcsec/tags.go
@@ -7,7 +7,6 @@ package grpcsec
import (
"encoding/json"
- "net"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace"
"gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation"
@@ -17,28 +16,20 @@ import (
// SetSecurityEventTags sets the AppSec-specific span tags when a security event
// occurred into the service entry span.
-func SetSecurityEventTags(span ddtrace.Span, events []json.RawMessage, addr net.Addr, md map[string][]string) {
- if err := setSecurityEventTags(span, events, addr, md); err != nil {
+func SetSecurityEventTags(span ddtrace.Span, events []json.RawMessage, md map[string][]string) {
+ if err := setSecurityEventTags(span, events, md); err != nil {
log.Error("appsec: %v", err)
}
}
-func setSecurityEventTags(span ddtrace.Span, events []json.RawMessage, addr net.Addr, md map[string][]string) error {
+func setSecurityEventTags(span ddtrace.Span, events []json.RawMessage, md map[string][]string) error {
if err := instrumentation.SetEventSpanTags(span, events); err != nil {
return err
}
- var ip string
- switch actual := addr.(type) {
- case *net.UDPAddr:
- ip = actual.IP.String()
- case *net.TCPAddr:
- ip = actual.IP.String()
- }
- if ip != "" {
- span.SetTag("network.client.ip", ip)
- }
+
for h, v := range httpsec.NormalizeHTTPHeaders(md) {
span.SetTag("grpc.metadata."+h, v)
}
+
return nil
}
diff --git a/internal/appsec/dyngo/instrumentation/grpcsec/tags_test.go b/internal/appsec/dyngo/instrumentation/grpcsec/tags_test.go
index 5e0e21db69..8e34bdb802 100644
--- a/internal/appsec/dyngo/instrumentation/grpcsec/tags_test.go
+++ b/internal/appsec/dyngo/instrumentation/grpcsec/tags_test.go
@@ -12,9 +12,12 @@ import (
"testing"
"github.com/stretchr/testify/require"
+ "google.golang.org/grpc/metadata"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
+ "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation"
+ "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation/httpsec"
"gopkg.in/DataDog/dd-trace-go.v1/internal/samplernames"
)
@@ -52,105 +55,116 @@ func TestSetSecurityEventTags(t *testing.T) {
},
} {
eventCase := eventCase
- for _, addrCase := range []struct {
- name string
- addr net.Addr
- expectedTag string
+ for _, metadataCase := range []struct {
+ name string
+ md map[string][]string
+ expectedTags map[string]string
}{
{
- name: "tcp-ipv4-address",
- addr: &net.TCPAddr{IP: net.ParseIP("1.2.3.4"), Port: 6789},
- expectedTag: "1.2.3.4",
+ name: "zero-metadata",
},
{
- name: "tcp-ipv6-address",
- addr: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 6789},
- expectedTag: "::1",
- },
- {
- name: "udp-ipv4-address",
- addr: &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 6789},
- expectedTag: "1.2.3.4",
+ name: "xff-metadata",
+ md: map[string][]string{
+ "x-forwarded-for": {"1.2.3.4", "4.5.6.7"},
+ ":authority": {"something"},
+ },
+ expectedTags: map[string]string{
+ "grpc.metadata.x-forwarded-for": "1.2.3.4,4.5.6.7",
+ },
},
{
- name: "udp-ipv6-address",
- addr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 6789},
- expectedTag: "::1",
+ name: "xff-metadata",
+ md: map[string][]string{
+ "x-forwarded-for": {"1.2.3.4"},
+ ":authority": {"something"},
+ },
+ expectedTags: map[string]string{
+ "grpc.metadata.x-forwarded-for": "1.2.3.4",
+ },
},
{
- name: "unix-socket-address",
- addr: &net.UnixAddr{Name: "/var/my.sock"},
+ name: "no-monitored-metadata",
+ md: map[string][]string{
+ ":authority": {"something"},
+ },
},
} {
- addrCase := addrCase
- for _, metadataCase := range []struct {
- name string
- md map[string][]string
- expectedTags map[string]string
- }{
- {
- name: "zero-metadata",
- },
- {
- name: "xff-metadata",
- md: map[string][]string{
- "x-forwarded-for": {"1.2.3.4", "4.5.6.7"},
- ":authority": {"something"},
- },
- expectedTags: map[string]string{
- "grpc.metadata.x-forwarded-for": "1.2.3.4,4.5.6.7",
- },
- },
- {
- name: "xff-metadata",
- md: map[string][]string{
- "x-forwarded-for": {"1.2.3.4"},
- ":authority": {"something"},
- },
- expectedTags: map[string]string{
- "grpc.metadata.x-forwarded-for": "1.2.3.4",
- },
- },
- {
- name: "no-monitored-metadata",
- md: map[string][]string{
- ":authority": {"something"},
- },
- },
- } {
- metadataCase := metadataCase
- t.Run(fmt.Sprintf("%s-%s-%s", eventCase.name, addrCase.name, metadataCase.name), func(t *testing.T) {
- var span MockSpan
- err := setSecurityEventTags(&span, eventCase.events, addrCase.addr, metadataCase.md)
- if eventCase.expectedError {
- require.Error(t, err)
- return
- }
- require.NoError(t, err)
-
- expectedTags := map[string]interface{}{
- "_dd.appsec.json": eventCase.expectedTag,
- "manual.keep": true,
- "appsec.event": true,
- "_dd.origin": "appsec",
- }
-
- if addr := addrCase.expectedTag; addr != "" {
- expectedTags["network.client.ip"] = addr
- }
-
- for k, v := range metadataCase.expectedTags {
- expectedTags[k] = v
- }
-
- require.Equal(t, expectedTags, span.tags)
- require.False(t, span.finished)
- })
- }
+ metadataCase := metadataCase
+ t.Run(fmt.Sprintf("%s-%s", eventCase.name, metadataCase.name), func(t *testing.T) {
+ var span MockSpan
+ err := setSecurityEventTags(&span, eventCase.events, metadataCase.md)
+ if eventCase.expectedError {
+ require.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+
+ expectedTags := map[string]interface{}{
+ "_dd.appsec.json": eventCase.expectedTag,
+ "manual.keep": true,
+ "appsec.event": true,
+ "_dd.origin": "appsec",
+ }
+
+ for k, v := range metadataCase.expectedTags {
+ expectedTags[k] = v
+ }
+
+ require.Equal(t, expectedTags, span.tags)
+ require.False(t, span.finished)
+ })
}
}
}
+func TestClientIP(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ addr net.Addr
+ md metadata.MD
+ expectedClientIP string
+ }{
+ {
+ name: "tcp-ipv4-address",
+ addr: &net.TCPAddr{IP: net.ParseIP("1.2.3.4"), Port: 6789},
+ expectedClientIP: "1.2.3.4",
+ },
+ {
+ name: "tcp-ipv4-address",
+ addr: &net.TCPAddr{IP: net.ParseIP("1.2.3.4"), Port: 6789},
+ md: map[string][]string{"x-client-ip": {"127.0.0.1, 2.3.4.5"}},
+ expectedClientIP: "2.3.4.5",
+ },
+ {
+ name: "tcp-ipv6-address",
+ addr: &net.TCPAddr{IP: net.ParseIP("::1"), Port: 6789},
+ expectedClientIP: "::1",
+ },
+ {
+ name: "udp-ipv4-address",
+ addr: &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 6789},
+ expectedClientIP: "1.2.3.4",
+ },
+ {
+ name: "udp-ipv6-address",
+ addr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 6789},
+ expectedClientIP: "::1",
+ },
+ {
+ name: "unix-socket-address",
+ addr: &net.UnixAddr{Name: "/var/my.sock"},
+ },
+ } {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ _, clientIP := httpsec.ClientIPTags(tc.md, tc.addr.String())
+ expectedClientIP, _ := instrumentation.NetaddrParseIP(tc.expectedClientIP)
+ require.Equal(t, expectedClientIP.String(), clientIP.String())
+ })
+ }
+}
+
type MockSpan struct {
tags map[string]interface{}
finished bool
diff --git a/internal/appsec/dyngo/instrumentation/httpsec/actions.go b/internal/appsec/dyngo/instrumentation/httpsec/actions.go
new file mode 100644
index 0000000000..c6d2925071
--- /dev/null
+++ b/internal/appsec/dyngo/instrumentation/httpsec/actions.go
@@ -0,0 +1,112 @@
+// Unless explicitly stated otherwise all files in this repository are licensed
+// under the Apache License Version 2.0.
+// This product includes software developed at Datadog (https://www.datadoghq.com/).
+// Copyright 2022 Datadog, Inc.
+
+package httpsec
+
+import (
+ "net/http"
+ "strings"
+ "sync"
+
+ "gopkg.in/DataDog/dd-trace-go.v1/internal/log"
+)
+
+// Action is used to identify any action kind
+type Action interface {
+ isAction()
+}
+
+// BlockRequestAction is the action that holds the HTTP handler to use to block the request
+type BlockRequestAction struct {
+ // handler is the http handler to use to block the request
+ handler http.Handler
+}
+
+func (*BlockRequestAction) isAction() {}
+
+// NewBlockRequestAction creates, initializes and returns a new BlockRequestAction
+func NewBlockRequestAction(status int, template string) BlockRequestAction {
+ htmlHandler := newBlockRequestHandler(status, "text/html", blockedTemplateHTML)
+ jsonHandler := newBlockRequestHandler(status, "application/json", blockedTemplateJSON)
+ var action BlockRequestAction
+ switch template {
+ case "json":
+ action.handler = jsonHandler
+ break
+ case "html":
+ action.handler = htmlHandler
+ break
+ default:
+ action.handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ h := jsonHandler
+ hdr := r.Header.Get("Accept")
+ htmlIdx := strings.Index(hdr, "text/html")
+ jsonIdx := strings.Index(hdr, "application/json")
+ // Switch to html handler if text/html comes before application/json in the Accept header
+ if htmlIdx != -1 && (jsonIdx == -1 || htmlIdx < jsonIdx) {
+ h = htmlHandler
+ }
+ h.ServeHTTP(w, r)
+ })
+ break
+ }
+ return action
+
+}
+
+func newBlockRequestHandler(status int, ct string, payload []byte) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", ct)
+ w.WriteHeader(status)
+ w.Write(payload)
+ })
+}
+
+// ActionsHandler handles actions registration and their application to operations
+type ActionsHandler struct {
+ mu sync.RWMutex
+ actions map[string]Action
+}
+
+// NewActionsHandler returns an action handler holding the default ASM actions.
+// Currently, only the default "block" action is supported
+func NewActionsHandler() *ActionsHandler {
+ handler := ActionsHandler{
+ actions: map[string]Action{},
+ }
+ // Register the default "block" action as specified in the RFC for HTTP blocking
+ block := NewBlockRequestAction(403, "auto")
+ handler.RegisterAction("block", &block)
+
+ return &handler
+}
+
+// RegisterAction registers a specific action to the handler. If the action kind is unknown
+// the action will not be registered
+func (h *ActionsHandler) RegisterAction(id string, a Action) {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ h.actions[id] = a
+}
+
+// Apply applies the action identified by `id` for the given operation
+// Returns true if the applied action will interrupt the request flow (block, redirect, etc...)
+func (h *ActionsHandler) Apply(id string, op *Operation) bool {
+ h.mu.RLock()
+ a, ok := h.actions[id]
+ h.mu.RUnlock()
+ if !ok {
+ log.Debug("appsec: ignoring the returned waf action: unknown action id `%s`", id)
+ return false
+ }
+ op.AddAction(a)
+
+ switch a.(type) {
+ case *BlockRequestAction:
+ return true
+ default:
+ return false
+ }
+}
diff --git a/internal/appsec/dyngo/instrumentation/httpsec/actions_test.go b/internal/appsec/dyngo/instrumentation/httpsec/actions_test.go
new file mode 100644
index 0000000000..5f48a23df2
--- /dev/null
+++ b/internal/appsec/dyngo/instrumentation/httpsec/actions_test.go
@@ -0,0 +1,155 @@
+// Unless explicitly stated otherwise all files in this repository are licensed
+// under the Apache License Version 2.0.
+// This product includes software developed at Datadog (https://www.datadoghq.com/).
+// Copyright 2016 Datadog, Inc.
+
+package httpsec
+
+import (
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewBlockRequestAction(t *testing.T) {
+ mux := http.NewServeMux()
+ srv := httptest.NewServer(mux)
+ mux.HandleFunc("/json", NewBlockRequestAction(403, "json").handler.ServeHTTP)
+ mux.HandleFunc("/html", NewBlockRequestAction(403, "html").handler.ServeHTTP)
+ mux.HandleFunc("/auto", NewBlockRequestAction(403, "auto").handler.ServeHTTP)
+ defer srv.Close()
+
+ t.Run("json", func(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ accept string
+ }{
+ {
+ name: "no-accept",
+ },
+ {
+ name: "irrelevant-accept",
+ accept: "text/html",
+ },
+ {
+ name: "accept",
+ accept: "application/json",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ req, err := http.NewRequest("POST", srv.URL+"/json", nil)
+ req.Header.Set("Accept", tc.accept)
+ require.NoError(t, err)
+ res, err := srv.Client().Do(req)
+ require.NoError(t, err)
+ defer res.Body.Close()
+ body, err := io.ReadAll(res.Body)
+ require.Equal(t, 403, res.StatusCode)
+ require.Equal(t, blockedTemplateJSON, body)
+ })
+ }
+ })
+
+ t.Run("html", func(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ accept string
+ }{
+ {
+ name: "no-accept",
+ },
+ {
+ name: "irrelevant-accept",
+ accept: "application/json",
+ },
+ {
+ name: "accept",
+ accept: "text/html",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ req, err := http.NewRequest("POST", srv.URL+"/html", nil)
+ require.NoError(t, err)
+ res, err := srv.Client().Do(req)
+ require.NoError(t, err)
+ defer res.Body.Close()
+ body, err := io.ReadAll(res.Body)
+ require.Equal(t, 403, res.StatusCode)
+ require.Equal(t, blockedTemplateHTML, body)
+ })
+ }
+ })
+
+ t.Run("auto", func(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ accept string
+ expected []byte
+ }{
+ {
+ name: "no-accept",
+ expected: blockedTemplateJSON,
+ },
+ {
+ name: "json-accept-1",
+ accept: "application/json",
+ expected: blockedTemplateJSON,
+ },
+ {
+ name: "json-accept-2",
+ accept: "application/json,text/html",
+ expected: blockedTemplateJSON,
+ },
+ {
+ name: "json-accept-3",
+ accept: "irrelevant/content,application/json,text/html",
+ expected: blockedTemplateJSON,
+ },
+ {
+ name: "json-accept-4",
+ accept: "irrelevant/content,application/json,text/html,application/json",
+ expected: blockedTemplateJSON,
+ },
+ {
+ name: "html-accept-1",
+ accept: "text/html",
+ expected: blockedTemplateHTML,
+ },
+ {
+ name: "html-accept-2",
+ accept: "text/html,application/json",
+ expected: blockedTemplateHTML,
+ },
+ {
+ name: "html-accept-3",
+ accept: "irrelevant/content,text/html,application/json",
+ expected: blockedTemplateHTML,
+ },
+ {
+ name: "html-accept-4",
+ accept: "irrelevant/content,text/html,application/json,text/html",
+ expected: blockedTemplateHTML,
+ },
+ {
+ name: "irrelevant-accept",
+ accept: "irrelevant/irrelevant,application/html",
+ expected: blockedTemplateJSON,
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ req, err := http.NewRequest("POST", srv.URL+"/auto", nil)
+ req.Header.Set("Accept", tc.accept)
+ require.NoError(t, err)
+ res, err := srv.Client().Do(req)
+ require.NoError(t, err)
+ defer res.Body.Close()
+ body, err := io.ReadAll(res.Body)
+ require.Equal(t, 403, res.StatusCode)
+ require.Equal(t, tc.expected, body)
+ })
+ }
+ })
+}
diff --git a/internal/appsec/dyngo/instrumentation/httpsec/blocked-template.html b/internal/appsec/dyngo/instrumentation/httpsec/blocked-template.html
new file mode 100644
index 0000000000..8c48babc80
--- /dev/null
+++ b/internal/appsec/dyngo/instrumentation/httpsec/blocked-template.html
@@ -0,0 +1 @@
+
You've been blocked Sorry, you cannot access this page. Please contact the customer service team.
diff --git a/internal/appsec/dyngo/instrumentation/httpsec/blocked-template.json b/internal/appsec/dyngo/instrumentation/httpsec/blocked-template.json
new file mode 100644
index 0000000000..bbcafb6cb1
--- /dev/null
+++ b/internal/appsec/dyngo/instrumentation/httpsec/blocked-template.json
@@ -0,0 +1,8 @@
+{
+ "errors": [
+ {
+ "title": "You've been blocked",
+ "detail": "Sorry, you cannot access this page. Please contact the customer service team. Security provided by Datadog."
+ }
+ ]
+}
diff --git a/internal/appsec/dyngo/instrumentation/httpsec/http.go b/internal/appsec/dyngo/instrumentation/httpsec/http.go
index b070065eeb..17606dee73 100644
--- a/internal/appsec/dyngo/instrumentation/httpsec/http.go
+++ b/internal/appsec/dyngo/instrumentation/httpsec/http.go
@@ -12,11 +12,14 @@ package httpsec
import (
"context"
+ // Blank import needed to use embed for the default blocked response payloads
+ _ "embed"
"encoding/json"
- "net"
"net/http"
+ "os"
"reflect"
"strings"
+ "sync"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace"
"gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo"
@@ -38,6 +41,8 @@ type (
Query map[string][]string
// PathParams corresponds to the address `server.request.path_params`
PathParams map[string]string
+ // ClientIP corresponds to the addres `http.client_ip`
+ ClientIP instrumentation.NetaddrIP
}
// HandlerOperationRes is the HTTP handler operation results.
@@ -68,17 +73,36 @@ func MonitorParsedBody(ctx context.Context, body interface{}) {
}
}
+// applyActions executes the operation's actions and returns the resulting http handler
+func applyActions(op *Operation) http.Handler {
+ defer op.ClearActions()
+ for _, action := range op.Actions() {
+ switch a := action.(type) {
+ case *BlockRequestAction:
+ op.AddTag(BlockedRequestTag, true)
+ return a.handler
+ default:
+ log.Error("appsec: ignoring security action: unexpected action type %T", a)
+ }
+ }
+ return nil
+}
+
// WrapHandler wraps the given HTTP handler with the abstract HTTP operation defined by HandlerOperationArgs and
// HandlerOperationRes.
func WrapHandler(handler http.Handler, span ddtrace.Span, pathParams map[string]string) http.Handler {
instrumentation.SetAppSecEnabledTags(span)
-
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- SetIPTags(span, r)
+ ipTags, clientIP := ClientIPTags(r.Header, r.RemoteAddr)
+ instrumentation.SetStringTags(span, ipTags)
- args := MakeHandlerOperationArgs(r, pathParams)
+ args := MakeHandlerOperationArgs(r, clientIP, pathParams)
ctx, op := StartOperation(r.Context(), args)
r = r.WithContext(ctx)
+
+ if h := applyActions(op); h != nil {
+ handler = h
+ }
defer func() {
var status int
if mw, ok := w.(interface{ Status() int }); ok {
@@ -91,21 +115,19 @@ func WrapHandler(handler http.Handler, span ddtrace.Span, pathParams map[string]
return
}
- remoteIP, _, err := net.SplitHostPort(r.RemoteAddr)
- if err != nil {
- remoteIP = r.RemoteAddr
- }
- SetSecurityEventTags(span, events, remoteIP, args.Headers, w.Header())
+ applyActions(op)
+ SetSecurityEventTags(span, events, args.Headers, w.Header())
}()
handler.ServeHTTP(w, r)
+
})
}
// MakeHandlerOperationArgs creates the HandlerOperationArgs out of a standard
// http.Request along with the given current span. It returns an empty structure
// when appsec is disabled.
-func MakeHandlerOperationArgs(r *http.Request, pathParams map[string]string) HandlerOperationArgs {
+func MakeHandlerOperationArgs(r *http.Request, clientIP instrumentation.NetaddrIP, pathParams map[string]string) HandlerOperationArgs {
headers := make(http.Header, len(r.Header))
for k, v := range r.Header {
k := strings.ToLower(k)
@@ -123,6 +145,7 @@ func MakeHandlerOperationArgs(r *http.Request, pathParams map[string]string) Han
Cookies: cookies,
Query: r.URL.Query(), // TODO(Julio-Guerra): avoid actively parsing the query values thanks to dynamic instrumentation
PathParams: pathParams,
+ ClientIP: clientIP,
}
}
@@ -149,6 +172,8 @@ type (
dyngo.Operation
instrumentation.TagsHolder
instrumentation.SecurityEventsHolder
+ mu sync.RWMutex
+ actions []Action
}
// SDKBodyOperation type representing an SDK body. It must be created with
@@ -187,6 +212,27 @@ func (op *Operation) Finish(res HandlerOperationRes) []json.RawMessage {
return op.Events()
}
+// Actions returns the actions linked to the operation
+func (op *Operation) Actions() []Action {
+ op.mu.RLock()
+ defer op.mu.RUnlock()
+ return op.actions
+}
+
+// AddAction adds an action to the operation
+func (op *Operation) AddAction(a Action) {
+ op.mu.Lock()
+ defer op.mu.Unlock()
+ op.actions = append(op.actions, a)
+}
+
+// ClearActions clears all the actions linked to the operation
+func (op *Operation) ClearActions() {
+ op.mu.Lock()
+ defer op.mu.Unlock()
+ op.actions = op.actions[0:0]
+}
+
// StartSDKBodyOperation starts the SDKBody operation and emits a start event
func StartSDKBodyOperation(parent *Operation, args SDKBodyOperationArgs) *SDKBodyOperation {
op := &SDKBodyOperation{Operation: dyngo.NewOperation(parent)}
@@ -261,3 +307,31 @@ func (OnSDKBodyOperationFinish) ListenedType() reflect.Type { return sdkBodyOper
func (f OnSDKBodyOperationFinish) Call(op dyngo.Operation, v interface{}) {
f(op.(*SDKBodyOperation), v.(SDKBodyOperationRes))
}
+
+// blockedTemplateJSON is the default JSON template used to write responses for blocked requests
+//
+//go:embed blocked-template.json
+var blockedTemplateJSON []byte
+
+// blockedTemplateHTML is the default HTML template used to write responses for blocked requests
+//
+//go:embed blocked-template.html
+var blockedTemplateHTML []byte
+
+const (
+ envBlockedTemplateHTML = "DD_APPSEC_HTTP_BLOCKED_TEMPLATE_HTML"
+ envBlockedTemplateJSON = "DD_APPSEC_HTTP_BLOCKED_TEMPLATE_JSON"
+)
+
+func init() {
+ for env, template := range map[string]*[]byte{envBlockedTemplateJSON: &blockedTemplateJSON, envBlockedTemplateHTML: &blockedTemplateHTML} {
+ if path, ok := os.LookupEnv(env); ok {
+ if t, err := os.ReadFile(path); err != nil {
+ log.Warn("Could not read template at %s: %v", path, err)
+ } else {
+ *template = t
+ }
+ }
+
+ }
+}
diff --git a/internal/appsec/dyngo/instrumentation/httpsec/ip_default.go b/internal/appsec/dyngo/instrumentation/httpsec/ip_default.go
deleted file mode 100644
index a96f74ef02..0000000000
--- a/internal/appsec/dyngo/instrumentation/httpsec/ip_default.go
+++ /dev/null
@@ -1,22 +0,0 @@
-// Unless explicitly stated otherwise all files in this repository are licensed
-// under the Apache License Version 2.0.
-// This product includes software developed at Datadog (https://www.datadoghq.com/).
-// Copyright 2022 Datadog, Inc.
-
-//go:build !go1.19
-// +build !go1.19
-
-package httpsec
-
-import "inet.af/netaddr"
-
-type netaddrIP = netaddr.IP
-type netaddrIPPrefix = netaddr.IPPrefix
-
-var (
- netaddrParseIP = netaddr.ParseIP
- netaddrParseIPPrefix = netaddr.ParseIPPrefix
- netaddrMustParseIP = netaddr.MustParseIP
- netaddrIPv4 = netaddr.IPv4
- netaddrIPv6Raw = netaddr.IPv6Raw
-)
diff --git a/internal/appsec/dyngo/instrumentation/httpsec/ip_go119.go b/internal/appsec/dyngo/instrumentation/httpsec/ip_go119.go
deleted file mode 100644
index dbea4b60d3..0000000000
--- a/internal/appsec/dyngo/instrumentation/httpsec/ip_go119.go
+++ /dev/null
@@ -1,26 +0,0 @@
-// Unless explicitly stated otherwise all files in this repository are licensed
-// under the Apache License Version 2.0.
-// This product includes software developed at Datadog (https://www.datadoghq.com/).
-// Copyright 2022 Datadog, Inc.
-
-//go:build go1.19
-// +build go1.19
-
-package httpsec
-
-import "net/netip"
-
-type netaddrIP = netip.Addr
-type netaddrIPPrefix = netip.Prefix
-
-var (
- netaddrParseIP = netip.ParseAddr
- netaddrParseIPPrefix = netip.ParsePrefix
- netaddrMustParseIP = netip.MustParseAddr
- netaddrIPv6Raw = netip.AddrFrom16
-)
-
-func netaddrIPv4(a, b, c, d byte) netaddrIP {
- e := [4]byte{a, b, c, d}
- return netip.AddrFrom4(e)
-}
diff --git a/internal/appsec/dyngo/instrumentation/httpsec/tags.go b/internal/appsec/dyngo/instrumentation/httpsec/tags.go
index faf300d950..5f07f140a7 100644
--- a/internal/appsec/dyngo/instrumentation/httpsec/tags.go
+++ b/internal/appsec/dyngo/instrumentation/httpsec/tags.go
@@ -8,7 +8,6 @@ package httpsec
import (
"encoding/json"
"net"
- "net/http"
"os"
"sort"
"strings"
@@ -21,15 +20,21 @@ import (
const (
// envClientIPHeader is the name of the env var used to specify the IP header to be used for client IP collection.
envClientIPHeader = "DD_TRACE_CLIENT_IP_HEADER"
- // multipleIPHeaders sets the multiple ip header tag used internally to tell the backend an error occurred when
+
+ // multipleIPHeadersTag sets the multiple ip header tag used internally to tell the backend an error occurred when
// retrieving an HTTP request client IP.
- multipleIPHeaders = "_dd.multiple-ip-headers"
+ multipleIPHeadersTag = "_dd.multiple-ip-headers"
+
+ // BlockedRequestTag used to convey whether a request is blocked
+ BlockedRequestTag = "appsec.blocked"
)
var (
- ipv6SpecialNetworks = []*netaddrIPPrefix{
+ ipv6SpecialNetworks = []*instrumentation.NetaddrIPPrefix{
ippref("fec0::/10"), // site local
}
+
+ // List of IP-related headers leveraged to retrieve the public client IP address.
defaultIPHeaders = []string{
"x-forwarded-for",
"x-real-ip",
@@ -41,6 +46,7 @@ var (
"via",
"true-client-ip",
}
+
// List of HTTP headers we collect and send.
collectedHTTPHeaders = append(defaultIPHeaders,
"host",
@@ -53,21 +59,22 @@ var (
"accept",
"accept-encoding",
"accept-language")
- clientIPHeader string
+
+ clientIPHeaderCfg string
)
func init() {
// Required by sort.SearchStrings
+ sort.Strings(defaultIPHeaders[:])
sort.Strings(collectedHTTPHeaders[:])
- clientIPHeader = os.Getenv(envClientIPHeader)
+ clientIPHeaderCfg = os.Getenv(envClientIPHeader)
}
// SetSecurityEventTags sets the AppSec-specific span tags when a security event occurred into the service entry span.
-func SetSecurityEventTags(span instrumentation.TagSetter, events []json.RawMessage, remoteIP string, headers, respHeaders map[string][]string) {
+func SetSecurityEventTags(span instrumentation.TagSetter, events []json.RawMessage, headers, respHeaders map[string][]string) {
if err := instrumentation.SetEventSpanTags(span, events); err != nil {
log.Error("appsec: unexpected error while creating the appsec event tags: %v", err)
}
- span.SetTag("network.client.ip", remoteIP)
for h, v := range NormalizeHTTPHeaders(headers) {
span.SetTag("http.request.headers."+h, v)
}
@@ -96,69 +103,106 @@ func NormalizeHTTPHeaders(headers map[string][]string) (normalized map[string]st
}
// ippref returns the IP network from an IP address string s. If not possible, it returns nil.
-func ippref(s string) *netaddrIPPrefix {
- if prefix, err := netaddrParseIPPrefix(s); err == nil {
+func ippref(s string) *instrumentation.NetaddrIPPrefix {
+ if prefix, err := instrumentation.NetaddrParseIPPrefix(s); err == nil {
return &prefix
}
return nil
}
-// SetIPTags sets the IP related span tags for a given request
-// See https://docs.datadoghq.com/tracing/configure_data_security#configuring-a-client-ip-header for more information.
-func SetIPTags(span instrumentation.TagSetter, r *http.Request) {
- ipHeaders := defaultIPHeaders
- if len(clientIPHeader) > 0 {
- ipHeaders = []string{clientIPHeader}
- }
-
- var (
- headers []string
- ips []string
- )
- for _, hdr := range ipHeaders {
- if v := r.Header.Get(hdr); v != "" {
- headers = append(headers, hdr)
- ips = append(ips, v)
+// ClientIPTags generates the IP related span tags for a given request headers
+func ClientIPTags(hdrs map[string][]string, remoteAddr string) (tags map[string]string, clientIP instrumentation.NetaddrIP) {
+ tags = map[string]string{}
+ monitoredHeaders := defaultIPHeaders
+ if clientIPHeaderCfg != "" {
+ monitoredHeaders = []string{clientIPHeaderCfg}
+ }
+
+ // Filter the list of headers
+ foundHeaders := map[string][]string{}
+ for k, v := range hdrs {
+ k = strings.ToLower(k)
+ if i := sort.SearchStrings(monitoredHeaders, k); i < len(monitoredHeaders) && monitoredHeaders[i] == k {
+ if len(v) >= 1 && v[0] != "" {
+ foundHeaders[k] = v
+ }
+ }
+ }
+
+ // If more than one IP header is present, report them and don't return any client ip
+ if len(foundHeaders) > 1 {
+ var headers []string
+ for header, ips := range foundHeaders {
+ tags[ext.HTTPRequestHeaders+"."+header] = strings.Join(ips, ",")
+ headers = append(headers, header)
}
+ sort.Strings(headers) // produce a predictable value
+ tags[multipleIPHeadersTag] = strings.Join(headers, ",")
+ return tags, instrumentation.NetaddrIP{}
}
- if l := len(ips); l == 0 {
- if remoteIP := parseIP(r.RemoteAddr); remoteIP.IsValid() && isGlobal(remoteIP) {
- span.SetTag(ext.HTTPClientIP, remoteIP.String())
+ // Walk IP-related headers
+ var foundIP instrumentation.NetaddrIP
+ for _, v := range foundHeaders {
+ // Handle multi-value headers by flattening the list of values
+ var ips []string
+ for _, ip := range v {
+ ips = append(ips, strings.Split(ip, ",")...)
}
- } else if l == 1 {
- for _, ipstr := range strings.Split(ips[0], ",") {
+
+ // Look for the first valid or global IP address in the comma-separated list
+ for _, ipstr := range ips {
ip := parseIP(strings.TrimSpace(ipstr))
- if ip.IsValid() && isGlobal(ip) {
- span.SetTag(ext.HTTPClientIP, ip.String())
+ if !ip.IsValid() {
+ continue
+ }
+ // Replace foundIP if still not valid in order to keep the oldest
+ if !foundIP.IsValid() {
+ foundIP = ip
+ }
+ if isGlobal(ip) {
+ foundIP = ip
break
}
}
- } else {
- for i := range ips {
- span.SetTag(ext.HTTPRequestHeaders+"."+headers[i], ips[i])
- }
- span.SetTag(multipleIPHeaders, strings.Join(headers, ","))
}
+
+ // Decide which IP address is the client one by starting with the remote IP
+ remoteIP := parseIP(remoteAddr)
+ if remoteIP.IsValid() {
+ tags["network.client.ip"] = remoteIP.String()
+ clientIP = remoteIP
+ }
+
+ // The IP address found in the headers supersedes a private remote IP address.
+ if foundIP.IsValid() && !isGlobal(remoteIP) || isGlobal(foundIP) {
+ clientIP = foundIP
+ }
+
+ if clientIP.IsValid() {
+ tags[ext.HTTPClientIP] = clientIP.String()
+ }
+
+ return tags, clientIP
}
-func parseIP(s string) netaddrIP {
- if ip, err := netaddrParseIP(s); err == nil {
+func parseIP(s string) instrumentation.NetaddrIP {
+ if ip, err := instrumentation.NetaddrParseIP(s); err == nil {
return ip
}
if h, _, err := net.SplitHostPort(s); err == nil {
- if ip, err := netaddrParseIP(h); err == nil {
+ if ip, err := instrumentation.NetaddrParseIP(h); err == nil {
return ip
}
}
- return netaddrIP{}
+ return instrumentation.NetaddrIP{}
}
-func isGlobal(ip netaddrIP) bool {
+func isGlobal(ip instrumentation.NetaddrIP) bool {
// IsPrivate also checks for ipv6 ULA.
// We care to check for these addresses are not considered public, hence not global.
// See https://www.rfc-editor.org/rfc/rfc4193.txt for more details.
- isGlobal := !ip.IsPrivate() && !ip.IsLoopback() && !ip.IsLinkLocalUnicast()
+ isGlobal := ip.IsValid() && !ip.IsPrivate() && !ip.IsLoopback() && !ip.IsLinkLocalUnicast()
if !isGlobal || !ip.Is6() {
return isGlobal
}
diff --git a/internal/appsec/dyngo/instrumentation/httpsec/tags_test.go b/internal/appsec/dyngo/instrumentation/httpsec/tags_test.go
index 1dfe92c6fe..2cb9bee406 100644
--- a/internal/appsec/dyngo/instrumentation/httpsec/tags_test.go
+++ b/internal/appsec/dyngo/instrumentation/httpsec/tags_test.go
@@ -11,6 +11,7 @@ import (
"testing"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
+ "gopkg.in/DataDog/dd-trace-go.v1/internal/appsec/dyngo/instrumentation"
"github.com/stretchr/testify/require"
)
@@ -58,7 +59,7 @@ type ipTestCase struct {
name string
remoteAddr string
headers map[string]string
- expectedIP netaddrIP
+ expectedIP instrumentation.NetaddrIP
multiHeaders string
clientIPHeader string
}
@@ -68,170 +69,221 @@ func genIPTestCases() []ipTestCase {
ipv6Global := randGlobalIPv6().String()
ipv4Private := randPrivateIPv4().String()
ipv6Private := randPrivateIPv6().String()
- tcs := []ipTestCase{}
+
+ tcs := []ipTestCase{
+ {
+ name: "ipv4-global-remoteaddr",
+ remoteAddr: ipv4Global,
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global),
+ },
+ {
+ name: "ipv4-private-remoteaddr",
+ remoteAddr: ipv4Private,
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv4Private),
+ },
+ {
+ name: "ipv6-global-remoteaddr",
+ remoteAddr: ipv6Global,
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv6Global),
+ },
+ {
+ name: "ipv6-private-remoteaddr",
+ remoteAddr: ipv6Private,
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv6Private),
+ },
+ }
+
// Simple ipv4 test cases over all headers
for _, header := range defaultIPHeaders {
- tcs = append(tcs, ipTestCase{
- name: "ipv4-global." + header,
- headers: map[string]string{header: ipv4Global},
- expectedIP: netaddrMustParseIP(ipv4Global),
- })
- tcs = append(tcs, ipTestCase{
- name: "ipv4-private." + header,
- headers: map[string]string{header: ipv4Private},
- expectedIP: netaddrIP{},
- })
+ tcs = append(tcs,
+ ipTestCase{
+ name: "ipv4-global." + header,
+ remoteAddr: ipv4Private,
+ headers: map[string]string{header: ipv4Global},
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global),
+ },
+ ipTestCase{
+ name: "ipv4-private." + header,
+ headers: map[string]string{header: ipv4Private},
+ remoteAddr: ipv6Private,
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv4Private),
+ },
+ ipTestCase{
+ name: "ipv4-global-remoteaddr-local-ip-header." + header,
+ remoteAddr: ipv4Global,
+ headers: map[string]string{header: ipv4Private},
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global),
+ },
+ ipTestCase{
+ name: "ipv4-global-remoteaddr-global-ip-header." + header,
+ remoteAddr: ipv6Global,
+ headers: map[string]string{header: ipv4Global},
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global),
+ })
}
+
// Simple ipv6 test cases over all headers
for _, header := range defaultIPHeaders {
tcs = append(tcs, ipTestCase{
name: "ipv6-global." + header,
+ remoteAddr: ipv4Private,
headers: map[string]string{header: ipv6Global},
- expectedIP: netaddrMustParseIP(ipv6Global),
- })
- tcs = append(tcs, ipTestCase{
- name: "ipv6-private." + header,
- headers: map[string]string{header: ipv6Private},
- expectedIP: netaddrIP{},
- })
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv6Global),
+ },
+ ipTestCase{
+ name: "ipv6-private." + header,
+ headers: map[string]string{header: ipv6Private},
+ remoteAddr: ipv4Private,
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv6Private),
+ },
+ ipTestCase{
+ name: "ipv6-global-remoteaddr-local-ip-header." + header,
+ remoteAddr: ipv6Global,
+ headers: map[string]string{header: ipv6Private},
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv6Global),
+ },
+ ipTestCase{
+ name: "ipv6-global-remoteaddr-global-ip-header." + header,
+ remoteAddr: ipv4Global,
+ headers: map[string]string{header: ipv6Global},
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv6Global),
+ })
}
+
// private and global in same header
tcs = append([]ipTestCase{
{
name: "ipv4-private+global",
headers: map[string]string{"x-forwarded-for": ipv4Private + "," + ipv4Global},
- expectedIP: netaddrMustParseIP(ipv4Global),
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global),
},
{
name: "ipv4-global+private",
headers: map[string]string{"x-forwarded-for": ipv4Global + "," + ipv4Private},
- expectedIP: netaddrMustParseIP(ipv4Global),
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global),
},
{
name: "ipv6-private+global",
headers: map[string]string{"x-forwarded-for": ipv6Private + "," + ipv6Global},
- expectedIP: netaddrMustParseIP(ipv6Global),
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv6Global),
},
{
name: "ipv6-global+private",
headers: map[string]string{"x-forwarded-for": ipv6Global + "," + ipv6Private},
- expectedIP: netaddrMustParseIP(ipv6Global),
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv6Global),
+ },
+ {
+ name: "mixed-global+global",
+ headers: map[string]string{"x-forwarded-for": ipv4Private + "," + ipv6Global + "," + ipv4Global},
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv6Global),
+ },
+ {
+ name: "mixed-global+global",
+ headers: map[string]string{"x-forwarded-for": ipv4Private + "," + ipv4Global + "," + ipv6Global},
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global),
},
}, tcs...)
+
// Invalid IPs (or a mix of valid/invalid over a single or multiple headers)
tcs = append([]ipTestCase{
{
name: "invalid-ipv4",
headers: map[string]string{"x-forwarded-for": "127..0.0.1"},
- expectedIP: netaddrIP{},
+ expectedIP: instrumentation.NetaddrIP{},
+ },
+ {
+ name: "invalid-ipv4-header-valid-remoteaddr",
+ headers: map[string]string{"x-forwarded-for": "127..0.0.1"},
+ remoteAddr: ipv4Private,
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv4Private),
},
{
name: "invalid-ipv4-recover",
- headers: map[string]string{"x-forwarded-for": "127..0.0.1, " + ipv4Global},
- expectedIP: netaddrMustParseIP(ipv4Global),
+ headers: map[string]string{"x-forwarded-for": "127..0.0.1, " + ipv6Private + "," + ipv4Global},
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global),
},
{
name: "ipv4-multi-header-1",
headers: map[string]string{"x-forwarded-for": "127.0.0.1", "forwarded-for": ipv4Global},
- expectedIP: netaddrIP{},
- multiHeaders: "x-forwarded-for,forwarded-for",
+ expectedIP: instrumentation.NetaddrIP{},
+ multiHeaders: "forwarded-for,x-forwarded-for",
},
{
name: "ipv4-multi-header-2",
headers: map[string]string{"forwarded-for": ipv4Global, "x-forwarded-for": "127.0.0.1"},
- expectedIP: netaddrIP{},
- multiHeaders: "x-forwarded-for,forwarded-for",
+ expectedIP: instrumentation.NetaddrIP{},
+ multiHeaders: "forwarded-for,x-forwarded-for",
},
{
name: "invalid-ipv6",
headers: map[string]string{"x-forwarded-for": "2001:0db8:2001:zzzz::"},
- expectedIP: netaddrIP{},
+ expectedIP: instrumentation.NetaddrIP{},
},
{
name: "invalid-ipv6-recover",
headers: map[string]string{"x-forwarded-for": "2001:0db8:2001:zzzz::, " + ipv6Global},
- expectedIP: netaddrMustParseIP(ipv6Global),
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv6Global),
},
{
name: "ipv6-multi-header-1",
headers: map[string]string{"x-forwarded-for": "2001:0db8:2001:zzzz::", "forwarded-for": ipv6Global},
- expectedIP: netaddrIP{},
- multiHeaders: "x-forwarded-for,forwarded-for",
+ expectedIP: instrumentation.NetaddrIP{},
+ multiHeaders: "forwarded-for,x-forwarded-for",
},
{
name: "ipv6-multi-header-2",
headers: map[string]string{"forwarded-for": ipv6Global, "x-forwarded-for": "2001:0db8:2001:zzzz::"},
- expectedIP: netaddrIP{},
- multiHeaders: "x-forwarded-for,forwarded-for",
+ expectedIP: instrumentation.NetaddrIP{},
+ multiHeaders: "forwarded-for,x-forwarded-for",
},
- }, tcs...)
- tcs = append([]ipTestCase{
{
name: "no-headers",
- expectedIP: netaddrIP{},
+ expectedIP: instrumentation.NetaddrIP{},
},
{
name: "header-case",
- expectedIP: netaddrMustParseIP(ipv4Global),
headers: map[string]string{"X-fOrWaRdEd-FoR": ipv4Global},
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global),
},
{
name: "user-header",
- expectedIP: netaddrMustParseIP(ipv4Global),
headers: map[string]string{"x-forwarded-for": ipv6Global, "custom-header": ipv4Global},
clientIPHeader: "custom-header",
+ expectedIP: instrumentation.NetaddrMustParseIP(ipv4Global),
},
{
name: "user-header-not-found",
- expectedIP: netaddrIP{},
headers: map[string]string{"x-forwarded-for": ipv4Global},
clientIPHeader: "custom-header",
+ expectedIP: instrumentation.NetaddrIP{},
},
}, tcs...)
return tcs
}
-type mockspan struct {
- tags map[string]interface{}
-}
-
-func (m *mockspan) SetTag(tag string, value interface{}) {
- if m.tags == nil {
- m.tags = make(map[string]interface{})
- }
- m.tags[tag] = value
-}
-
-func (m *mockspan) Tag(tag string) interface{} {
- if m.tags == nil {
- return nil
- }
- return m.tags[tag]
-}
-
func TestIPHeaders(t *testing.T) {
- // Make sure to restore the real value of clientIPHeader at the end of the test
- defer func(s string) { clientIPHeader = s }(clientIPHeader)
+ // Make sure to restore the real value of clientIPHeaderCfg at the end of the test
+ defer func(s string) { clientIPHeaderCfg = s }(clientIPHeaderCfg)
for _, tc := range genIPTestCases() {
t.Run(tc.name, func(t *testing.T) {
header := http.Header{}
for k, v := range tc.headers {
header.Add(k, v)
}
- r := http.Request{Header: header, RemoteAddr: tc.remoteAddr}
- clientIPHeader = tc.clientIPHeader
- var span mockspan
- SetIPTags(&span, &r)
+ clientIPHeaderCfg = tc.clientIPHeader
+ tags, clientIP := ClientIPTags(header, tc.remoteAddr)
if tc.expectedIP.IsValid() {
- require.Equal(t, tc.expectedIP.String(), span.Tag(ext.HTTPClientIP))
- require.Nil(t, span.Tag(multipleIPHeaders))
+ expectedIP := tc.expectedIP.String()
+ require.Equal(t, expectedIP, tags[ext.HTTPClientIP])
+ require.Equal(t, expectedIP, clientIP.String())
+ require.NotContains(t, tags, multipleIPHeadersTag)
} else {
- require.Nil(t, span.Tag(ext.HTTPClientIP))
+ require.NotContains(t, tags, ext.HTTPClientIP)
if tc.multiHeaders != "" {
- require.Equal(t, tc.multiHeaders, span.Tag(multipleIPHeaders))
+ require.Equal(t, tc.multiHeaders, tags[multipleIPHeadersTag])
for hdr, ip := range tc.headers {
- require.Equal(t, ip, span.Tag(ext.HTTPRequestHeaders+"."+hdr))
+ require.Equal(t, ip, tags[ext.HTTPRequestHeaders+"."+hdr])
}
}
}
@@ -239,12 +291,12 @@ func TestIPHeaders(t *testing.T) {
}
}
-func randIPv4() netaddrIP {
- return netaddrIPv4(uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()))
+func randIPv4() instrumentation.NetaddrIP {
+ return instrumentation.NetaddrIPv4(uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()))
}
-func randIPv6() netaddrIP {
- return netaddrIPv6Raw([16]byte{
+func randIPv6() instrumentation.NetaddrIP {
+ return instrumentation.NetaddrIPv6Raw([16]byte{
uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()),
uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()),
uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()), uint8(rand.Uint32()),
@@ -252,7 +304,7 @@ func randIPv6() netaddrIP {
})
}
-func randGlobalIPv4() netaddrIP {
+func randGlobalIPv4() instrumentation.NetaddrIP {
for {
ip := randIPv4()
if isGlobal(ip) {
@@ -261,7 +313,7 @@ func randGlobalIPv4() netaddrIP {
}
}
-func randGlobalIPv6() netaddrIP {
+func randGlobalIPv6() instrumentation.NetaddrIP {
for {
ip := randIPv6()
if isGlobal(ip) {
@@ -270,7 +322,7 @@ func randGlobalIPv6() netaddrIP {
}
}
-func randPrivateIPv4() netaddrIP {
+func randPrivateIPv4() instrumentation.NetaddrIP {
for {
ip := randIPv4()
if !isGlobal(ip) && ip.IsPrivate() {
@@ -279,7 +331,7 @@ func randPrivateIPv4() netaddrIP {
}
}
-func randPrivateIPv6() netaddrIP {
+func randPrivateIPv6() instrumentation.NetaddrIP {
for {
ip := randIPv6()
if !isGlobal(ip) && ip.IsPrivate() {
diff --git a/internal/appsec/dyngo/instrumentation/ip_default.go b/internal/appsec/dyngo/instrumentation/ip_default.go
new file mode 100644
index 0000000000..3fde0af771
--- /dev/null
+++ b/internal/appsec/dyngo/instrumentation/ip_default.go
@@ -0,0 +1,30 @@
+// Unless explicitly stated otherwise all files in this repository are licensed
+// under the Apache License Version 2.0.
+// This product includes software developed at Datadog (https://www.datadoghq.com/).
+// Copyright 2022 Datadog, Inc.
+
+//go:build !go1.19
+// +build !go1.19
+
+package instrumentation
+
+import "inet.af/netaddr"
+
+// NetaddrIP wraps an netaddr.IP value
+type NetaddrIP = netaddr.IP
+
+// NetaddrIPPrefix wraps an netaddr.IPPrefix value
+type NetaddrIPPrefix = netaddr.IPPrefix
+
+var (
+ // NetaddrParseIP wraps the netaddr.ParseIP function
+ NetaddrParseIP = netaddr.ParseIP
+ // NetaddrParseIPPrefix wraps the netaddr.ParseIPPrefix function
+ NetaddrParseIPPrefix = netaddr.ParseIPPrefix
+ // NetaddrMustParseIP wraps the netaddr.MustParseIP function
+ NetaddrMustParseIP = netaddr.MustParseIP
+ // NetaddrIPv4 wraps the netaddr.IPv4 function
+ NetaddrIPv4 = netaddr.IPv4
+ // NetaddrIPv6Raw wraps the netaddr.IPv6Raw function
+ NetaddrIPv6Raw = netaddr.IPv6Raw
+)
diff --git a/internal/appsec/dyngo/instrumentation/ip_go119.go b/internal/appsec/dyngo/instrumentation/ip_go119.go
new file mode 100644
index 0000000000..78b3b0e561
--- /dev/null
+++ b/internal/appsec/dyngo/instrumentation/ip_go119.go
@@ -0,0 +1,34 @@
+// Unless explicitly stated otherwise all files in this repository are licensed
+// under the Apache License Version 2.0.
+// This product includes software developed at Datadog (https://www.datadoghq.com/).
+// Copyright 2022 Datadog, Inc.
+
+//go:build go1.19
+// +build go1.19
+
+package instrumentation
+
+import "net/netip"
+
+// NetaddrIP wraps a netip.Addr value
+type NetaddrIP = netip.Addr
+
+// NetaddrIPPrefix wraps a netip.Prefix value
+type NetaddrIPPrefix = netip.Prefix
+
+var (
+ // NetaddrParseIP wraps the netip.ParseAddr function
+ NetaddrParseIP = netip.ParseAddr
+ // NetaddrParseIPPrefix wraps the netip.ParsePrefix function
+ NetaddrParseIPPrefix = netip.ParsePrefix
+ // NetaddrMustParseIP wraps the netip.MustParseAddr function
+ NetaddrMustParseIP = netip.MustParseAddr
+ // NetaddrIPv6Raw wraps the netIP.AddrFrom16 function
+ NetaddrIPv6Raw = netip.AddrFrom16
+)
+
+// NetaddrIPv4 wraps the netip.AddrFrom4 function
+func NetaddrIPv4(a, b, c, d byte) NetaddrIP {
+ e := [4]byte{a, b, c, d}
+ return netip.AddrFrom4(e)
+}
diff --git a/internal/appsec/testdata/blocking.json b/internal/appsec/testdata/blocking.json
new file mode 100644
index 0000000000..8b53f4caf4
--- /dev/null
+++ b/internal/appsec/testdata/blocking.json
@@ -0,0 +1,42 @@
+{
+ "version": "2.2",
+ "metadata": {
+ "rules_version": "1.4.2"
+ },
+ "rules": [
+ {
+ "id": "blk-001-001",
+ "name": "Block IP Addresses",
+ "tags": {
+ "type": "block_ip",
+ "category": "security_response"
+ },
+ "conditions": [
+ {
+ "parameters": {
+ "inputs": [
+ {
+ "address": "http.client_ip"
+ }
+ ],
+ "data": "blocked_ips"
+ },
+ "operator": "ip_match"
+ }
+ ],
+ "transformers": [],
+ "on_match": [
+ "block"
+ ]
+ }
+ ],
+ "rules_data": [
+ {
+ "id": "blocked_ips",
+ "type": "ip_with_expiration",
+ "data": [
+ { "value": "1.2.3.4" }
+ ]
+ }
+ ]
+}
diff --git a/internal/appsec/waf.go b/internal/appsec/waf.go
index ea13e36874..851360a7a3 100644
--- a/internal/appsec/waf.go
+++ b/internal/appsec/waf.go
@@ -99,9 +99,37 @@ func (a *appsec) registerWAF() (unreg dyngo.UnregisterFunc, err error) {
// newWAFEventListener returns the WAF event listener to register in order to enable it.
func newHTTPWAFEventListener(handle *waf.Handle, addresses []string, timeout time.Duration, limiter Limiter) dyngo.EventListener {
var monitorRulesOnce sync.Once // per instantiation
+ actionHandler := httpsec.NewActionsHandler()
return httpsec.OnHandlerOperationStart(func(op *httpsec.Operation, args httpsec.HandlerOperationArgs) {
var body interface{}
+ wafCtx := waf.NewContext(handle)
+ if wafCtx == nil {
+ // The WAF event listener got concurrently released
+ return
+ }
+
+ values := map[string]interface{}{}
+ for _, addr := range addresses {
+ if addr == httpClientIPAddr && args.ClientIP.IsValid() {
+ values[httpClientIPAddr] = args.ClientIP.String()
+ }
+ }
+ // TODO: suspicious request blocking by moving here all the addresses available when the request begins
+
+ matches, actionIds := runWAF(wafCtx, values, timeout)
+ if len(matches) > 0 {
+ interrupt := false
+ for _, id := range actionIds {
+ interrupt = actionHandler.Apply(id, op) || interrupt
+ }
+ op.AddSecurityEvents(matches)
+ log.Debug("appsec: WAF detected an attack before executing the request")
+ if interrupt {
+ wafCtx.Close()
+ return
+ }
+ }
op.On(httpsec.OnSDKBodyOperationStart(func(op *httpsec.SDKBodyOperation, args httpsec.SDKBodyOperationArgs) {
body = args.Body
@@ -110,13 +138,7 @@ func newHTTPWAFEventListener(handle *waf.Handle, addresses []string, timeout tim
// At the moment, AppSec doesn't block the requests, and so we can use the fact we are in monitoring-only mode
// to call the WAF only once at the end of the handler operation.
op.On(httpsec.OnHandlerOperationFinish(func(op *httpsec.Operation, res httpsec.HandlerOperationRes) {
- wafCtx := waf.NewContext(handle)
- if wafCtx == nil {
- // The WAF event listener got concurrently released
- return
- }
defer wafCtx.Close()
-
// Run the WAF on the rule addresses available in the request args
values := make(map[string]interface{}, len(addresses))
for _, addr := range addresses {
@@ -135,19 +157,21 @@ func newHTTPWAFEventListener(handle *waf.Handle, addresses []string, timeout tim
if query := args.Query; query != nil {
values[serverRequestQueryAddr] = query
}
- case serverRequestPathParams:
+ case serverRequestPathParamsAddr:
if pathParams := args.PathParams; pathParams != nil {
- values[serverRequestPathParams] = pathParams
+ values[serverRequestPathParamsAddr] = pathParams
}
- case serverRequestBody:
+ case serverRequestBodyAddr:
if body != nil {
- values[serverRequestBody] = body
+ values[serverRequestBodyAddr] = body
}
case serverResponseStatusAddr:
values[serverResponseStatusAddr] = res.Status
}
}
- matches := runWAF(wafCtx, values, timeout)
+ // Run the WAF, ignoring the returned actions - if any - since blocking after the request handler's
+ // response is not supported at the moment.
+ matches, _ := runWAF(wafCtx, values, timeout)
// Add WAF metrics.
rInfo := handle.RulesetInfo()
@@ -174,8 +198,9 @@ func newHTTPWAFEventListener(handle *waf.Handle, addresses []string, timeout tim
// newGRPCWAFEventListener returns the WAF event listener to register in order
// to enable it.
-func newGRPCWAFEventListener(handle *waf.Handle, _ []string, timeout time.Duration, limiter Limiter) dyngo.EventListener {
+func newGRPCWAFEventListener(handle *waf.Handle, addresses []string, timeout time.Duration, limiter Limiter) dyngo.EventListener {
var monitorRulesOnce sync.Once // per instantiation
+ actionHandler := grpcsec.NewActionsHandler()
return grpcsec.OnHandlerOperationStart(func(op *grpcsec.HandlerOperation, handlerArgs grpcsec.HandlerOperationArgs) {
// Limit the maximum number of security events, as a streaming RPC could
@@ -192,6 +217,34 @@ func newGRPCWAFEventListener(handle *waf.Handle, _ []string, timeout time.Durati
mu sync.Mutex // events mutex
)
+ wafCtx := waf.NewContext(handle)
+ if wafCtx == nil {
+ // The WAF event listener got concurrently released
+ return
+ }
+ defer wafCtx.Close()
+
+ // The same address is used for gRPC and http when it comes to client ip
+ values := map[string]interface{}{}
+ for _, addr := range addresses {
+ if addr == httpClientIPAddr && handlerArgs.ClientIP.IsValid() {
+ values[httpClientIPAddr] = handlerArgs.ClientIP.String()
+ }
+ }
+
+ matches, actionIds := runWAF(wafCtx, values, timeout)
+ if len(matches) > 0 {
+ interrupt := false
+ for _, id := range actionIds {
+ interrupt = actionHandler.Apply(id, op) || interrupt
+ }
+ op.AddSecurityEvents(matches)
+ log.Debug("appsec: WAF detected an attack before executing the request")
+ if interrupt {
+ return
+ }
+ }
+
op.On(grpcsec.OnReceiveOperationFinish(func(_ grpcsec.ReceiveOperation, res grpcsec.ReceiveOperationRes) {
if atomic.LoadUint32(&nbEvents) == maxWAFEventsPerRequest {
logOnce.Do(func() {
@@ -221,7 +274,9 @@ func newGRPCWAFEventListener(handle *waf.Handle, _ []string, timeout time.Durati
if md := handlerArgs.Metadata; len(md) > 0 {
values[grpcServerRequestMetadata] = md
}
- event := runWAF(wafCtx, values, timeout)
+ // Run the WAF, ignoring the returned actions - if any - since blocking after the request handler's
+ // response is not supported at the moment.
+ event, _ := runWAF(wafCtx, values, timeout)
// WAF run durations are WAF context bound. As of now we need to keep track of those externally since
// we use a new WAF context for each callback. When we are able to re-use the same WAF context across
@@ -259,17 +314,17 @@ func newGRPCWAFEventListener(handle *waf.Handle, _ []string, timeout time.Durati
})
}
-func runWAF(wafCtx *waf.Context, values map[string]interface{}, timeout time.Duration) []byte {
- matches, _, err := wafCtx.Run(values, timeout)
+func runWAF(wafCtx *waf.Context, values map[string]interface{}, timeout time.Duration) ([]byte, []string) {
+ matches, actions, err := wafCtx.Run(values, timeout)
if err != nil {
if err == waf.ErrTimeout {
log.Debug("appsec: waf timeout value of %s reached", timeout)
} else {
log.Error("appsec: unexpected waf error: %v", err)
- return nil
+ return nil, nil
}
}
- return matches
+ return matches, actions
}
// HTTP rule addresses currently supported by the WAF
@@ -278,9 +333,10 @@ const (
serverRequestHeadersNoCookiesAddr = "server.request.headers.no_cookies"
serverRequestCookiesAddr = "server.request.cookies"
serverRequestQueryAddr = "server.request.query"
- serverRequestPathParams = "server.request.path_params"
- serverRequestBody = "server.request.body"
+ serverRequestPathParamsAddr = "server.request.path_params"
+ serverRequestBodyAddr = "server.request.body"
serverResponseStatusAddr = "server.response.status"
+ httpClientIPAddr = "http.client_ip"
)
// List of HTTP rule addresses currently supported by the WAF
@@ -289,9 +345,10 @@ var httpAddresses = []string{
serverRequestHeadersNoCookiesAddr,
serverRequestCookiesAddr,
serverRequestQueryAddr,
- serverRequestPathParams,
- serverRequestBody,
+ serverRequestPathParamsAddr,
+ serverRequestBodyAddr,
serverResponseStatusAddr,
+ httpClientIPAddr,
}
// gRPC rule addresses currently supported by the WAF
@@ -304,6 +361,7 @@ const (
var grpcAddresses = []string{
grpcServerRequestMessage,
grpcServerRequestMetadata,
+ httpClientIPAddr,
}
func init() {
@@ -317,11 +375,16 @@ func init() {
func supportedAddresses(ruleAddresses []string) (supportedHTTP, supportedGRPC, notSupported []string) {
// Filter the supported addresses only
for _, addr := range ruleAddresses {
+ supported := false
if i := sort.SearchStrings(httpAddresses, addr); i < len(httpAddresses) && httpAddresses[i] == addr {
supportedHTTP = append(supportedHTTP, addr)
- } else if i := sort.SearchStrings(grpcAddresses, addr); i < len(grpcAddresses) && grpcAddresses[i] == addr {
+ supported = true
+ }
+ if i := sort.SearchStrings(grpcAddresses, addr); i < len(grpcAddresses) && grpcAddresses[i] == addr {
supportedGRPC = append(supportedGRPC, addr)
- } else {
+ supported = true
+ }
+ if !supported {
notSupported = append(notSupported, addr)
}
}
diff --git a/internal/appsec/waf_test.go b/internal/appsec/waf_test.go
index dcf8bb4699..aee4923486 100644
--- a/internal/appsec/waf_test.go
+++ b/internal/appsec/waf_test.go
@@ -172,3 +172,67 @@ func TestWAF(t *testing.T) {
require.NotContains(t, event, sensitivePayloadValue)
})
}
+
+// Test that http blocking works by using custom rules/rules data
+func TestBlocking(t *testing.T) {
+ t.Setenv("DD_APPSEC_RULES", "testdata/blocking.json")
+
+ appsec.Start()
+ defer appsec.Stop()
+ if !appsec.Enabled() {
+ t.Skip("AppSec needs to be enabled for this test")
+ }
+
+ // Start and trace an HTTP server
+ mux := httptrace.NewServeMux()
+ mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("Hello World!\n"))
+ })
+ srv := httptest.NewServer(mux)
+ defer srv.Close()
+
+ t.Run("block", func(t *testing.T) {
+ mt := mocktracer.Start()
+ defer mt.Stop()
+
+ req, err := http.NewRequest("POST", srv.URL, nil)
+ if err != nil {
+ panic(err)
+ }
+ // Hardcoded IP header holding an IP that is blocked
+ req.Header.Set("x-forwarded-for", "1.2.3.4")
+ res, err := srv.Client().Do(req)
+ require.NoError(t, err)
+
+ // Check that the request was blocked
+ b, err := io.ReadAll(res.Body)
+ require.NoError(t, err)
+ require.NotEqual(t, "Hello World!\n", string(b))
+ require.Equal(t, 403, res.StatusCode)
+ })
+
+ t.Run("no-block", func(t *testing.T) {
+ mt := mocktracer.Start()
+ defer mt.Stop()
+
+ req1, err := http.NewRequest("POST", srv.URL, nil)
+ if err != nil {
+ panic(err)
+ }
+ req2, err := http.NewRequest("POST", srv.URL, nil)
+ if err != nil {
+ panic(err)
+ }
+ req2.Header.Set("x-forwarded-for", "1.2.3.5")
+
+ for _, r := range []*http.Request{req1, req2} {
+ res, err := srv.Client().Do(r)
+ require.NoError(t, err)
+ // Check that the request was not blocked
+ b, err := io.ReadAll(res.Body)
+ require.NoError(t, err)
+ require.Equal(t, "Hello World!\n", string(b))
+
+ }
+ })
+}