diff --git a/pkg/detectors/jdbc/jdbc.go b/pkg/detectors/jdbc/jdbc.go index b3ee1b3481d4..385548aea64e 100644 --- a/pkg/detectors/jdbc/jdbc.go +++ b/pkg/detectors/jdbc/jdbc.go @@ -64,27 +64,94 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) (result } func tryRedactAnonymousJDBC(conn string) string { + if s, ok := tryRedactBasicAuth(conn); ok { + return s + } + if s, ok := tryRedactURLParams(conn); ok { + return s + } + if s, ok := tryRedactODBC(conn); ok { + return s + } + if s, ok := tryRedactRegex(conn); ok { + return s + } + return conn +} + +// Basic authentication "username:password@host" style +func tryRedactBasicAuth(conn string) (string, bool) { userPass, postfix, found := strings.Cut(conn, "@") - if found { - if index := strings.LastIndex(userPass, ":"); index != -1 { - prefix, pass := userPass[:index], userPass[index+1:] - return prefix + ":" + strings.Repeat("*", len(pass)) + "@" + postfix - } + if !found { + return "", false } + index := strings.LastIndex(userPass, ":") + if index == -1 { + return "", false + } + prefix, pass := userPass[:index], userPass[index+1:] + return prefix + ":" + strings.Repeat("*", len(pass)) + "@" + postfix, true +} + +// URL param "?password=password" style +func tryRedactURLParams(conn string) (string, bool) { prefix, paramString, found := strings.Cut(conn, "?") if !found { - return conn + return "", false } var newParams []string + found = false for _, param := range strings.Split(paramString, "&") { key, val, _ := strings.Cut(param, "=") if strings.Contains(strings.ToLower(key), "pass") { newParams = append(newParams, key+"="+strings.Repeat("*", len(val))) + found = true + continue + } + newParams = append(newParams, param) + } + if !found { + return "", false + } + return prefix + "?" + strings.Join(newParams, "&"), true +} + +// ODBC params ";password=password" style +func tryRedactODBC(conn string) (string, bool) { + var found bool + var newParams []string + for _, param := range strings.Split(conn, ";") { + key, val, _ := strings.Cut(param, "=") + if strings.Contains(strings.ToLower(key), "pass") { + newParams = append(newParams, key+"="+strings.Repeat("*", len(val))) + found = true continue } newParams = append(newParams, param) } - return prefix + "?" + strings.Join(newParams, "&") + if !found { + return "", false + } + return strings.Join(newParams, ";"), true +} + +// Naively search the string for "pass=" +func tryRedactRegex(conn string) (string, bool) { + pattern := regexp.MustCompile(`(?i)pass.*?=(.+?)\b`) + var found bool + newConn := pattern.ReplaceAllStringFunc(conn, func(s string) string { + index := strings.Index(s, "=") + if index == -1 { + // unreachable due to regex containing '=' + return s + } + found = true + return s[:index+1] + strings.Repeat("*", len(s[index+1:])) + }) + if !found { + return "", false + } + return newConn, true } var supportedSubprotocols = map[string]func(string) (jdbc, error){ diff --git a/pkg/detectors/jdbc/jdbc_integration_test.go b/pkg/detectors/jdbc/jdbc_integration_test.go index 55c0bdf49349..0c69571f3f07 100644 --- a/pkg/detectors/jdbc/jdbc_integration_test.go +++ b/pkg/detectors/jdbc/jdbc_integration_test.go @@ -27,14 +27,19 @@ func TestMain(m *testing.M) { } func runMain(m *testing.M) (int, error) { - if err := startPostgres(); err != nil { - return 0, err - } - defer stopPostgres() - if err := startMySQL(); err != nil { - return 0, err + for _, ctrl := range []struct { + start func() error + stop func() + }{ + {startPostgres, stopPostgres}, + {startMySQL, stopMySQL}, + {startSqlServer, stopSqlServer}, + } { + if err := ctrl.start(); err != nil { + return 0, err + } + defer ctrl.stop() } - defer stopMySQL() return m.Run(), nil } diff --git a/pkg/detectors/jdbc/jdbc_test.go b/pkg/detectors/jdbc/jdbc_test.go index 76a33d181572..8b2f3806aa05 100644 --- a/pkg/detectors/jdbc/jdbc_test.go +++ b/pkg/detectors/jdbc/jdbc_test.go @@ -115,6 +115,22 @@ func TestJdbc_FromChunk(t *testing.T) { }, wantErr: false, }, + { + name: "sqlserver, unverified", + args: args{ + ctx: context.Background(), + data: []byte(`jdbc:sqlserver://a.b.c.net;database=database-name;spring.datasource.password=super-secret-password`), + verify: false, + }, + want: []detectors.Result{ + { + DetectorType: detectorspb.DetectorType_JDBC, + Verified: false, + Redacted: "jdbc:sqlserver://a.b.c.net;database=database-name;spring.datasource.password=*********************", + }, + }, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/detectors/jdbc/postgres.go b/pkg/detectors/jdbc/postgres.go index 163e14d09c06..9136879bd26e 100644 --- a/pkg/detectors/jdbc/postgres.go +++ b/pkg/detectors/jdbc/postgres.go @@ -38,7 +38,7 @@ func (s *postgresJDBC) ping(ctx context.Context) bool { } data[key] = val } - if ping(ctx, "postgres", joinKeyValues(data)) { + if ping(ctx, "postgres", joinKeyValues(data, " ")) { return true } if s.params["dbname"] != "" { @@ -48,7 +48,7 @@ func (s *postgresJDBC) ping(ctx context.Context) bool { return false } -func joinKeyValues(m map[string]string) string { +func joinKeyValues(m map[string]string, sep string) string { var data []string for k, v := range m { if v == "" { @@ -56,7 +56,7 @@ func joinKeyValues(m map[string]string) string { } data = append(data, fmt.Sprintf("%s=%s", k, v)) } - return strings.Join(data, " ") + return strings.Join(data, sep) } func parsePostgres(subname string) (jdbc, error) { diff --git a/pkg/detectors/jdbc/sqlserver.go b/pkg/detectors/jdbc/sqlserver.go index 3f545a3b7d07..ab3506a6b923 100644 --- a/pkg/detectors/jdbc/sqlserver.go +++ b/pkg/detectors/jdbc/sqlserver.go @@ -2,24 +2,49 @@ package jdbc import ( "context" + "errors" "strings" _ "github.com/denisenkom/go-mssqldb" ) type sqlServerJDBC struct { - conn string + conn string + params map[string]string } func (s *sqlServerJDBC) ping(ctx context.Context) bool { if ping(ctx, "mssql", s.conn) { return true } + if ping(ctx, "mssql", joinKeyValues(s.params, ";")) { + return true + } // try URL format return ping(ctx, "mssql", "sqlserver://"+s.conn) } func parseSqlServer(subname string) (jdbc, error) { - // expected form: //[username:password@]host/instance[?key=val[&key=val]] - return &sqlServerJDBC{strings.TrimPrefix(subname, "//")}, nil + if !strings.HasPrefix(subname, "//") { + return nil, errors.New("expected connection to start with //") + } + conn := strings.TrimPrefix(subname, "//") + params := map[string]string{ + "user id": "sa", + "database": "master", + } + for _, param := range strings.Split(conn, ";") { + key, value, found := strings.Cut(param, "=") + if !found { + continue + } + params[key] = value + if key != "password" && strings.Contains(strings.ToLower(key), "password") { + params["password"] = value + } + } + return &sqlServerJDBC{ + conn: strings.TrimPrefix(subname, "//"), + params: params, + }, nil } diff --git a/pkg/detectors/jdbc/sqlserver_integration_test.go b/pkg/detectors/jdbc/sqlserver_integration_test.go new file mode 100644 index 000000000000..590e37607116 --- /dev/null +++ b/pkg/detectors/jdbc/sqlserver_integration_test.go @@ -0,0 +1,80 @@ +//go:build detectors && integration +// +build detectors,integration + +package jdbc + +import ( + "bytes" + "context" + "errors" + "os/exec" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +const ( + sqlServerPass = "Secr3tP@s5w0rd" + sqlServerUser = "sa" + sqlServerDatabase = "master" +) + +func TestSqlServer(t *testing.T) { + tests := []struct { + input string + wantErr bool + wantPing bool + }{ + { + input: "", + wantErr: true, + }, + { + input: "//odbc:server=localhost;user id=sa;database=master;password=" + sqlServerPass, + wantPing: true, + }, + { + input: "//localhost;database= master;spring.datasource.password=" + sqlServerPass, + wantPing: true, + }, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + j, err := parseSqlServer(tt.input) + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.wantPing, j.ping(context.Background())) + }) + } +} + +var sqlServerDockerHash string + +func startSqlServer() error { + cmd := exec.Command( + "docker", "run", "--rm", "-p", "1433:1433", + "-e", "ACCEPT_EULA=1", + "-e", "MSSQL_SA_PASSWORD="+sqlServerPass, + "-d", "mcr.microsoft.com/azure-sql-edge", + ) + out, err := cmd.Output() + if err != nil { + return err + } + sqlServerDockerHash = string(bytes.TrimSpace(out)) + select { + case <-dockerLogLine(sqlServerDockerHash, "EdgeTelemetry starting up"): + return nil + case <-time.After(30 * time.Second): + stopSqlServer() + return errors.New("timeout waiting for mysql database to be ready") + } +} + +func stopSqlServer() { + exec.Command("docker", "kill", sqlServerDockerHash).Run() +}