diff --git a/packages/firestore/src/platform/node/grpc_connection.ts b/packages/firestore/src/platform/node/grpc_connection.ts index 8edf2ea1801..f8b15440ffa 100644 --- a/packages/firestore/src/platform/node/grpc_connection.ts +++ b/packages/firestore/src/platform/node/grpc_connection.ts @@ -153,11 +153,11 @@ export class GrpcConnection implements Connection { path: string, request: Req, authToken: Token | null, - appCheckToken: Token | null + appCheckToken: Token | null, + expectedResponseCount?: number ): Promise { const results: Resp[] = []; const responseDeferred = new Deferred(); - logDebug( LOG_TAG, `RPC '${rpcName}' invoked (streaming) with request:`, @@ -172,13 +172,24 @@ export class GrpcConnection implements Connection { ); const jsonRequest = { ...request, database: this.databasePath }; const stream = stub[rpcName](jsonRequest, metadata); + let callbackFired = false; stream.on('data', (response: Resp) => { logDebug(LOG_TAG, `RPC ${rpcName} received result:`, response); results.push(response); + if ( + expectedResponseCount !== undefined && + results.length === expectedResponseCount + ) { + callbackFired = true; + responseDeferred.resolve(results); + } }); stream.on('end', () => { logDebug(LOG_TAG, `RPC '${rpcName}' completed.`); - responseDeferred.resolve(results); + if (!callbackFired) { + callbackFired = true; + responseDeferred.resolve(results); + } }); stream.on('error', (grpcError: grpc.ServiceError) => { logDebug(LOG_TAG, `RPC '${rpcName}' failed with error:`, grpcError); diff --git a/packages/firestore/src/remote/connection.ts b/packages/firestore/src/remote/connection.ts index 1a88c5c48c8..19a54d37c73 100644 --- a/packages/firestore/src/remote/connection.ts +++ b/packages/firestore/src/remote/connection.ts @@ -68,7 +68,8 @@ export interface Connection { path: string, request: Req, authToken: Token | null, - appCheckToken: Token | null + appCheckToken: Token | null, + expectedResponseCount?: number ): Promise; /** diff --git a/packages/firestore/src/remote/datastore.ts b/packages/firestore/src/remote/datastore.ts index e6ee7710a04..0d96661f6b1 100644 --- a/packages/firestore/src/remote/datastore.ts +++ b/packages/firestore/src/remote/datastore.ts @@ -120,7 +120,8 @@ class DatastoreImpl extends Datastore { invokeStreamingRPC( rpcName: string, path: string, - request: Req + request: Req, + expectedResponseCount?: number ): Promise { this.verifyInitialized(); return Promise.all([ @@ -133,7 +134,8 @@ class DatastoreImpl extends Datastore { path, request, authToken, - appCheckToken + appCheckToken, + expectedResponseCount ); }) .catch((error: FirestoreError) => { @@ -194,7 +196,7 @@ export async function invokeBatchGetDocumentsRpc( const response = await datastoreImpl.invokeStreamingRPC< ProtoBatchGetDocumentsRequest, ProtoBatchGetDocumentsResponse - >('BatchGetDocuments', path, request); + >('BatchGetDocuments', path, request, keys.length); const docs = new Map(); response.forEach(proto => { diff --git a/packages/firestore/src/remote/rest_connection.ts b/packages/firestore/src/remote/rest_connection.ts index 524e52c385a..9c318486620 100644 --- a/packages/firestore/src/remote/rest_connection.ts +++ b/packages/firestore/src/remote/rest_connection.ts @@ -104,7 +104,8 @@ export abstract class RestConnection implements Connection { path: string, request: Req, authToken: Token | null, - appCheckToken: Token | null + appCheckToken: Token | null, + expectedResponseCount?: number ): Promise { // The REST API automatically aggregates all of the streamed results, so we // can just use the normal invoke() method.