From d60b698c2ded0acc679d8943529d5b80c7a3e908 Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Thu, 6 May 2021 14:11:57 +0100 Subject: [PATCH] Allow Agent auto auth to read symlinked JWT files (#11502) --- command/agent/auth/jwt/jwt.go | 17 ++++- command/agent/auth/jwt/jwt_test.go | 108 +++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 3 deletions(-) create mode 100644 command/agent/auth/jwt/jwt_test.go diff --git a/command/agent/auth/jwt/jwt.go b/command/agent/auth/jwt/jwt.go index 403a57fae1dcd..0c97bee905ec3 100644 --- a/command/agent/auth/jwt/jwt.go +++ b/command/agent/auth/jwt/jwt.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io/fs" "io/ioutil" "net/http" "os" @@ -31,6 +32,8 @@ type jwtMethod struct { latestToken *atomic.Value } +// NewJWTAuthMethod returns an implementation of Agent's auth.AuthMethod +// interface for JWT auth. func NewJWTAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) { if conf == nil { return nil, errors.New("empty config") @@ -86,7 +89,7 @@ func NewJWTAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) { return j, nil } -func (j *jwtMethod) Authenticate(_ context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) { +func (j *jwtMethod) Authenticate(_ context.Context, _ *api.Client) (string, http.Header, map[string]interface{}, error) { j.logger.Trace("beginning authentication") j.ingressToken() @@ -160,8 +163,16 @@ func (j *jwtMethod) ingressToken() { j.logger.Debug("new jwt file found") - if !fi.Mode().IsRegular() { - j.logger.Error("jwt file is not a regular file") + // Check that the path refers to a file. + // If it's a symlink, it could still be a symlink to a directory, + // but ioutil.ReadFile below will return a descriptive error. + switch mode := fi.Mode(); { + case mode.IsRegular(): + // regular file + case mode&fs.ModeSymlink != 0: + // symlink + default: + j.logger.Error("jwt file is not a regular file or symlink") return } diff --git a/command/agent/auth/jwt/jwt_test.go b/command/agent/auth/jwt/jwt_test.go new file mode 100644 index 0000000000000..ef33bfb7e720b --- /dev/null +++ b/command/agent/auth/jwt/jwt_test.go @@ -0,0 +1,108 @@ +package jwt + +import ( + "bytes" + "io/ioutil" + "os" + "path" + "strings" + "sync/atomic" + "testing" + + "github.com/hashicorp/go-hclog" +) + +func TestIngressToken(t *testing.T) { + const ( + dir = "dir" + file = "file" + empty = "empty" + missing = "missing" + symlinked = "symlinked" + ) + + rootDir, err := ioutil.TempDir("", "vault-agent-jwt-auth-test") + if err != nil { + t.Fatalf("failed to create temp dir: %s", err) + } + defer os.RemoveAll(rootDir) + + setupTestDir := func() string { + testDir, err := ioutil.TempDir(rootDir, "") + if err != nil { + t.Fatal(err) + } + err = ioutil.WriteFile(path.Join(testDir, file), []byte("test"), 0644) + if err != nil { + t.Fatal(err) + } + _, err = os.Create(path.Join(testDir, empty)) + if err != nil { + t.Fatal(err) + } + err = os.Mkdir(path.Join(testDir, dir), 0755) + if err != nil { + t.Fatal(err) + } + err = os.Symlink(path.Join(testDir, file), path.Join(testDir, symlinked)) + if err != nil { + t.Fatal(err) + } + + return testDir + } + + for _, tc := range []struct { + name string + path string + errString string + }{ + { + "happy path", + file, + "", + }, + { + "path is directory", + dir, + "[ERROR] jwt file is not a regular file or symlink", + }, + { + "path is symlink", + symlinked, + "", + }, + { + "path is missing (implies nothing for ingressToken to do)", + missing, + "", + }, + { + "path is empty file", + empty, + "[WARN] empty jwt file read", + }, + } { + testDir := setupTestDir() + logBuffer := bytes.Buffer{} + jwtAuth := &jwtMethod{ + logger: hclog.New(&hclog.LoggerOptions{ + Output: &logBuffer, + }), + latestToken: new(atomic.Value), + path: path.Join(testDir, tc.path), + } + + jwtAuth.ingressToken() + + if tc.errString != "" { + if !strings.Contains(logBuffer.String(), tc.errString) { + t.Fatal("logs did no contain expected error", tc.errString, logBuffer.String()) + } + } else { + if strings.Contains(logBuffer.String(), "[ERROR]") || strings.Contains(logBuffer.String(), "[WARN]") { + t.Fatal("logs contained unexpected error", logBuffer.String()) + } + } + } +}