diff --git a/pubsub/pstest/fake.go b/pubsub/pstest/fake.go index 21a07ba2bf9e..76620df0b30e 100644 --- a/pubsub/pstest/fake.go +++ b/pubsub/pstest/fake.go @@ -42,6 +42,24 @@ import ( "google.golang.org/grpc/status" ) +// ReactorOptions is a map that Server uses to look up reactors. +// Key is the function name, value is array of reactor for the function. +type ReactorOptions map[string][]Reactor + +// Reactor is an interface to allow reaction function to a certain call. +type Reactor interface { + // React handles the message types and returns results. If "handled" is false, + // then the test server will ignore the results and continue to the next reactor + // or the original handler. + React(_ interface{}) (handled bool, ret interface{}, err error) +} + +// ServerReactorOption is options passed to the server for reactor creation. +type ServerReactorOption struct { + FuncName string + Reactor Reactor +} + // For testing. Note that even though changes to the now variable are atomic, a call // to the stored function can race with a change to that function. This could be a // problem if tests are run in parallel, or even if concurrent parts of the same test @@ -70,31 +88,37 @@ type GServer struct { pb.PublisherServer pb.SubscriberServer - mu sync.Mutex - topics map[string]*topic - subs map[string]*subscription - msgs []*Message // all messages ever published - msgsByID map[string]*Message - wg sync.WaitGroup - nextID int - streamTimeout time.Duration - timeNowFunc func() time.Time + mu sync.Mutex + topics map[string]*topic + subs map[string]*subscription + msgs []*Message // all messages ever published + msgsByID map[string]*Message + wg sync.WaitGroup + nextID int + streamTimeout time.Duration + timeNowFunc func() time.Time + reactorOptions ReactorOptions } // NewServer creates a new fake server running in the current process. -func NewServer() *Server { +func NewServer(opts ...ServerReactorOption) *Server { srv, err := testutil.NewServer() if err != nil { panic(fmt.Sprintf("pstest.NewServer: %v", err)) } + reactorOptions := ReactorOptions{} + for _, opt := range opts { + reactorOptions[opt.FuncName] = append(reactorOptions[opt.FuncName], opt.Reactor) + } s := &Server{ srv: srv, Addr: srv.Addr, GServer: GServer{ - topics: map[string]*topic{}, - subs: map[string]*subscription{}, - msgsByID: map[string]*Message{}, - timeNowFunc: timeNow, + topics: map[string]*topic{}, + subs: map[string]*subscription{}, + msgsByID: map[string]*Message{}, + timeNowFunc: timeNow, + reactorOptions: reactorOptions, }, } pb.RegisterPublisherServer(srv.Gsrv, &s.GServer) @@ -237,6 +261,10 @@ func (s *GServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) s.mu.Lock() defer s.mu.Unlock() + if handled, ret, err := s.runReactor(t, "CreateTopic", &pb.Topic{}); handled || err != nil { + return ret.(*pb.Topic), err + } + if s.topics[t.Name] != nil { return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name) } @@ -249,6 +277,10 @@ func (s *GServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topi s.mu.Lock() defer s.mu.Unlock() + if handled, ret, err := s.runReactor(req, "GetTopic", &pb.Topic{}); handled || err != nil { + return ret.(*pb.Topic), err + } + if t := s.topics[req.Topic]; t != nil { return t.proto, nil } @@ -259,6 +291,10 @@ func (s *GServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*p s.mu.Lock() defer s.mu.Unlock() + if handled, ret, err := s.runReactor(req, "UpdateTopic", &pb.Topic{}); handled || err != nil { + return ret.(*pb.Topic), err + } + t := s.topics[req.Topic.Name] if t == nil { return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic.Name) @@ -280,6 +316,10 @@ func (s *GServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb. s.mu.Lock() defer s.mu.Unlock() + if handled, ret, err := s.runReactor(req, "ListTopics", &pb.ListTopicsResponse{}); handled || err != nil { + return ret.(*pb.ListTopicsResponse), err + } + var names []string for n := range s.topics { if strings.HasPrefix(n, req.Project) { @@ -302,6 +342,10 @@ func (s *GServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSub s.mu.Lock() defer s.mu.Unlock() + if handled, ret, err := s.runReactor(req, "ListTopicSubscriptions", &pb.ListTopicSubscriptionsResponse{}); handled || err != nil { + return ret.(*pb.ListTopicSubscriptionsResponse), err + } + var names []string for name, sub := range s.subs { if sub.topic.proto.Name == req.Topic { @@ -323,6 +367,10 @@ func (s *GServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*e s.mu.Lock() defer s.mu.Unlock() + if handled, ret, err := s.runReactor(req, "DeleteTopic", &emptypb.Empty{}); handled || err != nil { + return ret.(*emptypb.Empty), err + } + t := s.topics[req.Topic] if t == nil { return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic) @@ -336,6 +384,10 @@ func (s *GServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*p s.mu.Lock() defer s.mu.Unlock() + if handled, ret, err := s.runReactor(ps, "CreateSubscription", &pb.Subscription{}); handled || err != nil { + return ret.(*pb.Subscription), err + } + if ps.Name == "" { return nil, status.Errorf(codes.InvalidArgument, "missing name") } @@ -416,6 +468,11 @@ func checkMRD(pmrd *durpb.Duration) error { func (s *GServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) { s.mu.Lock() defer s.mu.Unlock() + + if handled, ret, err := s.runReactor(req, "GetSubscription", &pb.Subscription{}); handled || err != nil { + return ret.(*pb.Subscription), err + } + sub, err := s.findSubscription(req.Subscription) if err != nil { return nil, err @@ -429,6 +486,11 @@ func (s *GServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscripti } s.mu.Lock() defer s.mu.Unlock() + + if handled, ret, err := s.runReactor(req, "UpdateSubscription", &pb.Subscription{}); handled || err != nil { + return ret.(*pb.Subscription), err + } + sub, err := s.findSubscription(req.Subscription.Name) if err != nil { return nil, err @@ -480,6 +542,10 @@ func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptions s.mu.Lock() defer s.mu.Unlock() + if handled, ret, err := s.runReactor(req, "ListSubscriptions", &pb.ListSubscriptionsResponse{}); handled || err != nil { + return ret.(*pb.ListSubscriptionsResponse), err + } + var names []string for name := range s.subs { if strings.HasPrefix(name, req.Project) { @@ -501,6 +567,11 @@ func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptions func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) { s.mu.Lock() defer s.mu.Unlock() + + if handled, ret, err := s.runReactor(req, "DeleteSubscription", &emptypb.Empty{}); handled || err != nil { + return ret.(*emptypb.Empty), err + } + sub, err := s.findSubscription(req.Subscription) if err != nil { return nil, err @@ -514,6 +585,11 @@ func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscripti func (s *GServer) DetachSubscription(_ context.Context, req *pb.DetachSubscriptionRequest) (*pb.DetachSubscriptionResponse, error) { s.mu.Lock() defer s.mu.Unlock() + + if handled, ret, err := s.runReactor(req, "DetachSubscription", &pb.DetachSubscriptionResponse{}); handled || err != nil { + return ret.(*pb.DetachSubscriptionResponse), err + } + sub, err := s.findSubscription(req.Subscription) if err != nil { return nil, err @@ -526,6 +602,10 @@ func (s *GServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.Publis s.mu.Lock() defer s.mu.Unlock() + if handled, ret, err := s.runReactor(req, "Publish", &pb.PublishResponse{}); handled || err != nil { + return ret.(*pb.PublishResponse), err + } + if req.Topic == "" { return nil, status.Errorf(codes.InvalidArgument, "missing topic") } @@ -646,6 +726,10 @@ func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*e s.mu.Lock() defer s.mu.Unlock() + if handled, ret, err := s.runReactor(req, "Acknowledge", &emptypb.Empty{}); handled || err != nil { + return ret.(*emptypb.Empty), err + } + sub, err := s.findSubscription(req.Subscription) if err != nil { return nil, err @@ -659,6 +743,11 @@ func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*e func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) { s.mu.Lock() defer s.mu.Unlock() + + if handled, ret, err := s.runReactor(req, "ModifyAckDeadline", &emptypb.Empty{}); handled || err != nil { + return ret.(*emptypb.Empty), err + } + sub, err := s.findSubscription(req.Subscription) if err != nil { return nil, err @@ -676,6 +765,12 @@ func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadline func (s *GServer) Pull(ctx context.Context, req *pb.PullRequest) (*pb.PullResponse, error) { s.mu.Lock() + + if handled, ret, err := s.runReactor(req, "Pull", &pb.PullResponse{}); handled || err != nil { + s.mu.Unlock() + return ret.(*pb.PullResponse), err + } + sub, err := s.findSubscription(req.Subscription) if err != nil { s.mu.Unlock() @@ -751,6 +846,11 @@ func (s *GServer) Seek(ctx context.Context, req *pb.SeekRequest) (*pb.SeekRespon // because the messages don't have any other synchronization. s.mu.Lock() defer s.mu.Unlock() + + if handled, ret, err := s.runReactor(req, "Seek", &pb.SeekResponse{}); handled || err != nil { + return ret.(*pb.SeekResponse), err + } + sub, err := s.findSubscription(req.Subscription) if err != nil { return nil, err @@ -1050,3 +1150,41 @@ func (s *subscription) modifyAckDeadline(id string, d time.Duration) { func secsToDur(secs int32) time.Duration { return time.Duration(secs) * time.Second } + +// runReactor looks up the reactors for a function, then launches them until handled=true +// or err is returned. If the reactor returns nil, the function returns defaultObj instead. +func (s *GServer) runReactor(req interface{}, funcName string, defaultObj interface{}) (bool, interface{}, error) { + if val, ok := s.reactorOptions[funcName]; ok { + for _, reactor := range val { + handled, ret, err := reactor.React(req) + // If handled=true, that means the reactor has successfully reacted to the request, + // so use the output directly. If err occurs, that means the request is invalidated + // by the reactor somehow. + if handled || err != nil { + if ret == nil { + ret = defaultObj + } + return true, ret, err + } + } + } + return false, nil, nil +} + +// ErrorInjectionReactor is a reactor to inject an error message +type ErrorInjectionReactor struct { + errMsg string +} + +// React simply returns an error with defined error message. +func (e *ErrorInjectionReactor) React(_ interface{}) (handled bool, ret interface{}, err error) { + return true, nil, fmt.Errorf(e.errMsg) +} + +// WithErrorInjection creates a ServerReactorOption that injects error for a certain function. +func WithErrorInjection(funcName string, errMsg string) ServerReactorOption { + return ServerReactorOption{ + FuncName: funcName, + Reactor: &ErrorInjectionReactor{errMsg: errMsg}, + } +} diff --git a/pubsub/pstest/fake_test.go b/pubsub/pstest/fake_test.go index d03e2b5d7435..454a54a763fc 100644 --- a/pubsub/pstest/fake_test.go +++ b/pubsub/pstest/fake_test.go @@ -18,6 +18,8 @@ import ( "context" "fmt" "io" + "reflect" + "strings" "sync" "testing" "time" @@ -933,8 +935,8 @@ func mustUpdateSubscription(ctx context.Context, t *testing.T, sc pb.SubscriberC // client. Its final return is a cleanup function. // // Note: be sure to call cleanup! -func newFake(ctx context.Context, t *testing.T) (pb.PublisherClient, pb.SubscriberClient, *Server, func()) { - srv := NewServer() +func newFake(ctx context.Context, t *testing.T, opts ...ServerReactorOption) (pb.PublisherClient, pb.SubscriberClient, *Server, func()) { + srv := NewServer(opts...) conn, err := grpc.DialContext(ctx, srv.Addr, grpc.WithInsecure()) if err != nil { t.Fatal(err) @@ -944,3 +946,95 @@ func newFake(ctx context.Context, t *testing.T) (pb.PublisherClient, pb.Subscrib conn.Close() } } + +func TestErrorInjection(t *testing.T) { + testcases := []struct { + funcName string + param interface{} + }{ + { + funcName: "CreateTopic", + }, + { + funcName: "GetTopic", + }, + { + funcName: "UpdateTopic", + }, + { + funcName: "ListTopics", + }, + { + funcName: "ListTopicSubscriptions", + }, + { + funcName: "DeleteTopic", + }, + { + funcName: "CreateSubscription", + }, + { + funcName: "GetSubscription", + }, + { + funcName: "UpdateSubscription", + param: &pb.UpdateSubscriptionRequest{Subscription: &pb.Subscription{}}, + }, + { + funcName: "ListSubscriptions", + }, + { + funcName: "DeleteSubscription", + }, + { + funcName: "DetachSubscription", + }, + { + funcName: "Publish", + }, + { + funcName: "Acknowledge", + }, + { + funcName: "ModifyAckDeadline", + }, + { + funcName: "Pull", + }, + { + funcName: "Seek", + param: &pb.SeekRequest{Target: &pb.SeekRequest_Time{Time: ptypes.TimestampNow()}}, + }, + } + + for _, tc := range testcases { + ctx := context.TODO() + errMsg := "error-injection-" + tc.funcName + opts := []ServerReactorOption{ + WithErrorInjection(tc.funcName, errMsg), + } + _, _, server, cleanup := newFake(ctx, t, opts...) + defer cleanup() + + // We used reflection here to blindly look up the function by name and pass + // context and a typed nil, as all the functions under test will have such + // a function signature. + f := reflect.ValueOf(&server.GServer).MethodByName(tc.funcName) + if !f.IsValid() { + t.Fatalf("Method %v Not Found", tc.funcName) + } + // If param is provided, use the param, otherwise create a typed nil that matches the parameter type. + var req reflect.Value + if tc.param != nil { + req = reflect.ValueOf(tc.param) + } else { + req = reflect.New(f.Type().In(1).Elem()) + } + ret := reflect.ValueOf(&server.GServer).MethodByName(tc.funcName).Call([]reflect.Value{reflect.ValueOf(ctx), req}) + + got := ret[1].Interface().(error) + if got == nil || !strings.Contains(got.Error(), errMsg) { + t.Errorf("Got error does not contain the right key %v", got) + } + } +}