Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Agent auto auth to read symlinked JWT files #11502

Merged
merged 4 commits into from May 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 14 additions & 3 deletions command/agent/auth/jwt/jwt.go
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io/fs"
"io/ioutil"
"net/http"
"os"
Expand Down Expand Up @@ -31,6 +32,8 @@ type jwtMethod struct {
latestToken *atomic.Value
}

// NewJWTAuthMethod returns an implementation of Agent's auth.AuthMethod
// interface for JWT auth.
func NewJWTAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
if conf == nil {
return nil, errors.New("empty config")
Expand Down Expand Up @@ -86,7 +89,7 @@ func NewJWTAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
return j, nil
}

func (j *jwtMethod) Authenticate(_ context.Context, client *api.Client) (string, http.Header, map[string]interface{}, error) {
func (j *jwtMethod) Authenticate(_ context.Context, _ *api.Client) (string, http.Header, map[string]interface{}, error) {
j.logger.Trace("beginning authentication")

j.ingressToken()
Expand Down Expand Up @@ -160,8 +163,16 @@ func (j *jwtMethod) ingressToken() {

j.logger.Debug("new jwt file found")

if !fi.Mode().IsRegular() {
j.logger.Error("jwt file is not a regular file")
// Check that the path refers to a file.
// If it's a symlink, it could still be a symlink to a directory,
// but ioutil.ReadFile below will return a descriptive error.
switch mode := fi.Mode(); {
case mode.IsRegular():
// regular file
case mode&fs.ModeSymlink != 0:
// symlink
default:
j.logger.Error("jwt file is not a regular file or symlink")
return
}

Expand Down
108 changes: 108 additions & 0 deletions command/agent/auth/jwt/jwt_test.go
@@ -0,0 +1,108 @@
package jwt

import (
"bytes"
"io/ioutil"
"os"
"path"
"strings"
"sync/atomic"
"testing"

"github.com/hashicorp/go-hclog"
)

func TestIngressToken(t *testing.T) {
const (
dir = "dir"
file = "file"
empty = "empty"
missing = "missing"
symlinked = "symlinked"
)

rootDir, err := ioutil.TempDir("", "vault-agent-jwt-auth-test")
if err != nil {
t.Fatalf("failed to create temp dir: %s", err)
}
defer os.RemoveAll(rootDir)

setupTestDir := func() string {
testDir, err := ioutil.TempDir(rootDir, "")
if err != nil {
t.Fatal(err)
}
err = ioutil.WriteFile(path.Join(testDir, file), []byte("test"), 0644)
if err != nil {
t.Fatal(err)
}
_, err = os.Create(path.Join(testDir, empty))
if err != nil {
t.Fatal(err)
}
err = os.Mkdir(path.Join(testDir, dir), 0755)
if err != nil {
t.Fatal(err)
}
err = os.Symlink(path.Join(testDir, file), path.Join(testDir, symlinked))
if err != nil {
t.Fatal(err)
}

return testDir
}

for _, tc := range []struct {
name string
path string
errString string
}{
{
"happy path",
file,
"",
},
{
"path is directory",
dir,
"[ERROR] jwt file is not a regular file or symlink",
},
{
"path is symlink",
symlinked,
"",
},
{
"path is missing (implies nothing for ingressToken to do)",
missing,
"",
},
{
"path is empty file",
empty,
"[WARN] empty jwt file read",
},
} {
testDir := setupTestDir()
logBuffer := bytes.Buffer{}
jwtAuth := &jwtMethod{
logger: hclog.New(&hclog.LoggerOptions{
Output: &logBuffer,
}),
latestToken: new(atomic.Value),
path: path.Join(testDir, tc.path),
}

jwtAuth.ingressToken()

if tc.errString != "" {
if !strings.Contains(logBuffer.String(), tc.errString) {
t.Fatal("logs did no contain expected error", tc.errString, logBuffer.String())
}
} else {
if strings.Contains(logBuffer.String(), "[ERROR]") || strings.Contains(logBuffer.String(), "[WARN]") {
t.Fatal("logs contained unexpected error", logBuffer.String())
}
}
}
}