diff --git a/server/etcdserver/api/v3rpc/grpc.go b/server/etcdserver/api/v3rpc/grpc.go index ea3dd75705fd..409a1c39a988 100644 --- a/server/etcdserver/api/v3rpc/grpc.go +++ b/server/etcdserver/api/v3rpc/grpc.go @@ -36,7 +36,7 @@ const ( maxSendBytes = math.MaxInt32 ) -func Server(s *etcdserver.EtcdServer, tls *tls.Config, interceptor grpc.UnaryServerInterceptor, gopts ...grpc.ServerOption) *grpc.Server { +func Server(s *etcdserver.EtcdServer, tls *tls.Config, interceptors []grpc.UnaryServerInterceptor, gopts ...grpc.ServerOption) *grpc.Server { var opts []grpc.ServerOption opts = append(opts, grpc.CustomCodec(&codec{})) if tls != nil { @@ -48,8 +48,8 @@ func Server(s *etcdserver.EtcdServer, tls *tls.Config, interceptor grpc.UnarySer newUnaryInterceptor(s), grpc_prometheus.UnaryServerInterceptor, } - if interceptor != nil { - chainUnaryInterceptors = append(chainUnaryInterceptors, interceptor) + if interceptors != nil { + chainUnaryInterceptors = append(chainUnaryInterceptors, interceptors...) } chainStreamInterceptors := []grpc.StreamServerInterceptor{ diff --git a/server/etcdserver/apply/apply.go b/server/etcdserver/apply/apply.go index 9fe77e91f4c7..2af0ef86a497 100644 --- a/server/etcdserver/apply/apply.go +++ b/server/etcdserver/apply/apply.go @@ -169,7 +169,7 @@ func (a *applierV3backend) Range(ctx context.Context, txn mvcc.TxnRead, r *pb.Ra } func (a *applierV3backend) Txn(ctx context.Context, rt *pb.TxnRequest) (*pb.TxnResponse, *traceutil.Trace, error) { - return mvcctxn.Txn(ctx, a.lg, rt, a.txnModeWriteWithSharedBuffer, a.kv, a.lessor) + return mvcctxn.Txn(ctx, a.lg, rt, a.txnModeWriteWithSharedBuffer, a.kv, a.lessor, mvcctxn.PanicErrHandler) } func (a *applierV3backend) Compaction(compaction *pb.CompactionRequest) (*pb.CompactionResponse, <-chan struct{}, *traceutil.Trace, error) { diff --git a/server/etcdserver/txn/txn.go b/server/etcdserver/txn/txn.go index 36782d34b61c..7e77000a7080 100644 --- a/server/etcdserver/txn/txn.go +++ b/server/etcdserver/txn/txn.go @@ -29,6 +29,17 @@ import ( "go.uber.org/zap" ) +type CriticalErrHandler func(lg *zap.Logger, err error) error + +func PanicErrHandler(lg *zap.Logger, err error) error { + lg.Panic("unexpected error during txnWrite", zap.Error(err)) + return err +} + +func PassthroughErrHandler(_ *zap.Logger, err error) error { + return err +} + func Put(ctx context.Context, lg *zap.Logger, lessor lease.Lessor, kv mvcc.KV, txnWrite mvcc.TxnWrite, p *pb.PutRequest) (resp *pb.PutResponse, trace *traceutil.Trace, err error) { resp = &pb.PutResponse{} resp.Header = &pb.ResponseHeader{} @@ -217,7 +228,7 @@ func Range(ctx context.Context, lg *zap.Logger, kv mvcc.KV, txnRead mvcc.TxnRead return resp, nil } -func Txn(ctx context.Context, lg *zap.Logger, rt *pb.TxnRequest, txnModeWriteWithSharedBuffer bool, kv mvcc.KV, lessor lease.Lessor) (*pb.TxnResponse, *traceutil.Trace, error) { +func Txn(ctx context.Context, lg *zap.Logger, rt *pb.TxnRequest, txnModeWriteWithSharedBuffer bool, kv mvcc.KV, lessor lease.Lessor, errHander CriticalErrHandler) (*pb.TxnResponse, *traceutil.Trace, error) { trace := traceutil.Get(ctx) if trace.IsEmpty() { trace = traceutil.New("transaction", lg) @@ -265,7 +276,7 @@ func Txn(ctx context.Context, lg *zap.Logger, rt *pb.TxnRequest, txnModeWriteWit txnWrite.End() txnWrite = kv.Write(trace) } - applyTxn(ctx, lg, kv, lessor, txnWrite, rt, txnPath, txnResp) + _, err := applyTxn(ctx, lg, kv, lessor, txnWrite, rt, txnPath, txnResp, errHander) rev := txnWrite.Rev() if len(txnWrite.Changes()) != 0 { rev++ @@ -277,7 +288,7 @@ func Txn(ctx context.Context, lg *zap.Logger, rt *pb.TxnRequest, txnModeWriteWit traceutil.Field{Key: "number_of_response", Value: len(txnResp.Responses)}, traceutil.Field{Key: "response_revision", Value: txnResp.Header.Revision}, ) - return txnResp, trace, nil + return txnResp, trace, err } // newTxnResp allocates a txn response for a txn request given a path. @@ -311,7 +322,7 @@ func newTxnResp(rt *pb.TxnRequest, txnPath []bool) (txnResp *pb.TxnResponse, txn return txnResp, txnCount } -func applyTxn(ctx context.Context, lg *zap.Logger, kv mvcc.KV, lessor lease.Lessor, txnWrite mvcc.TxnWrite, rt *pb.TxnRequest, txnPath []bool, tresp *pb.TxnResponse) (txns int) { +func applyTxn(ctx context.Context, lg *zap.Logger, kv mvcc.KV, lessor lease.Lessor, txnWrite mvcc.TxnWrite, rt *pb.TxnRequest, txnPath []bool, tresp *pb.TxnResponse, errHander CriticalErrHandler) (txns int, err error) { trace := traceutil.Get(ctx) reqs := rt.Success if !txnPath[0] { @@ -328,7 +339,7 @@ func applyTxn(ctx context.Context, lg *zap.Logger, kv mvcc.KV, lessor lease.Less traceutil.Field{Key: "range_end", Value: string(tv.RequestRange.RangeEnd)}) resp, err := Range(ctx, lg, kv, txnWrite, tv.RequestRange) if err != nil { - lg.Panic("unexpected error during txnWrite", zap.Error(err)) + return 0, errHander(lg, err) } respi.(*pb.ResponseOp_ResponseRange).ResponseRange = resp trace.StopSubTrace() @@ -339,26 +350,30 @@ func applyTxn(ctx context.Context, lg *zap.Logger, kv mvcc.KV, lessor lease.Less traceutil.Field{Key: "req_size", Value: tv.RequestPut.Size()}) resp, _, err := Put(ctx, lg, lessor, kv, txnWrite, tv.RequestPut) if err != nil { - lg.Panic("unexpected error during txnWrite", zap.Error(err)) + return 0, errHander(lg, err) } respi.(*pb.ResponseOp_ResponsePut).ResponsePut = resp trace.StopSubTrace() case *pb.RequestOp_RequestDeleteRange: resp, err := DeleteRange(kv, txnWrite, tv.RequestDeleteRange) if err != nil { - lg.Panic("unexpected error during txnWrite", zap.Error(err)) + return 0, errHander(lg, err) } respi.(*pb.ResponseOp_ResponseDeleteRange).ResponseDeleteRange = resp case *pb.RequestOp_RequestTxn: resp := respi.(*pb.ResponseOp_ResponseTxn).ResponseTxn - applyTxns := applyTxn(ctx, lg, kv, lessor, txnWrite, tv.RequestTxn, txnPath[1:], resp) + applyTxns, err := applyTxn(ctx, lg, kv, lessor, txnWrite, tv.RequestTxn, txnPath[1:], resp, errHander) + if err != nil { + // no need to call errHander() since it was already called recursively + return 0, err + } txns += applyTxns + 1 txnPath = txnPath[applyTxns+1:] default: // empty union } } - return txns + return txns, nil } //--------------------------------------------------------- diff --git a/server/etcdserver/v3_server.go b/server/etcdserver/v3_server.go index 63a190e6ed69..ae2911600115 100644 --- a/server/etcdserver/v3_server.go +++ b/server/etcdserver/v3_server.go @@ -182,7 +182,9 @@ func (s *EtcdServer) Txn(ctx context.Context, r *pb.TxnRequest) (*pb.TxnResponse }(time.Now()) get := func() { - resp, _, err = txn.Txn(ctx, s.Logger(), r, s.Cfg.ExperimentalTxnModeWriteWithSharedBuffer, s.KV(), s.lessor) + // ctx is passed into txn, it can be cancelled and trigger an error. + // This is ok for readonly serializable txn, we use PassthroughErrHandler to bubble up the error. + resp, _, err = txn.Txn(ctx, s.Logger(), r, s.Cfg.ExperimentalTxnModeWriteWithSharedBuffer, s.KV(), s.lessor, txn.PassthroughErrHandler) } if serr := s.doSerialize(ctx, chk, get); serr != nil { return nil, serr diff --git a/server/storage/mvcc/kvstore_txn.go b/server/storage/mvcc/kvstore_txn.go index 604fac78cb3e..5508804c928b 100644 --- a/server/storage/mvcc/kvstore_txn.go +++ b/server/storage/mvcc/kvstore_txn.go @@ -16,6 +16,7 @@ package mvcc import ( "context" + "fmt" "go.etcd.io/etcd/api/v3/mvccpb" "go.etcd.io/etcd/pkg/v3/traceutil" @@ -94,7 +95,7 @@ func (tr *storeTxnRead) rangeKeys(ctx context.Context, key, end []byte, curRev i for i, revpair := range revpairs[:len(kvs)] { select { case <-ctx.Done(): - return nil, ctx.Err() + return nil, fmt.Errorf("range context cancelled: %w", ctx.Err()) default: } revToBytes(revpair, revBytes) diff --git a/tests/framework/integration/cluster.go b/tests/framework/integration/cluster.go index a59499701862..6ae50796d7d2 100644 --- a/tests/framework/integration/cluster.go +++ b/tests/framework/integration/cluster.go @@ -170,6 +170,9 @@ type ClusterConfig struct { ExperimentalMaxLearners int StrictReconfigCheck bool CorruptCheckTime time.Duration + // GrpcInterceptors allows to add additional interceptors to GrpcServer for testing + // For example can be used to cancel context on demand + GrpcInterceptors []grpc.UnaryServerInterceptor } type Cluster struct { @@ -284,6 +287,7 @@ func (c *Cluster) mustNewMember(t testutil.TB) *Member { ExperimentalMaxLearners: c.Cfg.ExperimentalMaxLearners, StrictReconfigCheck: c.Cfg.StrictReconfigCheck, CorruptCheckTime: c.Cfg.CorruptCheckTime, + GrpcInterceptors: c.Cfg.GrpcInterceptors, }) m.DiscoveryURL = c.Cfg.DiscoveryURL return m @@ -574,6 +578,7 @@ type Member struct { Closed bool GrpcServerRecorder *grpc_testing.GrpcRecorder + GrpcInterceptors []grpc.UnaryServerInterceptor } func (m *Member) GRPCURL() string { return m.GrpcURL } @@ -605,6 +610,7 @@ type MemberConfig struct { ExperimentalMaxLearners int StrictReconfigCheck bool CorruptCheckTime time.Duration + GrpcInterceptors []grpc.UnaryServerInterceptor } // MustNewMember return an inited member with the given name. If peerTLS is @@ -718,6 +724,7 @@ func MustNewMember(t testutil.TB, mcfg MemberConfig) *Member { } m.V2Deprecation = config.V2_DEPR_DEFAULT m.GrpcServerRecorder = &grpc_testing.GrpcRecorder{} + m.GrpcInterceptors = append(mcfg.GrpcInterceptors, m.GrpcServerRecorder.UnaryInterceptor()) m.Logger = memberLogger(t, mcfg.Name) m.StrictReconfigCheck = mcfg.StrictReconfigCheck if err := m.listenGRPC(); err != nil { @@ -938,7 +945,7 @@ func (m *Member) Launch() error { return err } } - m.GrpcServer = v3rpc.Server(m.Server, tlscfg, m.GrpcServerRecorder.UnaryInterceptor(), m.GrpcServerOpts...) + m.GrpcServer = v3rpc.Server(m.Server, tlscfg, m.GrpcInterceptors, m.GrpcServerOpts...) m.ServerClient = v3client.New(m.Server) lockpb.RegisterLockServer(m.GrpcServer, v3lock.NewLockServer(m.ServerClient)) epb.RegisterElectionServer(m.GrpcServer, v3election.NewElectionServer(m.ServerClient)) diff --git a/tests/integration/v3_grpc_test.go b/tests/integration/v3_grpc_test.go index 78b5c6c66cf7..941346165584 100644 --- a/tests/integration/v3_grpc_test.go +++ b/tests/integration/v3_grpc_test.go @@ -17,6 +17,7 @@ package integration import ( "bytes" "context" + "errors" "fmt" "math/rand" "os" @@ -1952,3 +1953,77 @@ func waitForRestart(t *testing.T, kvc pb.KVClient) { t.Fatalf("timed out waiting for restart: %v", err) } } + +func TestV3ReadonlyTxnCancelledContext(t *testing.T) { + integration.BeforeTest(t) + clus := integration.NewCluster(t, &integration.ClusterConfig{ + Size: 1, + // Context should be cancelled on the second check that happens inside rangeKeys + GrpcInterceptors: []grpc.UnaryServerInterceptor{injectMockContextForTxn(newMockContext(2))}, + }) + defer clus.Terminate(t) + + kvc := integration.ToGRPC(clus.RandClient()).KV + pr := &pb.PutRequest{Key: []byte("abc"), Value: []byte("def")} + _, err := kvc.Put(context.TODO(), pr) + if err != nil { + t.Fatal(err) + } + + txnget := &pb.RequestOp{Request: &pb.RequestOp_RequestRange{RequestRange: &pb.RangeRequest{Key: []byte("abc")}}} + txn := &pb.TxnRequest{Success: []*pb.RequestOp{txnget}} + _, err = kvc.Txn(context.TODO(), txn) + if err == nil || !strings.Contains(err.Error(), "range context cancelled: mock context error") { + t.Fatal(err) + } +} + +type mockCtx struct { + calledDone int + doneAfter int + + donec chan struct{} +} + +func newMockContext(doneAfter int) context.Context { + return &mockCtx{ + calledDone: 0, + doneAfter: doneAfter, + donec: make(chan struct{}), + } +} + +func (*mockCtx) Deadline() (deadline time.Time, ok bool) { + return +} + +func (ctx *mockCtx) Done() <-chan struct{} { + ctx.calledDone++ + if ctx.calledDone == ctx.doneAfter { + close(ctx.donec) + } + return ctx.donec +} + +func (*mockCtx) Err() error { + return errors.New("mock context error") +} + +func (*mockCtx) Value(interface{}) interface{} { + return nil +} + +func (*mockCtx) String() string { + return "mock Context" +} + +func injectMockContextForTxn(mctx context.Context) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + switch req.(type) { + case *pb.TxnRequest: + return handler(mctx, req) + default: + return handler(ctx, req) + } + } +}