Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CustomRegex validation #939

Merged
merged 6 commits into from Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
151 changes: 151 additions & 0 deletions pkg/custom_detectors/custom_detectors.go
@@ -0,0 +1,151 @@
package custom_detectors

import (
"fmt"
"regexp"
"strconv"
"strings"

"github.com/trufflesecurity/trufflehog/v3/pkg/pb/custom_detectorspb"
)

// customRegex is a CustomRegex that is guaranteed to be valid.
type customRegex *custom_detectorspb.CustomRegex

func ValidateKeywords(keywords []string) error {
if len(keywords) == 0 {
return fmt.Errorf("no keywords")
}

for _, keyword := range keywords {
if len(keyword) == 0 {
return fmt.Errorf("empty keyword")
}
}
return nil
}

func ValidateRegex(regex map[string]string) error {
if len(regex) == 0 {
return fmt.Errorf("no regex")
}

for _, r := range regex {
if _, err := regexp.Compile(r); err != nil {
return fmt.Errorf("invalid regex %q", r)
}
}

return nil
}

func ValidateVerifyEndpoint(endpoint string, unsafe bool) error {
if len(endpoint) == 0 {
return fmt.Errorf("no endpoint")
}

if strings.HasPrefix(endpoint, "http://") && !unsafe {
return fmt.Errorf("http endpoint must have unsafe=true")
}
return nil
}

func ValidateVerifyHeaders(headers []string) error {
for _, header := range headers {
if !strings.Contains(header, ":") {
return fmt.Errorf("header %q must contain a colon", header)
}
}
return nil
}

func ValidateVerifyRanges(ranges []string) error {
const httpLowerBound = 100
const httpUpperBound = 599

for _, successRange := range ranges {
if !strings.Contains(successRange, "-") {
httpCode, err := strconv.Atoi(successRange)
if err != nil {
return fmt.Errorf("unable to convert http code to int %q", successRange)
}

if httpCode < httpLowerBound || httpCode > httpUpperBound {
return fmt.Errorf("invalid http status code %q", successRange)
}

continue
}

httpRange := strings.Split(successRange, "-")
if len(httpRange) != 2 {
return fmt.Errorf("invalid range format %q", successRange)
}

lowerBound, err := strconv.Atoi(httpRange[0])
if err != nil {
return fmt.Errorf("unable to convert lower bound to int %q", successRange)
}

upperBound, err := strconv.Atoi(httpRange[1])
if err != nil {
return fmt.Errorf("unable to convert upper bound to int %q", successRange)
}

if lowerBound > upperBound {
return fmt.Errorf("lower bound greater than upper bound on range %q", successRange)
}

if lowerBound < httpLowerBound || upperBound > httpUpperBound {
return fmt.Errorf("invalid http status code range %q", successRange)
}
}
return nil
}

func ValidateRegexVars(regex map[string]string, body ...string) error {
for _, b := range body {
matches := NewRegexVarString(b).variables

for match := range matches {
if _, ok := regex[match]; !ok {
return fmt.Errorf("body %q contains an unknown variable", b)
}
}
}

return nil
}

func NewCustomRegex(pb *custom_detectorspb.CustomRegex) (customRegex, error) {
// TODO: Return all validation errors.
if err := ValidateKeywords(pb.Keywords); err != nil {
return nil, err
}

if err := ValidateRegex(pb.Regex); err != nil {
return nil, err
}

for _, verify := range pb.Verify {

if err := ValidateVerifyEndpoint(verify.Endpoint, verify.Unsafe); err != nil {
return nil, err
}

if err := ValidateVerifyHeaders(verify.Headers); err != nil {
return nil, err
}

if err := ValidateVerifyRanges(verify.SuccessRanges); err != nil {
return nil, err
}

if err := ValidateRegexVars(pb.Regex, append(verify.Headers, verify.Endpoint)...); err != nil {
return nil, err
}

}

return pb, nil
}
224 changes: 224 additions & 0 deletions pkg/custom_detectors/custom_detectors_test.go
Expand Up @@ -66,3 +66,227 @@ func TestCustomDetectorsParsing(t *testing.T) {
assert.NoError(t, protoyaml.UnmarshalStrict([]byte(testYamlConfig), &messages))
assertExpected(t, messages.Detectors[0])
}

func TestCustomDetectorsKeywordValidation(t *testing.T) {
tests := []struct {
name string
input []string
wantErr bool
}{
{
name: "Test empty list of keywords",
input: []string{},
wantErr: true,
},
{
name: "Test empty keyword",
input: []string{""},
wantErr: true,
},
{
name: "Test valid keywords",
input: []string{"hello", "world"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ValidateKeywords(tt.input)

if (got != nil && !tt.wantErr) || (got == nil && tt.wantErr) {
t.Errorf("ValidateKeywords() error = %v, wantErr %v", got, tt.wantErr)
}
})
}
}

func TestCustomDetectorsRegexValidation(t *testing.T) {
tests := []struct {
name string
input map[string]string
wantErr bool
}{
{
name: "Test list of keywords",
input: map[string]string{
"id_pat_example": "([a-zA-Z0-9]{32})",
},
wantErr: false,
},
{
name: "Test empty list of keywords",
input: map[string]string{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ValidateRegex(tt.input)

if (got != nil && !tt.wantErr) || (got == nil && tt.wantErr) {
t.Errorf("ValidateRegex() error = %v, wantErr %v", got, tt.wantErr)
}
})
}
}

func TestCustomDetectorsVerifyEndpointValidation(t *testing.T) {
tests := []struct {
name string
endpoint string
unsafe bool
wantErr bool
}{
{
name: "Test http endpoint with unsafe flag",
endpoint: "http://localhost:8000/{id_pat_example}",
unsafe: true,
wantErr: false,
},
{
name: "Test http endpoint without unsafe flag",
endpoint: "http://localhost:8000/{id_pat_example}",
unsafe: false,
wantErr: true,
},
{
name: "Test https endpoint with unsafe flag",
endpoint: "https://localhost:8000/{id_pat_example}",
unsafe: true,
wantErr: false,
},
{
name: "Test https endpoint without unsafe flag",
endpoint: "https://localhost:8000/{id_pat_example}",
unsafe: false,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ValidateVerifyEndpoint(tt.endpoint, tt.unsafe)

if (got != nil && !tt.wantErr) || (got == nil && tt.wantErr) {
t.Errorf("ValidateVerifyEndpoint() error = %v, wantErr %v", got, tt.wantErr)
}
})
}
}

func TestCustomDetectorsVerifyHeadersValidation(t *testing.T) {
tests := []struct {
name string
headers []string
wantErr bool
}{
{
name: "Test single header",
headers: []string{"Authorization: Bearer {secret_pat_example.0}"},
wantErr: false,
},
{
name: "Test invalid header",
headers: []string{"Hello world"},
wantErr: true,
},
{
name: "Test ugly header",
headers: []string{"Hello:::::::world::hi:"},
wantErr: false,
},
{
name: "Test empty header",
headers: []string{},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ValidateVerifyHeaders(tt.headers)

if (got != nil && !tt.wantErr) || (got == nil && tt.wantErr) {
t.Errorf("ValidateVerifyHeaders() error = %v, wantErr %v", got, tt.wantErr)
}
})
}
}

func TestCustomDetectorsVerifyRangeValidation(t *testing.T) {
tests := []struct {
name string
ranges []string
wantErr bool
}{
{
name: "Test multiple mixed ranges",
ranges: []string{"200", "300-350"},
wantErr: false,
},
{
name: "Test invalid non-number range",
ranges: []string{"hi"},
wantErr: true,
},
{
name: "Test invalid lower to upper range",
ranges: []string{"200-100"},
wantErr: true,
},
{
name: "Test invalid http range",
ranges: []string{"400-1000"},
wantErr: true,
},
{
name: "Test multiple ranges with invalid inputs",
ranges: []string{"322", "hello-world", "100-200"},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ValidateVerifyRanges(tt.ranges)

if (got != nil && !tt.wantErr) || (got == nil && tt.wantErr) {
t.Errorf("ValidateVerifyRanges() error = %v, wantErr %v", got, tt.wantErr)
}
})
}
}

func TestCustomDetectorsVerifyRegexVarsValidation(t *testing.T) {
tests := []struct {
name string
regex map[string]string
body string
wantErr bool
}{
{
name: "Regex defined but not used in body",
regex: map[string]string{"id": "[0-9]{1,10}", "id_pat_example": "([a-zA-Z0-9]{32})"},
body: "hello world",
wantErr: false,
},
{
name: "Regex defined and is used in body",
regex: map[string]string{"id": "[0-9]{1,10}", "id_pat_example": "([a-zA-Z0-9]{32})"},
body: "hello world {id}",
wantErr: false,
},
{
name: "Regex var in body but not defined",
regex: map[string]string{"id": "[0-9]{1,10}", "id_pat_example": "([a-zA-Z0-9]{32})"},
body: "hello world {hello}",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ValidateRegexVars(tt.regex, tt.body)

if (got != nil && !tt.wantErr) || (got == nil && tt.wantErr) {
t.Errorf("ValidateRegexVars() error = %v, wantErr %v", got, tt.wantErr)
}
})
}
}