Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pubsub/pstest): Add reactor options to pstest server #2916

Merged
merged 1 commit into from Sep 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
166 changes: 152 additions & 14 deletions pubsub/pstest/fake.go
Expand Up @@ -42,6 +42,24 @@ import (
"google.golang.org/grpc/status"
)

// ReactorOptions is a map that Server uses to look up reactors.
// Key is the function name, value is array of reactor for the function.
type ReactorOptions map[string][]Reactor

// Reactor is an interface to allow reaction function to a certain call.
type Reactor interface {
// React handles the message types and returns results. If "handled" is false,
// then the test server will ignore the results and continue to the next reactor
// or the original handler.
React(_ interface{}) (handled bool, ret interface{}, err error)
}

// ServerReactorOption is options passed to the server for reactor creation.
type ServerReactorOption struct {
FuncName string
Reactor Reactor
}

// For testing. Note that even though changes to the now variable are atomic, a call
// to the stored function can race with a change to that function. This could be a
// problem if tests are run in parallel, or even if concurrent parts of the same test
Expand Down Expand Up @@ -70,31 +88,37 @@ type GServer struct {
pb.PublisherServer
pb.SubscriberServer

mu sync.Mutex
topics map[string]*topic
subs map[string]*subscription
msgs []*Message // all messages ever published
msgsByID map[string]*Message
wg sync.WaitGroup
nextID int
streamTimeout time.Duration
timeNowFunc func() time.Time
mu sync.Mutex
topics map[string]*topic
subs map[string]*subscription
msgs []*Message // all messages ever published
msgsByID map[string]*Message
wg sync.WaitGroup
nextID int
streamTimeout time.Duration
timeNowFunc func() time.Time
reactorOptions ReactorOptions
}

// NewServer creates a new fake server running in the current process.
func NewServer() *Server {
func NewServer(opts ...ServerReactorOption) *Server {
srv, err := testutil.NewServer()
if err != nil {
panic(fmt.Sprintf("pstest.NewServer: %v", err))
}
reactorOptions := ReactorOptions{}
for _, opt := range opts {
reactorOptions[opt.FuncName] = append(reactorOptions[opt.FuncName], opt.Reactor)
}
s := &Server{
srv: srv,
Addr: srv.Addr,
GServer: GServer{
topics: map[string]*topic{},
subs: map[string]*subscription{},
msgsByID: map[string]*Message{},
timeNowFunc: timeNow,
topics: map[string]*topic{},
subs: map[string]*subscription{},
msgsByID: map[string]*Message{},
timeNowFunc: timeNow,
reactorOptions: reactorOptions,
},
}
pb.RegisterPublisherServer(srv.Gsrv, &s.GServer)
Expand Down Expand Up @@ -237,6 +261,10 @@ func (s *GServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error)
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(t, "CreateTopic", &pb.Topic{}); handled || err != nil {
return ret.(*pb.Topic), err
}

if s.topics[t.Name] != nil {
return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name)
}
Expand All @@ -249,6 +277,10 @@ func (s *GServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topi
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "GetTopic", &pb.Topic{}); handled || err != nil {
return ret.(*pb.Topic), err
}

if t := s.topics[req.Topic]; t != nil {
return t.proto, nil
}
Expand All @@ -259,6 +291,10 @@ func (s *GServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*p
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "UpdateTopic", &pb.Topic{}); handled || err != nil {
return ret.(*pb.Topic), err
}

t := s.topics[req.Topic.Name]
if t == nil {
return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic.Name)
Expand All @@ -280,6 +316,10 @@ func (s *GServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "ListTopics", &pb.ListTopicsResponse{}); handled || err != nil {
return ret.(*pb.ListTopicsResponse), err
}

var names []string
for n := range s.topics {
if strings.HasPrefix(n, req.Project) {
Expand All @@ -302,6 +342,10 @@ func (s *GServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSub
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "ListTopicSubscriptions", &pb.ListTopicSubscriptionsResponse{}); handled || err != nil {
return ret.(*pb.ListTopicSubscriptionsResponse), err
}

var names []string
for name, sub := range s.subs {
if sub.topic.proto.Name == req.Topic {
Expand All @@ -323,6 +367,10 @@ func (s *GServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*e
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "DeleteTopic", &emptypb.Empty{}); handled || err != nil {
return ret.(*emptypb.Empty), err
}

t := s.topics[req.Topic]
if t == nil {
return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
Expand All @@ -336,6 +384,10 @@ func (s *GServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*p
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(ps, "CreateSubscription", &pb.Subscription{}); handled || err != nil {
return ret.(*pb.Subscription), err
}

if ps.Name == "" {
return nil, status.Errorf(codes.InvalidArgument, "missing name")
}
Expand Down Expand Up @@ -416,6 +468,11 @@ func checkMRD(pmrd *durpb.Duration) error {
func (s *GServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "GetSubscription", &pb.Subscription{}); handled || err != nil {
return ret.(*pb.Subscription), err
}

sub, err := s.findSubscription(req.Subscription)
if err != nil {
return nil, err
Expand All @@ -429,6 +486,11 @@ func (s *GServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscripti
}
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "UpdateSubscription", &pb.Subscription{}); handled || err != nil {
return ret.(*pb.Subscription), err
}

sub, err := s.findSubscription(req.Subscription.Name)
if err != nil {
return nil, err
Expand Down Expand Up @@ -480,6 +542,10 @@ func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptions
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "ListSubscriptions", &pb.ListSubscriptionsResponse{}); handled || err != nil {
return ret.(*pb.ListSubscriptionsResponse), err
}

var names []string
for name := range s.subs {
if strings.HasPrefix(name, req.Project) {
Expand All @@ -501,6 +567,11 @@ func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptions
func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "DeleteSubscription", &emptypb.Empty{}); handled || err != nil {
return ret.(*emptypb.Empty), err
}

sub, err := s.findSubscription(req.Subscription)
if err != nil {
return nil, err
Expand All @@ -514,6 +585,11 @@ func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscripti
func (s *GServer) DetachSubscription(_ context.Context, req *pb.DetachSubscriptionRequest) (*pb.DetachSubscriptionResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "DetachSubscription", &pb.DetachSubscriptionResponse{}); handled || err != nil {
return ret.(*pb.DetachSubscriptionResponse), err
}

sub, err := s.findSubscription(req.Subscription)
if err != nil {
return nil, err
Expand All @@ -526,6 +602,10 @@ func (s *GServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.Publis
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "Publish", &pb.PublishResponse{}); handled || err != nil {
return ret.(*pb.PublishResponse), err
}

if req.Topic == "" {
return nil, status.Errorf(codes.InvalidArgument, "missing topic")
}
Expand Down Expand Up @@ -646,6 +726,10 @@ func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*e
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "Acknowledge", &emptypb.Empty{}); handled || err != nil {
return ret.(*emptypb.Empty), err
}

sub, err := s.findSubscription(req.Subscription)
if err != nil {
return nil, err
Expand All @@ -659,6 +743,11 @@ func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*e
func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "ModifyAckDeadline", &emptypb.Empty{}); handled || err != nil {
return ret.(*emptypb.Empty), err
}

sub, err := s.findSubscription(req.Subscription)
if err != nil {
return nil, err
Expand All @@ -676,6 +765,12 @@ func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadline

func (s *GServer) Pull(ctx context.Context, req *pb.PullRequest) (*pb.PullResponse, error) {
s.mu.Lock()

if handled, ret, err := s.runReactor(req, "Pull", &pb.PullResponse{}); handled || err != nil {
s.mu.Unlock()
return ret.(*pb.PullResponse), err
}

sub, err := s.findSubscription(req.Subscription)
if err != nil {
s.mu.Unlock()
Expand Down Expand Up @@ -751,6 +846,11 @@ func (s *GServer) Seek(ctx context.Context, req *pb.SeekRequest) (*pb.SeekRespon
// because the messages don't have any other synchronization.
s.mu.Lock()
defer s.mu.Unlock()

if handled, ret, err := s.runReactor(req, "Seek", &pb.SeekResponse{}); handled || err != nil {
return ret.(*pb.SeekResponse), err
}

sub, err := s.findSubscription(req.Subscription)
if err != nil {
return nil, err
Expand Down Expand Up @@ -1050,3 +1150,41 @@ func (s *subscription) modifyAckDeadline(id string, d time.Duration) {
func secsToDur(secs int32) time.Duration {
return time.Duration(secs) * time.Second
}

// runReactor looks up the reactors for a function, then launches them until handled=true
// or err is returned. If the reactor returns nil, the function returns defaultObj instead.
func (s *GServer) runReactor(req interface{}, funcName string, defaultObj interface{}) (bool, interface{}, error) {
if val, ok := s.reactorOptions[funcName]; ok {
for _, reactor := range val {
handled, ret, err := reactor.React(req)
// If handled=true, that means the reactor has successfully reacted to the request,
// so use the output directly. If err occurs, that means the request is invalidated
// by the reactor somehow.
if handled || err != nil {
if ret == nil {
ret = defaultObj
}
return true, ret, err
}
}
}
return false, nil, nil
}

// ErrorInjectionReactor is a reactor to inject an error message
type ErrorInjectionReactor struct {
errMsg string
}

// React simply returns an error with defined error message.
func (e *ErrorInjectionReactor) React(_ interface{}) (handled bool, ret interface{}, err error) {
return true, nil, fmt.Errorf(e.errMsg)
}

// WithErrorInjection creates a ServerReactorOption that injects error for a certain function.
func WithErrorInjection(funcName string, errMsg string) ServerReactorOption {
return ServerReactorOption{
FuncName: funcName,
Reactor: &ErrorInjectionReactor{errMsg: errMsg},
}
}