From 4184e720254f77e0a7e9c77b2a78d2aa94f99aaa Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Wed, 19 Jun 2019 09:40:57 -0400 Subject: [PATCH] Fix a deadlock if a panic happens during request handling (#6920) * Fix a deadlock if a panic happens during request handling During request handling, if a panic is created, deferred functions are run but otherwise execution stops. #5889 changed some locks to non-defers but had the side effect of causing the read lock to not be released if the request panicked. This fixes that and addresses a few other potential places where things could go wrong: 1) In sealInitCommon we always now defer a function that unlocks the read lock if it hasn't been unlocked already 2) In StepDown we defer the RUnlock but we also had two error cases that were calling it manually. These are unlikely to be hit but if they were I believe would cause a panic. * Add panic recovery test --- vault/core.go | 17 ++- .../misc/recover_from_panic_test.go | 49 ++++++++ vault/ha.go | 3 +- vault/request_handling.go | 10 +- vault/router_test.go | 101 --------------- vault/router_testing.go | 115 ++++++++++++++++++ 6 files changed, 174 insertions(+), 121 deletions(-) create mode 100644 vault/external_tests/misc/recover_from_panic_test.go create mode 100644 vault/router_testing.go diff --git a/vault/core.go b/vault/core.go index 755b368303bef..8cbfdfdf537c2 100644 --- a/vault/core.go +++ b/vault/core.go @@ -1193,9 +1193,15 @@ func (c *Core) Seal(token string) error { func (c *Core) sealInitCommon(ctx context.Context, req *logical.Request) (retErr error) { defer metrics.MeasureSince([]string{"core", "seal-internal"}, time.Now()) + var unlocked bool + defer func() { + if !unlocked { + c.stateLock.RUnlock() + } + }() + if req == nil { retErr = multierror.Append(retErr, errors.New("nil request to seal")) - c.stateLock.RUnlock() return retErr } @@ -1207,14 +1213,12 @@ func (c *Core) sealInitCommon(ctx context.Context, req *logical.Request) (retErr if c.standby { c.logger.Error("vault cannot seal when in standby mode; please restart instead") retErr = multierror.Append(retErr, errors.New("vault cannot seal when in standby mode; please restart instead")) - c.stateLock.RUnlock() return retErr } acl, te, entity, identityPolicies, err := c.fetchACLTokenEntryAndEntity(ctx, req) if err != nil { retErr = multierror.Append(retErr, err) - c.stateLock.RUnlock() return retErr } @@ -1242,20 +1246,17 @@ func (c *Core) sealInitCommon(ctx context.Context, req *logical.Request) (retErr if err := c.auditBroker.LogRequest(ctx, logInput, c.auditedHeaders); err != nil { c.logger.Error("failed to audit request", "request_path", req.Path, "error", err) retErr = multierror.Append(retErr, errors.New("failed to audit request, cannot continue")) - c.stateLock.RUnlock() return retErr } if entity != nil && entity.Disabled { c.logger.Warn("permission denied as the entity on the token is disabled") retErr = multierror.Append(retErr, logical.ErrPermissionDenied) - c.stateLock.RUnlock() return retErr } if te != nil && te.EntityID != "" && entity == nil { c.logger.Warn("permission denied as the entity on the token is invalid") retErr = multierror.Append(retErr, logical.ErrPermissionDenied) - c.stateLock.RUnlock() return retErr } @@ -1266,13 +1267,11 @@ func (c *Core) sealInitCommon(ctx context.Context, req *logical.Request) (retErr if err != nil { c.logger.Error("failed to use token", "error", err) retErr = multierror.Append(retErr, ErrInternalError) - c.stateLock.RUnlock() return retErr } if te == nil { // Token is no longer valid retErr = multierror.Append(retErr, logical.ErrPermissionDenied) - c.stateLock.RUnlock() return retErr } } @@ -1282,7 +1281,6 @@ func (c *Core) sealInitCommon(ctx context.Context, req *logical.Request) (retErr RootPrivsRequired: true, }) if !authResults.Allowed { - c.stateLock.RUnlock() retErr = multierror.Append(retErr, authResults.Error) if authResults.Error.ErrorOrNil() == nil || authResults.DeniedError { retErr = multierror.Append(retErr, logical.ErrPermissionDenied) @@ -1304,6 +1302,7 @@ func (c *Core) sealInitCommon(ctx context.Context, req *logical.Request) (retErr } // Unlock; sealing will grab the lock when needed + unlocked = true c.stateLock.RUnlock() sealErr := c.sealInternal() diff --git a/vault/external_tests/misc/recover_from_panic_test.go b/vault/external_tests/misc/recover_from_panic_test.go new file mode 100644 index 0000000000000..157a4e969494a --- /dev/null +++ b/vault/external_tests/misc/recover_from_panic_test.go @@ -0,0 +1,49 @@ +package token + +import ( + "testing" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault" +) + +// Tests the regression in +// https://github.com/hashicorp/vault/pull/6920 +func TestRecoverFromPanic(t *testing.T) { + logger := hclog.New(nil) + + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "noop": vault.NoopBackendFactory, + }, + EnableRaw: true, + Logger: logger, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + core := cluster.Cores[0] + vault.TestWaitActive(t, core.Core) + client := core.Client + + err := client.Sys().Mount("noop", &api.MountInput{ + Type: "noop", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Read("noop/panic") + if err == nil { + t.Fatal("expected error") + } + + // This will deadlock the test if we hit the condition + cluster.EnsureCoresSealed(t) +} diff --git a/vault/ha.go b/vault/ha.go index 39fb74d48f6cc..fd2ca1b4864d3 100644 --- a/vault/ha.go +++ b/vault/ha.go @@ -207,6 +207,7 @@ func (c *Core) StepDown(httpCtx context.Context, req *logical.Request) (retErr e c.stateLock.RLock() defer c.stateLock.RUnlock() + if c.Sealed() { return nil } @@ -261,14 +262,12 @@ func (c *Core) StepDown(httpCtx context.Context, req *logical.Request) (retErr e if entity != nil && entity.Disabled { c.logger.Warn("permission denied as the entity on the token is disabled") retErr = multierror.Append(retErr, logical.ErrPermissionDenied) - c.stateLock.RUnlock() return retErr } if te != nil && te.EntityID != "" && entity == nil { c.logger.Warn("permission denied as the entity on the token is invalid") retErr = multierror.Append(retErr, logical.ErrPermissionDenied) - c.stateLock.RUnlock() return retErr } diff --git a/vault/request_handling.go b/vault/request_handling.go index 41aac1a309f04..6d45d8e163fed 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -385,18 +385,12 @@ func (c *Core) HandleRequest(httpCtx context.Context, req *logical.Request) (res func (c *Core) switchedLockHandleRequest(httpCtx context.Context, req *logical.Request, doLocking bool) (resp *logical.Response, err error) { if doLocking { c.stateLock.RLock() - } - unlockFunc := func() { - if doLocking { - c.stateLock.RUnlock() - } + defer c.stateLock.RUnlock() } if c.Sealed() { - unlockFunc() return nil, consts.ErrSealed } if c.standby && !c.perfStandby { - unlockFunc() return nil, consts.ErrStandby } @@ -412,7 +406,6 @@ func (c *Core) switchedLockHandleRequest(httpCtx context.Context, req *logical.R ns, err := namespace.FromContext(httpCtx) if err != nil { cancel() - unlockFunc() return nil, errwrap.Wrapf("could not parse namespace from http context: {{err}}", err) } ctx = namespace.ContextWithNamespace(ctx, ns) @@ -421,7 +414,6 @@ func (c *Core) switchedLockHandleRequest(httpCtx context.Context, req *logical.R req.SetTokenEntry(nil) cancel() - unlockFunc() return resp, err } diff --git a/vault/router_test.go b/vault/router_test.go index 29f9ae0a7f629..b20b69894fe4f 100644 --- a/vault/router_test.go +++ b/vault/router_test.go @@ -1,116 +1,15 @@ package vault import ( - "context" - "fmt" "reflect" "strings" - "sync" "testing" - "time" - log "github.com/hashicorp/go-hclog" uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/logical" ) -type HandlerFunc func(context.Context, *logical.Request) (*logical.Response, error) - -type NoopBackend struct { - sync.Mutex - - Root []string - Login []string - Paths []string - Requests []*logical.Request - Response *logical.Response - RequestHandler HandlerFunc - Invalidations []string - DefaultLeaseTTL time.Duration - MaxLeaseTTL time.Duration - BackendType logical.BackendType -} - -func (n *NoopBackend) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) { - if req.TokenEntry() != nil { - panic("got a non-nil TokenEntry") - } - - var err error - resp := n.Response - if n.RequestHandler != nil { - resp, err = n.RequestHandler(ctx, req) - } - - n.Lock() - defer n.Unlock() - - requestCopy := *req - n.Paths = append(n.Paths, req.Path) - n.Requests = append(n.Requests, &requestCopy) - if req.Storage == nil { - return nil, fmt.Errorf("missing view") - } - - return resp, err -} - -func (n *NoopBackend) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) { - return false, false, nil -} - -func (n *NoopBackend) SpecialPaths() *logical.Paths { - return &logical.Paths{ - Root: n.Root, - Unauthenticated: n.Login, - } -} - -func (n *NoopBackend) System() logical.SystemView { - defaultLeaseTTLVal := time.Hour * 24 - maxLeaseTTLVal := time.Hour * 24 * 32 - if n.DefaultLeaseTTL > 0 { - defaultLeaseTTLVal = n.DefaultLeaseTTL - } - - if n.MaxLeaseTTL > 0 { - maxLeaseTTLVal = n.MaxLeaseTTL - } - - return logical.StaticSystemView{ - DefaultLeaseTTLVal: defaultLeaseTTLVal, - MaxLeaseTTLVal: maxLeaseTTLVal, - } -} - -func (n *NoopBackend) Cleanup(ctx context.Context) { - // noop -} - -func (n *NoopBackend) InvalidateKey(ctx context.Context, k string) { - n.Invalidations = append(n.Invalidations, k) -} - -func (n *NoopBackend) Setup(ctx context.Context, config *logical.BackendConfig) error { - return nil -} - -func (n *NoopBackend) Logger() log.Logger { - return log.NewNullLogger() -} - -func (n *NoopBackend) Initialize(ctx context.Context) error { - return nil -} - -func (n *NoopBackend) Type() logical.BackendType { - if n.BackendType == logical.TypeUnknown { - return logical.TypeLogical - } - return n.BackendType -} - func TestRouter_Mount(t *testing.T) { r := NewRouter() _, barrier, _ := mockBarrier(t) diff --git a/vault/router_testing.go b/vault/router_testing.go new file mode 100644 index 0000000000000..bc287806eda10 --- /dev/null +++ b/vault/router_testing.go @@ -0,0 +1,115 @@ +package vault + +import ( + "context" + "fmt" + "sync" + "time" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/logical" +) + +type RouterTestHandlerFunc func(context.Context, *logical.Request) (*logical.Response, error) + +type NoopBackend struct { + sync.Mutex + + Root []string + Login []string + Paths []string + Requests []*logical.Request + Response *logical.Response + RequestHandler RouterTestHandlerFunc + Invalidations []string + DefaultLeaseTTL time.Duration + MaxLeaseTTL time.Duration + BackendType logical.BackendType +} + +func NoopBackendFactory(_ context.Context, _ *logical.BackendConfig) (logical.Backend, error) { + return &NoopBackend{}, nil +} + +func (n *NoopBackend) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) { + if req.TokenEntry() != nil { + panic("got a non-nil TokenEntry") + } + + var err error + resp := n.Response + if n.RequestHandler != nil { + resp, err = n.RequestHandler(ctx, req) + } + + n.Lock() + defer n.Unlock() + + requestCopy := *req + n.Paths = append(n.Paths, req.Path) + n.Requests = append(n.Requests, &requestCopy) + if req.Storage == nil { + return nil, fmt.Errorf("missing view") + } + + if req.Path == "panic" { + panic("as you command") + } + + return resp, err +} + +func (n *NoopBackend) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) { + return false, false, nil +} + +func (n *NoopBackend) SpecialPaths() *logical.Paths { + return &logical.Paths{ + Root: n.Root, + Unauthenticated: n.Login, + } +} + +func (n *NoopBackend) System() logical.SystemView { + defaultLeaseTTLVal := time.Hour * 24 + maxLeaseTTLVal := time.Hour * 24 * 32 + if n.DefaultLeaseTTL > 0 { + defaultLeaseTTLVal = n.DefaultLeaseTTL + } + + if n.MaxLeaseTTL > 0 { + maxLeaseTTLVal = n.MaxLeaseTTL + } + + return logical.StaticSystemView{ + DefaultLeaseTTLVal: defaultLeaseTTLVal, + MaxLeaseTTLVal: maxLeaseTTLVal, + } +} + +func (n *NoopBackend) Cleanup(ctx context.Context) { + // noop +} + +func (n *NoopBackend) InvalidateKey(ctx context.Context, k string) { + n.Invalidations = append(n.Invalidations, k) +} + +func (n *NoopBackend) Setup(ctx context.Context, config *logical.BackendConfig) error { + return nil +} + +func (n *NoopBackend) Logger() log.Logger { + return log.NewNullLogger() +} + +func (n *NoopBackend) Initialize(ctx context.Context) error { + return nil +} + +func (n *NoopBackend) Type() logical.BackendType { + if n.BackendType == logical.TypeUnknown { + return logical.TypeLogical + } + return n.BackendType +}