diff --git a/internal/appsec/waf_test.go b/internal/appsec/waf_test.go index dcf8bb4699..aee4923486 100644 --- a/internal/appsec/waf_test.go +++ b/internal/appsec/waf_test.go @@ -172,3 +172,67 @@ func TestWAF(t *testing.T) { require.NotContains(t, event, sensitivePayloadValue) }) } + +// Test that http blocking works by using custom rules/rules data +func TestBlocking(t *testing.T) { + t.Setenv("DD_APPSEC_RULES", "testdata/blocking.json") + + appsec.Start() + defer appsec.Stop() + if !appsec.Enabled() { + t.Skip("AppSec needs to be enabled for this test") + } + + // Start and trace an HTTP server + mux := httptrace.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello World!\n")) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + t.Run("block", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + req, err := http.NewRequest("POST", srv.URL, nil) + if err != nil { + panic(err) + } + // Hardcoded IP header holding an IP that is blocked + req.Header.Set("x-forwarded-for", "1.2.3.4") + res, err := srv.Client().Do(req) + require.NoError(t, err) + + // Check that the request was blocked + b, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.NotEqual(t, "Hello World!\n", string(b)) + require.Equal(t, 403, res.StatusCode) + }) + + t.Run("no-block", func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + req1, err := http.NewRequest("POST", srv.URL, nil) + if err != nil { + panic(err) + } + req2, err := http.NewRequest("POST", srv.URL, nil) + if err != nil { + panic(err) + } + req2.Header.Set("x-forwarded-for", "1.2.3.5") + + for _, r := range []*http.Request{req1, req2} { + res, err := srv.Client().Do(r) + require.NoError(t, err) + // Check that the request was not blocked + b, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, "Hello World!\n", string(b)) + + } + }) +}