From 63195b99ab9a611a0e5c9cb85a3a49390f4c6855 Mon Sep 17 00:00:00 2001 From: Richard Park <51494936+richardpark-msft@users.noreply.github.com> Date: Fri, 15 Apr 2022 17:01:38 -0700 Subject: [PATCH] [azservicebus] Allow link creation to be cancelled (#17598) AcceptNextSessionFor(Queue|Subscription) can block for a long time (server dependent) if there are no available sessions. HOWEVER, it was intended to be cancellable, which wasn't working. This is a simpler workaround until we get context support plumbed through go-amqp itself. Fixes #17565 --- sdk/messaging/azservicebus/CHANGELOG.md | 3 ++ sdk/messaging/azservicebus/client_test.go | 27 +++++++++- .../azservicebus/internal/amqpLinks.go | 5 +- .../azservicebus/internal/amqpLinks_test.go | 33 ++++++------ .../azservicebus/internal/amqp_test_utils.go | 22 ++++++-- .../internal/amqpwrap/amqpwrap.go | 10 ++-- .../azservicebus/internal/namespace.go | 6 +-- sdk/messaging/azservicebus/receiver.go | 52 +++++++++++++++---- .../azservicebus/receiver_unit_test.go | 30 +++++++++++ sdk/messaging/azservicebus/sender.go | 3 +- .../azservicebus/session_receiver.go | 3 +- 11 files changed, 151 insertions(+), 43 deletions(-) diff --git a/sdk/messaging/azservicebus/CHANGELOG.md b/sdk/messaging/azservicebus/CHANGELOG.md index 2591a37dbee1..9f93cbadcb27 100644 --- a/sdk/messaging/azservicebus/CHANGELOG.md +++ b/sdk/messaging/azservicebus/CHANGELOG.md @@ -12,6 +12,9 @@ ### Bugs Fixed +- Fixing issue where the AcceptNextSessionForQueue and AcceptNextSessionForSubscription + couldn't be cancelled, forcing the user to wait for the service to timeout. (#17598) + ### Other Changes ## 0.4.0 (2022-04-06) diff --git a/sdk/messaging/azservicebus/client_test.go b/sdk/messaging/azservicebus/client_test.go index a77118e7e967..ede056561f31 100644 --- a/sdk/messaging/azservicebus/client_test.go +++ b/sdk/messaging/azservicebus/client_test.go @@ -13,7 +13,9 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/admin" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/test" "github.com/Azure/azure-sdk-for-go/sdk/messaging/internal/sas" @@ -204,7 +206,7 @@ func TestNewClientNewReceiverNotFound(t *testing.T) { assertRPCNotFound(t, err) } -func TestNewClientNewSessionReceiverNotFound(t *testing.T) { +func TestClientNewSessionReceiverNotFound(t *testing.T) { connectionString := test.GetConnectionString(t) client, err := NewClientFromConnectionString(connectionString, nil) require.NoError(t, err) @@ -258,6 +260,29 @@ func TestClientCloseVsClosePermanently(t *testing.T) { require.Nil(t, sessionReceiver) } +func TestClientNewSessionReceiverCancel(t *testing.T) { + // Both the session APIs create the receiver immediately however AcceptNextSession() has a quirk + // where it takes an excessively long time. + connectionString := test.GetConnectionString(t) + + queue, cleanup := createQueue(t, connectionString, &admin.QueueProperties{ + RequiresSession: to.Ptr(true), + }) + + defer cleanup() + + client, err := NewClientFromConnectionString(connectionString, nil) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // non-cancelled version + receiver, err := client.AcceptNextSessionForQueue(ctx, queue, nil) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.Nil(t, receiver) +} + func TestNewClientUnitTests(t *testing.T) { t.Run("WithTokenCredential", func(t *testing.T) { fakeTokenCredential := struct{ azcore.TokenCredential }{} diff --git a/sdk/messaging/azservicebus/internal/amqpLinks.go b/sdk/messaging/azservicebus/internal/amqpLinks.go index 1ab0d11b1eeb..87ad26949754 100644 --- a/sdk/messaging/azservicebus/internal/amqpLinks.go +++ b/sdk/messaging/azservicebus/internal/amqpLinks.go @@ -11,6 +11,7 @@ import ( "sync" "github.com/Azure/azure-sdk-for-go/sdk/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/amqpwrap" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" @@ -85,7 +86,7 @@ type AMQPLinksImpl struct { RPCLink RPCLink // the AMQP session for either the 'sender' or 'receiver' link - session AMQPSessionCloser + session amqpwrap.AMQPSession // these are populated by your `createLinkFunc` when you construct // the amqpLinks @@ -104,7 +105,7 @@ type AMQPLinksImpl struct { // CreateLinkFunc creates the links, using the given session. Typically you'll only create either an // *amqp.Sender or a *amqp.Receiver. AMQPLinks handles it either way. -type CreateLinkFunc func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) +type CreateLinkFunc func(ctx context.Context, session amqpwrap.AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) type NewAMQPLinksArgs struct { NS NamespaceForAMQPLinks diff --git a/sdk/messaging/azservicebus/internal/amqpLinks_test.go b/sdk/messaging/azservicebus/internal/amqpLinks_test.go index bc7f0a7c5027..bc0c35c4e545 100644 --- a/sdk/messaging/azservicebus/internal/amqpLinks_test.go +++ b/sdk/messaging/azservicebus/internal/amqpLinks_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/amqpwrap" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/test" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" @@ -77,7 +78,7 @@ func TestAMQPLinksBasic(t *testing.T) { links := NewAMQPLinks(NewAMQPLinksArgs{ NS: ns, EntityPath: entityPath, - CreateLinkFunc: func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + CreateLinkFunc: func(ctx context.Context, session amqpwrap.AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { return newLinksForAMQPLinksTest(entityPath, session) }, GetRecoveryKindFunc: GetRecoveryKind, @@ -112,7 +113,7 @@ func TestAMQPLinksLive(t *testing.T) { links := NewAMQPLinks(NewAMQPLinksArgs{ NS: ns, EntityPath: entityPath, - CreateLinkFunc: func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + CreateLinkFunc: func(ctx context.Context, session amqpwrap.AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { createLinksCalled++ return newLinksForAMQPLinksTest(entityPath, session) }, @@ -185,7 +186,7 @@ func TestAMQPLinksLiveRecoverLink(t *testing.T) { links := NewAMQPLinks(NewAMQPLinksArgs{ NS: ns, EntityPath: entityPath, - CreateLinkFunc: func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + CreateLinkFunc: func(ctx context.Context, session amqpwrap.AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { createLinksCalled++ return newLinksForAMQPLinksTest(entityPath, session) }, @@ -223,7 +224,7 @@ func TestAMQPLinksLiveRace(t *testing.T) { links := NewAMQPLinks(NewAMQPLinksArgs{ NS: ns, EntityPath: entityPath, - CreateLinkFunc: func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + CreateLinkFunc: func(ctx context.Context, session amqpwrap.AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { createLinksCalled++ return newLinksForAMQPLinksTest(entityPath, session) }, @@ -275,7 +276,7 @@ func TestAMQPLinksLiveRaceLink(t *testing.T) { links := NewAMQPLinks(NewAMQPLinksArgs{ NS: ns, EntityPath: entityPath, - CreateLinkFunc: func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + CreateLinkFunc: func(ctx context.Context, session amqpwrap.AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { createLinksCalled++ return newLinksForAMQPLinksTest(entityPath, session) }, @@ -319,7 +320,7 @@ func TestAMQPLinksRetry(t *testing.T) { links := NewAMQPLinks(NewAMQPLinksArgs{ NS: ns, EntityPath: entityPath, - CreateLinkFunc: func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + CreateLinkFunc: func(ctx context.Context, session amqpwrap.AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { createLinksCalled++ return newLinksForAMQPLinksTest(entityPath, session) }, @@ -361,7 +362,7 @@ func TestAMQPLinksMultipleWithSameConnection(t *testing.T) { links := NewAMQPLinks(NewAMQPLinksArgs{ NS: ns, EntityPath: entityPath, - CreateLinkFunc: func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + CreateLinkFunc: func(ctx context.Context, session amqpwrap.AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { createLinksCalled++ return newLinksForAMQPLinksTest(entityPath, session) }, @@ -377,7 +378,7 @@ func TestAMQPLinksMultipleWithSameConnection(t *testing.T) { links2 := NewAMQPLinks(NewAMQPLinksArgs{ NS: ns, EntityPath: entityPath, - CreateLinkFunc: func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + CreateLinkFunc: func(ctx context.Context, session amqpwrap.AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { createLinksCalled2++ return newLinksForAMQPLinksTest(entityPath, session) }, @@ -456,7 +457,7 @@ func TestAMQPLinksCloseIfNeeded(t *testing.T) { links := NewAMQPLinks(NewAMQPLinksArgs{ NS: ns, EntityPath: "entityPath", - CreateLinkFunc: func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + CreateLinkFunc: func(ctx context.Context, session amqpwrap.AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { return sender, receiver, nil }, GetRecoveryKindFunc: GetRecoveryKind, @@ -486,7 +487,7 @@ func TestAMQPLinksCloseIfNeeded(t *testing.T) { links := NewAMQPLinks(NewAMQPLinksArgs{ NS: ns, EntityPath: "entityPath", - CreateLinkFunc: func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + CreateLinkFunc: func(ctx context.Context, session amqpwrap.AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { return sender, receiver, nil }, GetRecoveryKindFunc: GetRecoveryKind, @@ -515,7 +516,7 @@ func TestAMQPLinksCloseIfNeeded(t *testing.T) { links := NewAMQPLinks(NewAMQPLinksArgs{ NS: ns, EntityPath: "entityPath", - CreateLinkFunc: func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + CreateLinkFunc: func(ctx context.Context, session amqpwrap.AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { return sender, receiver, nil }, GetRecoveryKindFunc: GetRecoveryKind, @@ -544,7 +545,7 @@ func TestAMQPLinksCloseIfNeeded(t *testing.T) { links := NewAMQPLinks(NewAMQPLinksArgs{ NS: ns, EntityPath: "entityPath", - CreateLinkFunc: func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + CreateLinkFunc: func(ctx context.Context, session amqpwrap.AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { return sender, receiver, nil }, GetRecoveryKindFunc: GetRecoveryKind, @@ -607,7 +608,7 @@ func TestAMQPLinksRetriesUnit(t *testing.T) { links := NewAMQPLinks(NewAMQPLinksArgs{ NS: ns, EntityPath: "entityPath", - CreateLinkFunc: func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + CreateLinkFunc: func(ctx context.Context, session amqpwrap.AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { return sender, receiver, nil }, GetRecoveryKindFunc: GetRecoveryKind, @@ -651,7 +652,7 @@ func TestAMQPLinks_Logging(t *testing.T) { links := NewAMQPLinks(NewAMQPLinksArgs{ NS: ns, EntityPath: "entityPath", - CreateLinkFunc: func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + CreateLinkFunc: func(ctx context.Context, session amqpwrap.AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { return nil, receiver, nil }, GetRecoveryKindFunc: GetRecoveryKind, @@ -684,7 +685,7 @@ func TestAMQPLinks_Logging(t *testing.T) { links := NewAMQPLinks(NewAMQPLinksArgs{ NS: ns, EntityPath: "entityPath", - CreateLinkFunc: func(ctx context.Context, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { + CreateLinkFunc: func(ctx context.Context, session amqpwrap.AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { return nil, receiver, nil }, GetRecoveryKindFunc: GetRecoveryKind, }) @@ -710,7 +711,7 @@ func TestAMQPLinks_Logging(t *testing.T) { }) } -func newLinksForAMQPLinksTest(entityPath string, session AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { +func newLinksForAMQPLinksTest(entityPath string, session amqpwrap.AMQPSession) (AMQPSenderCloser, AMQPReceiverCloser, error) { receiveMode := amqp.ModeSecond opts := []amqp.LinkOption{ diff --git a/sdk/messaging/azservicebus/internal/amqp_test_utils.go b/sdk/messaging/azservicebus/internal/amqp_test_utils.go index 9895f1c30ca0..7493686d9931 100644 --- a/sdk/messaging/azservicebus/internal/amqp_test_utils.go +++ b/sdk/messaging/azservicebus/internal/amqp_test_utils.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/Azure/azure-sdk-for-go/sdk/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/amqpwrap" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" "github.com/Azure/go-amqp" @@ -18,7 +19,7 @@ type FakeNS struct { recovered uint64 clientRevisions []uint64 RPCLink RPCLink - Session AMQPSessionCloser + Session amqpwrap.AMQPSession AMQPLinks *FakeAMQPLinks CloseCalled int @@ -30,7 +31,10 @@ type FakeAMQPSender struct { } type FakeAMQPSession struct { - AMQPSessionCloser + amqpwrap.AMQPSession + + NewReceiverFn func(opts ...amqp.LinkOption) (AMQPReceiverCloser, error) + closed int } @@ -54,7 +58,8 @@ type FakeAMQPLinks struct { type FakeAMQPReceiver struct { AMQPReceiver - Closed int + Closed int + CloseFn func(ctx context.Context) error DrainCalled int DrainCreditImpl func(ctx context.Context) error @@ -139,6 +144,11 @@ func (r *FakeAMQPReceiver) Prefetched(ctx context.Context) (*amqp.Message, error func (r *FakeAMQPReceiver) Close(ctx context.Context) error { r.Closed++ + + if r.CloseFn != nil { + return r.CloseFn(ctx) + } + return nil } @@ -189,6 +199,10 @@ func (s *FakeAMQPSender) Close(ctx context.Context) error { return nil } +func (s *FakeAMQPSession) NewReceiver(opts ...amqp.LinkOption) (AMQPReceiverCloser, error) { + return s.NewReceiverFn(opts...) +} + func (s *FakeAMQPSession) Close(ctx context.Context) error { s.closed++ return nil @@ -207,7 +221,7 @@ func (ns *FakeNS) GetEntityAudience(entityPath string) string { return fmt.Sprintf("audience: %s", entityPath) } -func (ns *FakeNS) NewAMQPSession(ctx context.Context) (AMQPSessionCloser, uint64, error) { +func (ns *FakeNS) NewAMQPSession(ctx context.Context) (amqpwrap.AMQPSession, uint64, error) { return ns.Session, ns.recovered + 100, nil } diff --git a/sdk/messaging/azservicebus/internal/amqpwrap/amqpwrap.go b/sdk/messaging/azservicebus/internal/amqpwrap/amqpwrap.go index 2ec0f80c01eb..d1f979fd566f 100644 --- a/sdk/messaging/azservicebus/internal/amqpwrap/amqpwrap.go +++ b/sdk/messaging/azservicebus/internal/amqpwrap/amqpwrap.go @@ -78,22 +78,22 @@ func (w *AMQPClientWrapper) NewSession(opts ...amqp.SessionOption) (AMQPSession, } return &AMQPSessionWrapper{ - inner: sess, + Inner: sess, }, nil } type AMQPSessionWrapper struct { - inner *amqp.Session + Inner *amqp.Session } func (w *AMQPSessionWrapper) Close(ctx context.Context) error { - return w.inner.Close(ctx) + return w.Inner.Close(ctx) } func (w *AMQPSessionWrapper) NewReceiver(opts ...amqp.LinkOption) (AMQPReceiverCloser, error) { - return w.inner.NewReceiver(opts...) + return w.Inner.NewReceiver(opts...) } func (w *AMQPSessionWrapper) NewSender(opts ...amqp.LinkOption) (AMQPSenderCloser, error) { - return w.inner.NewSender(opts...) + return w.Inner.NewSender(opts...) } diff --git a/sdk/messaging/azservicebus/internal/namespace.go b/sdk/messaging/azservicebus/internal/namespace.go index 901ffbff5d2b..825f46edac95 100644 --- a/sdk/messaging/azservicebus/internal/namespace.go +++ b/sdk/messaging/azservicebus/internal/namespace.go @@ -70,7 +70,7 @@ type NamespaceWithNewAMQPLinks interface { // NamespaceForAMQPLinks is the Namespace surface needed for the internals of AMQPLinks. type NamespaceForAMQPLinks interface { NegotiateClaim(ctx context.Context, entityPath string) (context.CancelFunc, <-chan struct{}, error) - NewAMQPSession(ctx context.Context) (AMQPSessionCloser, uint64, error) + NewAMQPSession(ctx context.Context) (amqpwrap.AMQPSession, uint64, error) NewRPCLink(ctx context.Context, managementPath string) (RPCLink, error) GetEntityAudience(entityPath string) string Recover(ctx context.Context, clientRevision uint64) (bool, error) @@ -192,7 +192,7 @@ func (ns *Namespace) newClient(ctx context.Context) (*amqp.Client, error) { // NewAMQPSession creates a new AMQP session with the internally cached *amqp.Client. // Returns a closeable AMQP session and the current client revision. -func (ns *Namespace) NewAMQPSession(ctx context.Context) (AMQPSessionCloser, uint64, error) { +func (ns *Namespace) NewAMQPSession(ctx context.Context) (amqpwrap.AMQPSession, uint64, error) { client, clientRevision, err := ns.GetAMQPClientImpl(ctx) if err != nil { @@ -205,7 +205,7 @@ func (ns *Namespace) NewAMQPSession(ctx context.Context) (AMQPSessionCloser, uin return nil, 0, err } - return session, clientRevision, err + return &amqpwrap.AMQPSessionWrapper{Inner: session}, clientRevision, err } // NewRPCLink creates a new amqp-common *rpc.Link with the internally cached *amqp.Client. diff --git a/sdk/messaging/azservicebus/receiver.go b/sdk/messaging/azservicebus/receiver.go index 208abf7db2af..e413ecaf5dec 100644 --- a/sdk/messaging/azservicebus/receiver.go +++ b/sdk/messaging/azservicebus/receiver.go @@ -12,10 +12,10 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/internal/log" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/amqpwrap" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" "github.com/Azure/go-amqp" - "github.com/devigned/tab" ) // ReceiveMode represents the lock style to use for a receiver - either @@ -120,7 +120,7 @@ type newReceiverArgs struct { entity entity cleanupOnClose func() getRecoveryKindFunc func(err error) internal.RecoveryKind - newLinkFn func(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) + newLinkFn func(ctx context.Context, session amqpwrap.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) } func newReceiver(args newReceiverArgs, options *ReceiverOptions) (*Receiver, error) { @@ -164,7 +164,7 @@ func newReceiver(args newReceiverArgs, options *ReceiverOptions) (*Receiver, err return receiver, nil } -func (r *Receiver) newReceiverLink(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) { +func (r *Receiver) newReceiverLink(ctx context.Context, session amqpwrap.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) { linkOptions := createLinkOptions(r.receiveMode, r.entityPath) link, err := createReceiverLink(ctx, session, linkOptions) return nil, link, err @@ -508,15 +508,47 @@ func (e *entity) SetSubQueue(subQueue SubQueue) error { return fmt.Errorf("unknown SubQueue %d", subQueue) } -func createReceiverLink(ctx context.Context, session internal.AMQPSession, linkOptions []amqp.LinkOption) (internal.AMQPReceiverCloser, error) { - amqpReceiver, err := session.NewReceiver(linkOptions...) - - if err != nil { - tab.For(ctx).Error(err) - return nil, err +func createReceiverLink(ctx context.Context, session amqpwrap.AMQPSession, linkOptions []amqp.LinkOption) (internal.AMQPReceiverCloser, error) { + // If you're doing an AcceptNextSession it's possible for this call to take a long time before timing out + // on its own (it's by design - it's waiting for any empty session to become available). + type ret = struct { + Receiver internal.AMQPReceiverCloser + Err error } - return amqpReceiver, nil + done := make(chan ret) + + go func(ctx context.Context) { + defer close(done) + + tmpReceiver, tmpErr := session.NewReceiver(linkOptions...) + + if tmpErr != nil { + done <- ret{Err: tmpErr} + return + } + + select { + case <-ctx.Done(): + // `createReceiverLink` will have already returned with a cancellation based error, + // so this goroutine just needs to make sure we close this link that nobody is going + // to use. + _ = tmpReceiver.Close(context.Background()) + return + default: + done <- ret{Receiver: tmpReceiver} + } + }(ctx) + + select { + case data := <-done: + return data.Receiver, data.Err + case <-ctx.Done(): + // we'll early exit if cancelled - the goroutine above + // will just close the no-longer-needed link if/when it + // returns successfully. + return nil, ctx.Err() + } } func createLinkOptions(mode ReceiveMode, entityPath string) []amqp.LinkOption { diff --git a/sdk/messaging/azservicebus/receiver_unit_test.go b/sdk/messaging/azservicebus/receiver_unit_test.go index 9ae75928d545..a4e7ef1af7ff 100644 --- a/sdk/messaging/azservicebus/receiver_unit_test.go +++ b/sdk/messaging/azservicebus/receiver_unit_test.go @@ -299,6 +299,36 @@ func TestReceiver_ReceiveMessages_SomeMessagesAndError(t *testing.T) { require.Equal(t, 1, fakeAMQPLinks.CloseIfNeededCalled, "prefetch is called") } +func TestReceiver_CanCancelLinkCreation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + receiverWasClosedCh := make(chan struct{}) + + fakeReceiver := &internal.FakeAMQPReceiver{ + CloseFn: func(ctx context.Context) error { + close(receiverWasClosedCh) + return nil + }, + } + + session := &internal.FakeAMQPSession{ + NewReceiverFn: func(opts ...amqp.LinkOption) (internal.AMQPReceiverCloser, error) { + // simulate the client cancelling while we're stuck attempting to get the + // session receiver link. + cancel() + return fakeReceiver, nil + }, + } + + receiver, err := createReceiverLink(ctx, session, []amqp.LinkOption{}) + require.Nil(t, receiver) + require.ErrorIs(t, err, context.Canceled) + + // also, the receiver we returned should be closed as part of the gourtine + // unwinding. + <-receiverWasClosedCh +} + func TestReceiverCancellationUnitTests(t *testing.T) { t.Run("ImmediatelyCancelled", func(t *testing.T) { r := &Receiver{ diff --git a/sdk/messaging/azservicebus/sender.go b/sdk/messaging/azservicebus/sender.go index 1ab0bbec6da2..0da61f689600 100644 --- a/sdk/messaging/azservicebus/sender.go +++ b/sdk/messaging/azservicebus/sender.go @@ -8,6 +8,7 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/amqpwrap" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/tracing" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" "github.com/Azure/go-amqp" @@ -146,7 +147,7 @@ func (s *Sender) scheduleAMQPMessages(ctx context.Context, messages []*amqp.Mess return sequenceNumbers, err } -func (sender *Sender) createSenderLink(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) { +func (sender *Sender) createSenderLink(ctx context.Context, session amqpwrap.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) { amqpSender, err := session.NewSender( amqp.LinkSenderSettle(amqp.ModeMixed), amqp.LinkReceiverSettle(amqp.ModeFirst), diff --git a/sdk/messaging/azservicebus/session_receiver.go b/sdk/messaging/azservicebus/session_receiver.go index 150cc65e9e01..2bdee141d539 100644 --- a/sdk/messaging/azservicebus/session_receiver.go +++ b/sdk/messaging/azservicebus/session_receiver.go @@ -9,6 +9,7 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/amqpwrap" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus/internal/utils" "github.com/Azure/go-amqp" ) @@ -77,7 +78,7 @@ func newSessionReceiver(ctx context.Context, sessionID *string, ns internal.Name return sessionReceiver, nil } -func (r *SessionReceiver) newLink(ctx context.Context, session internal.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) { +func (r *SessionReceiver) newLink(ctx context.Context, session amqpwrap.AMQPSession) (internal.AMQPSenderCloser, internal.AMQPReceiverCloser, error) { const sessionFilterName = "com.microsoft:session-filter" const code = uint64(0x00000137000000C)