diff --git a/pubsublite/internal/wire/service.go b/pubsublite/internal/wire/service.go index 928d9a8df0d..565f39276d4 100644 --- a/pubsublite/internal/wire/service.go +++ b/pubsublite/internal/wire/service.go @@ -14,6 +14,7 @@ package wire import ( + "errors" "sync" ) @@ -56,6 +57,7 @@ type service interface { AddStatusChangeReceiver(serviceHandle, serviceStatusChangeFunc) RemoveStatusChangeReceiver(serviceHandle) Handle() serviceHandle + Status() serviceStatus Error() error } @@ -153,10 +155,7 @@ func (as *abstractService) unsafeUpdateStatus(targetStatus serviceStatus, err er return true } -type serviceHolder struct { - service service - lastStatus serviceStatus -} +var errChildServiceStarted = errors.New("pubsublite: dependent service must not be started") // compositeService can be embedded into other structs to manage child services. // It implements the service interface and can itself be a dependency of another @@ -169,8 +168,8 @@ type compositeService struct { waitStarted chan struct{} waitTerminated chan struct{} - dependencies []*serviceHolder - removed []*serviceHolder + dependencies []service + removed []service abstractService } @@ -188,7 +187,7 @@ func (cs *compositeService) Start() { if cs.abstractService.unsafeUpdateStatus(serviceStarting, nil) { for _, s := range cs.dependencies { - s.service.Start() + s.Start() } } } @@ -218,8 +217,14 @@ func (cs *compositeService) unsafeAddServices(services ...service) error { } for _, s := range services { + // Adding dependent services which have already started not currently + // supported. Requires updating logic to handle the compositeService state. + if s.Status() > serviceUninitialized { + return errChildServiceStarted + } + s.AddStatusChangeReceiver(cs.Handle(), cs.onServiceStatusChange) - cs.dependencies = append(cs.dependencies, &serviceHolder{service: s}) + cs.dependencies = append(cs.dependencies, s) if cs.status > serviceUninitialized { s.Start() } @@ -227,15 +232,15 @@ func (cs *compositeService) unsafeAddServices(services ...service) error { return nil } -func (cs *compositeService) unsafeRemoveService(service service) { +func (cs *compositeService) unsafeRemoveService(remove service) { removeIdx := -1 for i, s := range cs.dependencies { - if s.service.Handle() == service.Handle() { + if s.Handle() == remove.Handle() { // Move from the `dependencies` to the `removed` list. cs.removed = append(cs.removed, s) removeIdx = i - if s.lastStatus < serviceTerminating { - s.service.Stop() + if s.Status() < serviceTerminating { + s.Stop() } break } @@ -244,12 +249,13 @@ func (cs *compositeService) unsafeRemoveService(service service) { } func (cs *compositeService) unsafeInitiateShutdown(targetStatus serviceStatus, err error) { - for _, s := range cs.dependencies { - if s.lastStatus < serviceTerminating { - s.service.Stop() + if cs.unsafeUpdateStatus(targetStatus, err) { + for _, s := range cs.dependencies { + if s.Status() < serviceTerminating { + s.Stop() + } } } - cs.unsafeUpdateStatus(targetStatus, err) } func (cs *compositeService) unsafeUpdateStatus(targetStatus serviceStatus, err error) (ret bool) { @@ -257,7 +263,7 @@ func (cs *compositeService) unsafeUpdateStatus(targetStatus serviceStatus, err e if ret = cs.abstractService.unsafeUpdateStatus(targetStatus, err); ret { // Note: the waitStarted channel must be closed when the service fails to // start. - if previousStatus == serviceStarting { + if previousStatus < serviceActive && targetStatus >= serviceActive { close(cs.waitStarted) } if targetStatus == serviceTerminated { @@ -273,17 +279,15 @@ func (cs *compositeService) onServiceStatusChange(handle serviceHandle, status s removeIdx := -1 for i, s := range cs.removed { - if s.service.Handle() == handle { + if s.Handle() == handle { if status == serviceTerminated { - s.service.RemoveStatusChangeReceiver(cs.Handle()) + s.RemoveStatusChangeReceiver(cs.Handle()) removeIdx = i } break } } - if removeIdx >= 0 { - cs.removed = removeFromSlice(cs.removed, removeIdx) - } + cs.removed = removeFromSlice(cs.removed, removeIdx) // Note: we cannot rely on the service not being in the removed list above to // determine whether it is an active dependency. The notification may be for a @@ -291,10 +295,7 @@ func (cs *compositeService) onServiceStatusChange(handle serviceHandle, status s // changes are notified asynchronously and may be received out of order. isDependency := false for _, s := range cs.dependencies { - if s.service.Handle() == handle { - if status > s.lastStatus { - s.lastStatus = status - } + if s.Handle() == handle { isDependency = true break } @@ -307,13 +308,13 @@ func (cs *compositeService) onServiceStatusChange(handle serviceHandle, status s numTerminated := 0 for _, s := range cs.dependencies { - if shouldTerminate && s.lastStatus < serviceTerminating { - s.service.Stop() + if shouldTerminate && s.Status() < serviceTerminating { + s.Stop() } - if s.lastStatus >= serviceActive { + if s.Status() >= serviceActive { numStarted++ } - if s.lastStatus == serviceTerminated { + if s.Status() == serviceTerminated { numTerminated++ } } @@ -328,7 +329,7 @@ func (cs *compositeService) onServiceStatusChange(handle serviceHandle, status s } } -func removeFromSlice(services []*serviceHolder, removeIdx int) []*serviceHolder { +func removeFromSlice(services []service, removeIdx int) []service { lastIdx := len(services) - 1 if removeIdx < 0 || removeIdx > lastIdx { return services diff --git a/pubsublite/internal/wire/service_test.go b/pubsublite/internal/wire/service_test.go index 9a5b22c8dbb..0dd9b9be4fd 100644 --- a/pubsublite/internal/wire/service_test.go +++ b/pubsublite/internal/wire/service_test.go @@ -202,10 +202,10 @@ func newTestCompositeService(name string) *testCompositeService { return ts } -func (ts *testCompositeService) AddServices(services ...service) { +func (ts *testCompositeService) AddServices(services ...service) error { ts.mu.Lock() defer ts.mu.Unlock() - ts.unsafeAddServices(services...) + return ts.unsafeAddServices(services...) } func (ts *testCompositeService) RemoveService(service service) { @@ -231,7 +231,9 @@ func TestCompositeServiceNormalStop(t *testing.T) { child2 := newTestService("child2") child3 := newTestService("child3") parent := newTestCompositeService("parent") - parent.AddServices(child1, child2) + if err := parent.AddServices(child1, child2); err != nil { + t.Errorf("AddServices() got err: %v", err) + } t.Run("Starting", func(t *testing.T) { wantState := serviceUninitialized @@ -252,7 +254,9 @@ func TestCompositeServiceNormalStop(t *testing.T) { if child3.Status() != wantState { t.Errorf("child3: current service status: got %d, want %d", child3.Status(), wantState) } - parent.AddServices(child3) + if err := parent.AddServices(child3); err != nil { + t.Errorf("AddServices() got err: %v", err) + } child3.receiver.VerifyStatus(t, serviceStarting) }) @@ -300,7 +304,9 @@ func TestCompositeServiceErrorDuringStartup(t *testing.T) { child1 := newTestService("child1") child2 := newTestService("child2") parent := newTestCompositeService("parent") - parent.AddServices(child1, child2) + if err := parent.AddServices(child1, child2); err != nil { + t.Errorf("AddServices() got err: %v", err) + } t.Run("Starting", func(t *testing.T) { parent.Start() @@ -334,7 +340,9 @@ func TestCompositeServiceErrorWhileActive(t *testing.T) { child1 := newTestService("child1") child2 := newTestService("child2") parent := newTestCompositeService("parent") - parent.AddServices(child1, child2) + if err := parent.AddServices(child1, child2); err != nil { + t.Errorf("AddServices() got err: %v", err) + } t.Run("Starting", func(t *testing.T) { parent.Start() @@ -382,7 +390,9 @@ func TestCompositeServiceRemoveService(t *testing.T) { child1 := newTestService("child1") child2 := newTestService("child2") parent := newTestCompositeService("parent") - parent.AddServices(child1, child2) + if err := parent.AddServices(child1, child2); err != nil { + t.Errorf("AddServices() got err: %v", err) + } t.Run("Starting", func(t *testing.T) { parent.Start() @@ -452,16 +462,21 @@ func TestCompositeServiceTree(t *testing.T) { leaf1 := newTestService("leaf1") leaf2 := newTestService("leaf2") intermediate1 := newTestCompositeService("intermediate1") - intermediate1.AddServices(leaf1, leaf2) + if err := intermediate1.AddServices(leaf1, leaf2); err != nil { + t.Errorf("intermediate1.AddServices() got err: %v", err) + } leaf3 := newTestService("leaf3") leaf4 := newTestService("leaf4") intermediate2 := newTestCompositeService("intermediate2") - intermediate2.AddServices(leaf3, leaf4) + if err := intermediate2.AddServices(leaf3, leaf4); err != nil { + t.Errorf("intermediate2.AddServices() got err: %v", err) + } root := newTestCompositeService("root") - root.AddServices(intermediate1, intermediate2) - + if err := root.AddServices(intermediate1, intermediate2); err != nil { + t.Errorf("root.AddServices() got err: %v", err) + } wantErr := errors.New("fail") t.Run("Starting", func(t *testing.T) { @@ -528,3 +543,23 @@ func TestCompositeServiceTree(t *testing.T) { } }) } + +func TestCompositeServiceAddServicesErrors(t *testing.T) { + child1 := newTestService("child1") + parent := newTestCompositeService("parent") + if err := parent.AddServices(child1); err != nil { + t.Errorf("AddServices(child1) got err: %v", err) + } + + child2 := newTestService("child2") + child2.Start() + if gotErr, wantErr := parent.AddServices(child2), errChildServiceStarted; !test.ErrorEqual(gotErr, wantErr) { + t.Errorf("AddServices(child2) got err: (%v), want err: (%v)", gotErr, wantErr) + } + + parent.Stop() + child3 := newTestService("child3") + if gotErr, wantErr := parent.AddServices(child3), ErrServiceStopped; !test.ErrorEqual(gotErr, wantErr) { + t.Errorf("AddServices(child3) got err: (%v), want err: (%v)", gotErr, wantErr) + } +}