diff --git a/packages/auth/src/mfa/mfa_session.ts b/packages/auth/src/mfa/mfa_session.ts index f9b8452d06d..dbb1f4995fc 100644 --- a/packages/auth/src/mfa/mfa_session.ts +++ b/packages/auth/src/mfa/mfa_session.ts @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +import { AuthInternal } from '../model/auth'; import { MultiFactorSession } from '../model/public_types'; export const enum MultiFactorSessionType { @@ -31,11 +32,12 @@ interface SerializedMultiFactorSession { export class MultiFactorSessionImpl implements MultiFactorSession { private constructor( readonly type: MultiFactorSessionType, - readonly credential: string + readonly credential: string, + readonly auth?: AuthInternal, ) {} - static _fromIdtoken(idToken: string): MultiFactorSessionImpl { - return new MultiFactorSessionImpl(MultiFactorSessionType.ENROLL, idToken); + static _fromIdtoken(idToken: string, auth?: AuthInternal): MultiFactorSessionImpl { + return new MultiFactorSessionImpl(MultiFactorSessionType.ENROLL, idToken, auth); } static _fromMfaPendingCredential( diff --git a/packages/auth/src/mfa/mfa_user.test.ts b/packages/auth/src/mfa/mfa_user.test.ts index a21b0d27b51..2c0da80a11a 100644 --- a/packages/auth/src/mfa/mfa_user.test.ts +++ b/packages/auth/src/mfa/mfa_user.test.ts @@ -84,6 +84,12 @@ describe('core/mfa/mfa_user/MultiFactorUser', () => { expect(mfaSession.type).to.eq(MultiFactorSessionType.ENROLL); expect(mfaSession.credential).to.eq('access-token'); }); + it('should contain a reference to auth', async () => { + const mfaSession = (await mfaUser.getSession()) as MultiFactorSessionImpl; + expect(mfaSession.type).to.eq(MultiFactorSessionType.ENROLL); + expect(mfaSession.credential).to.eq('access-token'); + expect(mfaSession.auth).to.eq(auth); + }); }); describe('enroll', () => { diff --git a/packages/auth/src/mfa/mfa_user.ts b/packages/auth/src/mfa/mfa_user.ts index 55bd15f830b..b536d7e4fbe 100644 --- a/packages/auth/src/mfa/mfa_user.ts +++ b/packages/auth/src/mfa/mfa_user.ts @@ -49,7 +49,7 @@ export class MultiFactorUserImpl implements MultiFactorUser { } async getSession(): Promise { - return MultiFactorSessionImpl._fromIdtoken(await this.user.getIdToken()); + return MultiFactorSessionImpl._fromIdtoken(await this.user.getIdToken(), this.user.auth); } async enroll( diff --git a/packages/auth/src/platform_browser/mfa/assertions/phone.test.ts b/packages/auth/src/platform_browser/mfa/assertions/phone.test.ts index 22e64593cfd..80a4a3b6574 100644 --- a/packages/auth/src/platform_browser/mfa/assertions/phone.test.ts +++ b/packages/auth/src/platform_browser/mfa/assertions/phone.test.ts @@ -58,7 +58,7 @@ describe('platform_browser/mfa/phone', () => { describe('enroll', () => { beforeEach(() => { - session = MultiFactorSessionImpl._fromIdtoken('enrollment-id-token'); + session = MultiFactorSessionImpl._fromIdtoken('enrollment-id-token', auth); }); it('should finalize the MFA enrollment', async () => { @@ -75,6 +75,7 @@ describe('platform_browser/mfa/phone', () => { sessionInfo: 'verification-id' } }); + expect(session.auth).to.eql(auth); }); context('with display name', () => { @@ -97,6 +98,7 @@ describe('platform_browser/mfa/phone', () => { sessionInfo: 'verification-id' } }); + expect(session.auth).to.eql(auth); }); }); }); @@ -119,6 +121,7 @@ describe('platform_browser/mfa/phone', () => { sessionInfo: 'verification-id' } }); + expect(session.auth).to.eql(undefined); }); }); });