Skip to content

Commit

Permalink
Replace unsafe usage of recover() in helper functions (#4913)
Browse files Browse the repository at this point in the history
We decided that our usage of helper functions when using defer/recover was too brittle, since one
too many functions causes recover() not to do the right thing in Go. Switching to this idiom, where
the defer caller also calls recover() ensures that we will get the right result from recover(), but
also lets us continue to use the helper functions to deduplicate the code that runs after recover().
  • Loading branch information
ZackLK committed Aug 22, 2022
1 parent 20adf51 commit 552add5
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 217 deletions.
6 changes: 3 additions & 3 deletions client/history/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ func (c *clientImpl) GetReplicationMessages(
for peer, req := range requestsByPeer {
peer, req := peer, req
g.Go(func() (e error) {
defer log.CapturePanic(c.logger, &e)
defer func() { log.CapturePanic(recover(), c.logger, &e) }()

requestContext, cancel := common.CreateChildContext(ctx, 0.05)
defer cancel()
Expand Down Expand Up @@ -939,7 +939,7 @@ func (c *clientImpl) CountDLQMessages(
for _, peer := range peers {
peer := peer
g.Go(func() (e error) {
defer log.CapturePanic(c.logger, &e)
defer func() { log.CapturePanic(recover(), c.logger, &e) }()

response, err := c.client.CountDLQMessages(ctx, request, append(opts, yarpc.WithShardKey(peer))...)
if err == nil {
Expand Down Expand Up @@ -1047,7 +1047,7 @@ func (c *clientImpl) NotifyFailoverMarkers(
for peer, req := range requestsByPeer {
peer, req := peer, req
g.Go(func() (e error) {
defer log.CapturePanic(c.logger, &e)
defer func() { log.CapturePanic(recover(), c.logger, &e) }()

ctx, cancel := c.createContext(ctx)
defer cancel()
Expand Down
9 changes: 5 additions & 4 deletions common/log/panic.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ import (
// If the panic value is not error then a default error is returned
// We have to use pointer is because in golang: "recover return nil if was not called directly by a deferred function."
// And we have to set the returned error otherwise our handler will return nil as error which is incorrect
// NOTE: this function MUST be called in a deferred function
func CapturePanic(logger Logger, retError *error) {
// revive:disable-next-line:defer Caller must call from a deferred function
if errPanic := recover(); errPanic != nil {
// errPanic MUST be the result from calling recover, which MUST be done in a single level deep
// deferred function. The usual way of calling this is:
// - defer func() { log.CapturePanic(recover(), logger, &err) }()
func CapturePanic(errPanic interface{}, logger Logger, retError *error) {
if errPanic != nil {
err, ok := errPanic.(error)
if !ok {
err = fmt.Errorf("panic object is not error: %#v", errPanic)
Expand Down
46 changes: 22 additions & 24 deletions common/persistence/sql/sqlExecutionStore.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,9 @@ func (m *sqlExecutionStore) GetWorkflowExecution(
ctx context.Context,
request *p.InternalGetWorkflowExecutionRequest,
) (resp *p.InternalGetWorkflowExecutionResponse, e error) {
recoverPanic := func(err *error) {
// revive:disable-next-line:defer Func is being called using defer().
if r := recover(); r != nil {
*err = fmt.Errorf("DB operation panicked: %v %s", r, debug.Stack())
recoverPanic := func(recovered interface{}, err *error) {
if recovered != nil {
*err = fmt.Errorf("DB operation panicked: %v %s", recovered, debug.Stack())
}
}

Expand All @@ -291,55 +290,55 @@ func (m *sqlExecutionStore) GetWorkflowExecution(
g, ctx := errgroup.WithContext(ctx)

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
executions, e = m.getExecutions(ctx, request, domainID, wfID, runID)
return e
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
activityInfos, e = getActivityInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
timerInfos, e = getTimerInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
childExecutionInfos, e = getChildExecutionInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
requestCancelInfos, e = getRequestCancelInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
signalInfos, e = getSignalInfoMap(
ctx, m.db, m.shardID, domainID, wfID, runID, m.parser)
return e
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
bufferedEvents, e = getBufferedEvents(
ctx, m.db, m.shardID, domainID, wfID, runID)
return e
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
signalsRequested, e = getSignalsRequested(
ctx, m.db, m.shardID, domainID, wfID, runID)
return e
Expand Down Expand Up @@ -619,10 +618,9 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
ctx context.Context,
request *p.DeleteWorkflowExecutionRequest,
) error {
recoverPanic := func(err *error) {
// revive:disable-next-line:defer Func is being called using defer().
if r := recover(); r != nil {
*err = fmt.Errorf("DB operation panicked: %v %s", r, debug.Stack())
recoverPanic := func(recovered interface{}, err *error) {
if recovered != nil {
*err = fmt.Errorf("DB operation panicked: %v %s", recovered, debug.Stack())
}
}
domainID := serialization.MustParseUUID(request.DomainID)
Expand All @@ -631,7 +629,7 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
g, ctx := errgroup.WithContext(ctx)

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
_, e = m.db.DeleteFromExecutions(ctx, &sqlplugin.ExecutionsFilter{
ShardID: m.shardID,
DomainID: domainID,
Expand All @@ -642,7 +640,7 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
_, e = m.db.DeleteFromActivityInfoMaps(ctx, &sqlplugin.ActivityInfoMapsFilter{
ShardID: int64(m.shardID),
DomainID: domainID,
Expand All @@ -653,7 +651,7 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
_, e = m.db.DeleteFromTimerInfoMaps(ctx, &sqlplugin.TimerInfoMapsFilter{
ShardID: int64(m.shardID),
DomainID: domainID,
Expand All @@ -664,7 +662,7 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
_, e = m.db.DeleteFromChildExecutionInfoMaps(ctx, &sqlplugin.ChildExecutionInfoMapsFilter{
ShardID: int64(m.shardID),
DomainID: domainID,
Expand All @@ -675,7 +673,7 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
_, e = m.db.DeleteFromRequestCancelInfoMaps(ctx, &sqlplugin.RequestCancelInfoMapsFilter{
ShardID: int64(m.shardID),
DomainID: domainID,
Expand All @@ -686,7 +684,7 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
_, e = m.db.DeleteFromSignalInfoMaps(ctx, &sqlplugin.SignalInfoMapsFilter{
ShardID: int64(m.shardID),
DomainID: domainID,
Expand All @@ -697,7 +695,7 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
_, e = m.db.DeleteFromBufferedEvents(ctx, &sqlplugin.BufferedEventsFilter{
ShardID: m.shardID,
DomainID: domainID,
Expand All @@ -708,7 +706,7 @@ func (m *sqlExecutionStore) DeleteWorkflowExecution(
})

g.Go(func() (e error) {
defer recoverPanic(&e)
defer func() { recoverPanic(recover(), &e) }()
_, e = m.db.DeleteFromSignalsRequestedSets(ctx, &sqlplugin.SignalsRequestedSetsFilter{
ShardID: int64(m.shardID),
DomainID: domainID,
Expand Down

0 comments on commit 552add5

Please sign in to comment.