diff --git a/pkg/custom_detectors/custom_detectors.go b/pkg/custom_detectors/custom_detectors.go new file mode 100644 index 000000000000..c41972896c3a --- /dev/null +++ b/pkg/custom_detectors/custom_detectors.go @@ -0,0 +1,151 @@ +package custom_detectors + +import ( + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/trufflesecurity/trufflehog/v3/pkg/pb/custom_detectorspb" +) + +// customRegex is a CustomRegex that is guaranteed to be valid. +type customRegex *custom_detectorspb.CustomRegex + +func ValidateKeywords(keywords []string) error { + if len(keywords) == 0 { + return fmt.Errorf("no keywords") + } + + for _, keyword := range keywords { + if len(keyword) == 0 { + return fmt.Errorf("empty keyword") + } + } + return nil +} + +func ValidateRegex(regex map[string]string) error { + if len(regex) == 0 { + return fmt.Errorf("no regex") + } + + for _, r := range regex { + if _, err := regexp.Compile(r); err != nil { + return fmt.Errorf("invalid regex %q", r) + } + } + + return nil +} + +func ValidateVerifyEndpoint(endpoint string, unsafe bool) error { + if len(endpoint) == 0 { + return fmt.Errorf("no endpoint") + } + + if strings.HasPrefix(endpoint, "http://") && !unsafe { + return fmt.Errorf("http endpoint must have unsafe=true") + } + return nil +} + +func ValidateVerifyHeaders(headers []string) error { + for _, header := range headers { + if !strings.Contains(header, ":") { + return fmt.Errorf("header %q must contain a colon", header) + } + } + return nil +} + +func ValidateVerifyRanges(ranges []string) error { + const httpLowerBound = 100 + const httpUpperBound = 599 + + for _, successRange := range ranges { + if !strings.Contains(successRange, "-") { + httpCode, err := strconv.Atoi(successRange) + if err != nil { + return fmt.Errorf("unable to convert http code to int %q", successRange) + } + + if httpCode < httpLowerBound || httpCode > httpUpperBound { + return fmt.Errorf("invalid http status code %q", successRange) + } + + continue + } + + httpRange := strings.Split(successRange, "-") + if len(httpRange) != 2 { + return fmt.Errorf("invalid range format %q", successRange) + } + + lowerBound, err := strconv.Atoi(httpRange[0]) + if err != nil { + return fmt.Errorf("unable to convert lower bound to int %q", successRange) + } + + upperBound, err := strconv.Atoi(httpRange[1]) + if err != nil { + return fmt.Errorf("unable to convert upper bound to int %q", successRange) + } + + if lowerBound > upperBound { + return fmt.Errorf("lower bound greater than upper bound on range %q", successRange) + } + + if lowerBound < httpLowerBound || upperBound > httpUpperBound { + return fmt.Errorf("invalid http status code range %q", successRange) + } + } + return nil +} + +func ValidateRegexVars(regex map[string]string, body ...string) error { + for _, b := range body { + matches := NewRegexVarString(b).variables + + for match := range matches { + if _, ok := regex[match]; !ok { + return fmt.Errorf("body %q contains an unknown variable", b) + } + } + } + + return nil +} + +func NewCustomRegex(pb *custom_detectorspb.CustomRegex) (customRegex, error) { + // TODO: Return all validation errors. + if err := ValidateKeywords(pb.Keywords); err != nil { + return nil, err + } + + if err := ValidateRegex(pb.Regex); err != nil { + return nil, err + } + + for _, verify := range pb.Verify { + + if err := ValidateVerifyEndpoint(verify.Endpoint, verify.Unsafe); err != nil { + return nil, err + } + + if err := ValidateVerifyHeaders(verify.Headers); err != nil { + return nil, err + } + + if err := ValidateVerifyRanges(verify.SuccessRanges); err != nil { + return nil, err + } + + if err := ValidateRegexVars(pb.Regex, append(verify.Headers, verify.Endpoint)...); err != nil { + return nil, err + } + + } + + return pb, nil +} diff --git a/pkg/custom_detectors/custom_detectors_test.go b/pkg/custom_detectors/custom_detectors_test.go index 96a16001999d..aedb1fcc2739 100644 --- a/pkg/custom_detectors/custom_detectors_test.go +++ b/pkg/custom_detectors/custom_detectors_test.go @@ -66,3 +66,227 @@ func TestCustomDetectorsParsing(t *testing.T) { assert.NoError(t, protoyaml.UnmarshalStrict([]byte(testYamlConfig), &messages)) assertExpected(t, messages.Detectors[0]) } + +func TestCustomDetectorsKeywordValidation(t *testing.T) { + tests := []struct { + name string + input []string + wantErr bool + }{ + { + name: "Test empty list of keywords", + input: []string{}, + wantErr: true, + }, + { + name: "Test empty keyword", + input: []string{""}, + wantErr: true, + }, + { + name: "Test valid keywords", + input: []string{"hello", "world"}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ValidateKeywords(tt.input) + + if (got != nil && !tt.wantErr) || (got == nil && tt.wantErr) { + t.Errorf("ValidateKeywords() error = %v, wantErr %v", got, tt.wantErr) + } + }) + } +} + +func TestCustomDetectorsRegexValidation(t *testing.T) { + tests := []struct { + name string + input map[string]string + wantErr bool + }{ + { + name: "Test list of keywords", + input: map[string]string{ + "id_pat_example": "([a-zA-Z0-9]{32})", + }, + wantErr: false, + }, + { + name: "Test empty list of keywords", + input: map[string]string{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ValidateRegex(tt.input) + + if (got != nil && !tt.wantErr) || (got == nil && tt.wantErr) { + t.Errorf("ValidateRegex() error = %v, wantErr %v", got, tt.wantErr) + } + }) + } +} + +func TestCustomDetectorsVerifyEndpointValidation(t *testing.T) { + tests := []struct { + name string + endpoint string + unsafe bool + wantErr bool + }{ + { + name: "Test http endpoint with unsafe flag", + endpoint: "http://localhost:8000/{id_pat_example}", + unsafe: true, + wantErr: false, + }, + { + name: "Test http endpoint without unsafe flag", + endpoint: "http://localhost:8000/{id_pat_example}", + unsafe: false, + wantErr: true, + }, + { + name: "Test https endpoint with unsafe flag", + endpoint: "https://localhost:8000/{id_pat_example}", + unsafe: true, + wantErr: false, + }, + { + name: "Test https endpoint without unsafe flag", + endpoint: "https://localhost:8000/{id_pat_example}", + unsafe: false, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ValidateVerifyEndpoint(tt.endpoint, tt.unsafe) + + if (got != nil && !tt.wantErr) || (got == nil && tt.wantErr) { + t.Errorf("ValidateVerifyEndpoint() error = %v, wantErr %v", got, tt.wantErr) + } + }) + } +} + +func TestCustomDetectorsVerifyHeadersValidation(t *testing.T) { + tests := []struct { + name string + headers []string + wantErr bool + }{ + { + name: "Test single header", + headers: []string{"Authorization: Bearer {secret_pat_example.0}"}, + wantErr: false, + }, + { + name: "Test invalid header", + headers: []string{"Hello world"}, + wantErr: true, + }, + { + name: "Test ugly header", + headers: []string{"Hello:::::::world::hi:"}, + wantErr: false, + }, + { + name: "Test empty header", + headers: []string{}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ValidateVerifyHeaders(tt.headers) + + if (got != nil && !tt.wantErr) || (got == nil && tt.wantErr) { + t.Errorf("ValidateVerifyHeaders() error = %v, wantErr %v", got, tt.wantErr) + } + }) + } +} + +func TestCustomDetectorsVerifyRangeValidation(t *testing.T) { + tests := []struct { + name string + ranges []string + wantErr bool + }{ + { + name: "Test multiple mixed ranges", + ranges: []string{"200", "300-350"}, + wantErr: false, + }, + { + name: "Test invalid non-number range", + ranges: []string{"hi"}, + wantErr: true, + }, + { + name: "Test invalid lower to upper range", + ranges: []string{"200-100"}, + wantErr: true, + }, + { + name: "Test invalid http range", + ranges: []string{"400-1000"}, + wantErr: true, + }, + { + name: "Test multiple ranges with invalid inputs", + ranges: []string{"322", "hello-world", "100-200"}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ValidateVerifyRanges(tt.ranges) + + if (got != nil && !tt.wantErr) || (got == nil && tt.wantErr) { + t.Errorf("ValidateVerifyRanges() error = %v, wantErr %v", got, tt.wantErr) + } + }) + } +} + +func TestCustomDetectorsVerifyRegexVarsValidation(t *testing.T) { + tests := []struct { + name string + regex map[string]string + body string + wantErr bool + }{ + { + name: "Regex defined but not used in body", + regex: map[string]string{"id": "[0-9]{1,10}", "id_pat_example": "([a-zA-Z0-9]{32})"}, + body: "hello world", + wantErr: false, + }, + { + name: "Regex defined and is used in body", + regex: map[string]string{"id": "[0-9]{1,10}", "id_pat_example": "([a-zA-Z0-9]{32})"}, + body: "hello world {id}", + wantErr: false, + }, + { + name: "Regex var in body but not defined", + regex: map[string]string{"id": "[0-9]{1,10}", "id_pat_example": "([a-zA-Z0-9]{32})"}, + body: "hello world {hello}", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ValidateRegexVars(tt.regex, tt.body) + + if (got != nil && !tt.wantErr) || (got == nil && tt.wantErr) { + t.Errorf("ValidateRegexVars() error = %v, wantErr %v", got, tt.wantErr) + } + }) + } +} diff --git a/pkg/custom_detectors/regex_varstring.go b/pkg/custom_detectors/regex_varstring.go new file mode 100644 index 000000000000..e0efc1364183 --- /dev/null +++ b/pkg/custom_detectors/regex_varstring.go @@ -0,0 +1,44 @@ +package custom_detectors + +import ( + "regexp" + "strconv" + "strings" +) + +// nameGroupRegex matches `{ name . group }` ignoring any whitespace. +var nameGroupRegex = regexp.MustCompile(`{\s*([a-zA-Z0-9-_]+)\s*(\.\s*[0-9]*)?\s*}`) + +// RegexVarString is a string with embedded {name.group} variables. A name may +// only contain alphanumeric, hyphen, and underscore characters. Group is +// optional but if provided it must be a non-negative integer. If the group is +// omitted it defaults to 0. +type RegexVarString struct { + original string + // map from name to group + variables map[string]int +} + +func NewRegexVarString(original string) RegexVarString { + variables := make(map[string]int) + + matches := nameGroupRegex.FindAllStringSubmatch(original, -1) + for _, match := range matches { + name, group := match[1], 0 + // The second match will start with a period followed by any number + // of whitespace. + if len(match[2]) > 1 { + g, err := strconv.Atoi(strings.TrimSpace(match[2][1:])) + if err != nil { + continue + } + group = g + } + variables[name] = group + } + + return RegexVarString{ + original: original, + variables: variables, + } +} diff --git a/pkg/custom_detectors/regex_varstring_test.go b/pkg/custom_detectors/regex_varstring_test.go new file mode 100644 index 000000000000..7e7bec85b561 --- /dev/null +++ b/pkg/custom_detectors/regex_varstring_test.go @@ -0,0 +1,76 @@ +package custom_detectors + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestVarString(t *testing.T) { + tests := []struct { + name string + input string + wantVars map[string]int + }{ + { + name: "empty", + input: "{}", + wantVars: map[string]int{}, + }, + { + name: "no subgroup", + input: "{hello}", + wantVars: map[string]int{ + "hello": 0, + }, + }, + { + name: "with subgroup", + input: "{hello.123}", + wantVars: map[string]int{ + "hello": 123, + }, + }, + { + name: "subgroup with spaces", + input: "{\thell0 . 123 }", + wantVars: map[string]int{ + "hell0": 123, + }, + }, + { + name: "multiple groups", + input: "foo {bar} {bazz.buzz} {buzz.2}", + wantVars: map[string]int{ + "bar": 0, + "buzz": 2, + }, + }, + { + name: "nested groups", + input: "{foo {bar}}", + wantVars: map[string]int{ + "bar": 0, + }, + }, + { + name: "decimal without number", + input: "{foo.}", + wantVars: map[string]int{ + "foo": 0, + }, + }, + { + name: "negative number", + input: "{foo.-1}", + wantVars: map[string]int{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewRegexVarString(tt.input) + assert.Equal(t, tt.input, got.original) + assert.Equal(t, tt.wantVars, got.variables) + }) + } +}