Skip to content

Commit

Permalink
Add reactor options to pstest server.
Browse files Browse the repository at this point in the history
This allows users to define customized reactor to a sub command handler.
An example is to inject error in topic deletion. A generic ErrorInjection
is also provided to inject error to handlers.
  • Loading branch information
Jimmy Lin committed Sep 25, 2020
1 parent fa82905 commit 7c5fd6c
Show file tree
Hide file tree
Showing 2 changed files with 248 additions and 16 deletions.
166 changes: 152 additions & 14 deletions pubsub/pstest/fake.go
Expand Up @@ -42,6 +42,24 @@ import (
"google.golang.org/grpc/status"
)

// ReactorOptions is a map that Server uses to look up reactors.
// Key is the function name, value is array of reactor for the function.
type ReactorOptions map[string][]Reactor

// Reactor is an interface to allow reaction function to a certain call.
type Reactor interface {
// React handles the message types and returns results. If "handled" is false,
// then the test server will ignore the results and continue to the next reactor
// or the original handler.
React(_ interface{}) (handled bool, ret interface{}, err error)
}

// ServerReactorOption is options passed to the server for reactor creation.
type ServerReactorOption struct {
FuncName string
Reactor Reactor
}

// For testing. Note that even though changes to the now variable are atomic, a call
// to the stored function can race with a change to that function. This could be a
// problem if tests are run in parallel, or even if concurrent parts of the same test
Expand Down Expand Up @@ -70,31 +88,37 @@ type GServer struct {
pb.PublisherServer
pb.SubscriberServer

mu sync.Mutex
topics map[string]*topic
subs map[string]*subscription
msgs []*Message // all messages ever published
msgsByID map[string]*Message
wg sync.WaitGroup
nextID int
streamTimeout time.Duration
timeNowFunc func() time.Time
mu sync.Mutex
topics map[string]*topic
subs map[string]*subscription
msgs []*Message // all messages ever published
msgsByID map[string]*Message
wg sync.WaitGroup
nextID int
streamTimeout time.Duration
timeNowFunc func() time.Time
reactorOptions ReactorOptions
}

// NewServer creates a new fake server running in the current process.
func NewServer() *Server {
func NewServer(opts ...ServerReactorOption) *Server {
srv, err := testutil.NewServer()
if err != nil {
panic(fmt.Sprintf("pstest.NewServer: %v", err))
}
reactorOptions := ReactorOptions{}
for _, opt := range opts {
reactorOptions[opt.FuncName] = append(reactorOptions[opt.FuncName], opt.Reactor)
}
s := &Server{
srv: srv,
Addr: srv.Addr,
GServer: GServer{
topics: map[string]*topic{},
subs: map[string]*subscription{},
msgsByID: map[string]*Message{},
timeNowFunc: timeNow,
topics: map[string]*topic{},
subs: map[string]*subscription{},
msgsByID: map[string]*Message{},
timeNowFunc: timeNow,
reactorOptions: reactorOptions,
},
}
pb.RegisterPublisherServer(srv.Gsrv, &s.GServer)
Expand Down Expand Up @@ -237,6 +261,10 @@ func (s *GServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error)
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(t, "CreateTopic", &pb.Topic{}); handled || err != nil {
return ret.(*pb.Topic), err
}

if s.topics[t.Name] != nil {
return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name)
}
Expand All @@ -249,6 +277,10 @@ func (s *GServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topi
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "GetTopic", &pb.Topic{}); handled || err != nil {
return ret.(*pb.Topic), err
}

if t := s.topics[req.Topic]; t != nil {
return t.proto, nil
}
Expand All @@ -259,6 +291,10 @@ func (s *GServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*p
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "UpdateTopic", &pb.Topic{}); handled || err != nil {
return ret.(*pb.Topic), err
}

t := s.topics[req.Topic.Name]
if t == nil {
return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic.Name)
Expand All @@ -280,6 +316,10 @@ func (s *GServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "ListTopics", &pb.ListTopicsResponse{}); handled || err != nil {
return ret.(*pb.ListTopicsResponse), err
}

var names []string
for n := range s.topics {
if strings.HasPrefix(n, req.Project) {
Expand All @@ -302,6 +342,10 @@ func (s *GServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSub
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "ListTopicSubscriptions", &pb.ListTopicSubscriptionsResponse{}); handled || err != nil {
return ret.(*pb.ListTopicSubscriptionsResponse), err
}

var names []string
for name, sub := range s.subs {
if sub.topic.proto.Name == req.Topic {
Expand All @@ -323,6 +367,10 @@ func (s *GServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*e
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "DeleteTopic", &emptypb.Empty{}); handled || err != nil {
return ret.(*emptypb.Empty), err
}

t := s.topics[req.Topic]
if t == nil {
return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
Expand All @@ -336,6 +384,10 @@ func (s *GServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*p
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(ps, "CreateSubscription", &pb.Subscription{}); handled || err != nil {
return ret.(*pb.Subscription), err
}

if ps.Name == "" {
return nil, status.Errorf(codes.InvalidArgument, "missing name")
}
Expand Down Expand Up @@ -416,6 +468,11 @@ func checkMRD(pmrd *durpb.Duration) error {
func (s *GServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "GetSubscription", &pb.Subscription{}); handled || err != nil {
return ret.(*pb.Subscription), err
}

sub, err := s.findSubscription(req.Subscription)
if err != nil {
return nil, err
Expand All @@ -429,6 +486,11 @@ func (s *GServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscripti
}
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "UpdateSubscription", &pb.Subscription{}); handled || err != nil {
return ret.(*pb.Subscription), err
}

sub, err := s.findSubscription(req.Subscription.Name)
if err != nil {
return nil, err
Expand Down Expand Up @@ -480,6 +542,10 @@ func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptions
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "ListSubscriptions", &pb.ListSubscriptionsResponse{}); handled || err != nil {
return ret.(*pb.ListSubscriptionsResponse), err
}

var names []string
for name := range s.subs {
if strings.HasPrefix(name, req.Project) {
Expand All @@ -501,6 +567,11 @@ func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptions
func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "DeleteSubscription", &emptypb.Empty{}); handled || err != nil {
return ret.(*emptypb.Empty), err
}

sub, err := s.findSubscription(req.Subscription)
if err != nil {
return nil, err
Expand All @@ -514,6 +585,11 @@ func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscripti
func (s *GServer) DetachSubscription(_ context.Context, req *pb.DetachSubscriptionRequest) (*pb.DetachSubscriptionResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "DetachSubscription", &pb.DetachSubscriptionResponse{}); handled || err != nil {
return ret.(*pb.DetachSubscriptionResponse), err
}

sub, err := s.findSubscription(req.Subscription)
if err != nil {
return nil, err
Expand All @@ -526,6 +602,10 @@ func (s *GServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.Publis
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "Publish", &pb.PublishResponse{}); handled || err != nil {
return ret.(*pb.PublishResponse), err
}

if req.Topic == "" {
return nil, status.Errorf(codes.InvalidArgument, "missing topic")
}
Expand Down Expand Up @@ -646,6 +726,10 @@ func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*e
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "Acknowledge", &emptypb.Empty{}); handled || err != nil {
return ret.(*emptypb.Empty), err
}

sub, err := s.findSubscription(req.Subscription)
if err != nil {
return nil, err
Expand All @@ -659,6 +743,11 @@ func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*e
func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "ModifyAckDeadline", &emptypb.Empty{}); handled || err != nil {
return ret.(*emptypb.Empty), err
}

sub, err := s.findSubscription(req.Subscription)
if err != nil {
return nil, err
Expand All @@ -676,6 +765,12 @@ func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadline

func (s *GServer) Pull(ctx context.Context, req *pb.PullRequest) (*pb.PullResponse, error) {
s.mu.Lock()

if handled, ret, err := s.runReactor(req, "Pull", &pb.PullResponse{}); handled || err != nil {
s.mu.Unlock()
return ret.(*pb.PullResponse), err
}

sub, err := s.findSubscription(req.Subscription)
if err != nil {
s.mu.Unlock()
Expand Down Expand Up @@ -751,6 +846,11 @@ func (s *GServer) Seek(ctx context.Context, req *pb.SeekRequest) (*pb.SeekRespon
// because the messages don't have any other synchronization.
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "Seek", &pb.SeekResponse{}); handled || err != nil {
return ret.(*pb.SeekResponse), err
}

sub, err := s.findSubscription(req.Subscription)
if err != nil {
return nil, err
Expand Down Expand Up @@ -1050,3 +1150,41 @@ func (s *subscription) modifyAckDeadline(id string, d time.Duration) {
func secsToDur(secs int32) time.Duration {
return time.Duration(secs) * time.Second
}

// runReactor looks up the reactors for a function, then launches them until handled=true
// or err is returned. If the reactor returns nil, the function returns defaultObj instead.
func (s *GServer) runReactor(req interface{}, funcName string, defaultObj interface{}) (bool, interface{}, error) {
if val, ok := s.reactorOptions[funcName]; ok {
for _, reactor := range val {
handled, ret, err := reactor.React(req)
// If handled=true, that means the reactor has successfully reacted to the request,
// so use the output directly. If err occurs, that means the request is invalidated
// by the reactor somehow.
if handled || err != nil {
if ret == nil {
ret = defaultObj
}
return true, ret, err
}
}
}
return false, nil, nil
}

// ErrorInjectionReactor is a reactor to inject an error message
type ErrorInjectionReactor struct {
errMsg string
}

// React simply returns an error with defined error message.
func (e *ErrorInjectionReactor) React(_ interface{}) (handled bool, ret interface{}, err error) {
return true, nil, fmt.Errorf(e.errMsg)
}

// WithErrorInjection creates a ServerReactorOption that injects error for a certain function.
func WithErrorInjection(funcName string, errMsg string) ServerReactorOption {
return ServerReactorOption{
FuncName: funcName,
Reactor: &ErrorInjectionReactor{errMsg: errMsg},
}
}

0 comments on commit 7c5fd6c

Please sign in to comment.