Skip to content

Commit

Permalink
fix: Wait for endpoint creation to identify user (#13353)
Browse files Browse the repository at this point in the history
* fix: Wait for endpoint creation to identify user

* Update unit test to assert underlying error
  • Loading branch information
cshfang committed May 10, 2024
1 parent 57a77b5 commit f17cdf0
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 33 deletions.
2 changes: 1 addition & 1 deletion packages/core/src/providers/pinpoint/index.ts
Expand Up @@ -7,4 +7,4 @@ export {
PinpointServiceOptions,
UpdateEndpointException,
} from './types';
export { resolveEndpointId } from './utils';
export { getEndpointId, resolveEndpointId } from './utils';
@@ -1,7 +1,10 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

import { updateEndpoint } from '@aws-amplify/core/internals/providers/pinpoint';
import {
getEndpointId,
updateEndpoint,
} from '@aws-amplify/core/internals/providers/pinpoint';
import { assertIsInitialized } from '../../../../../src/pushNotifications/errors/errorHelpers';
import { identifyUser } from '../../../../../src/pushNotifications/providers/pinpoint/apis/identifyUser.native';
import { IdentifyUserInput } from '../../../../../src/pushNotifications/providers/pinpoint/types';
Expand All @@ -11,6 +14,7 @@ import {
} from '../../../../../src/pushNotifications/utils';
import {
getChannelType,
getInflightDeviceRegistration,
resolveConfig,
} from '../../../../../src/pushNotifications/providers/pinpoint/utils';
import {
Expand All @@ -32,11 +36,14 @@ describe('identifyUser (native)', () => {
// assert mocks
const mockAssertIsInitialized = assertIsInitialized as jest.Mock;
const mockGetChannelType = getChannelType as jest.Mock;
const mockUpdateEndpoint = updateEndpoint as jest.Mock;
const mockGetEndpointId = getEndpointId as jest.Mock;
const mockGetInflightDeviceRegistration =
getInflightDeviceRegistration as jest.Mock;
const mockGetPushNotificationUserAgentString =
getPushNotificationUserAgentString as jest.Mock;
const mockResolveConfig = resolveConfig as jest.Mock;
const mockResolveCredentials = resolveCredentials as jest.Mock;
const mockUpdateEndpoint = updateEndpoint as jest.Mock;

beforeAll(() => {
mockGetChannelType.mockReturnValue(channelType);
Expand All @@ -47,7 +54,9 @@ describe('identifyUser (native)', () => {

afterEach(() => {
mockAssertIsInitialized.mockReset();
mockGetEndpointId.mockReset();
mockUpdateEndpoint.mockReset();
mockGetInflightDeviceRegistration.mockClear();
});

it('must be initialized', async () => {
Expand Down Expand Up @@ -111,4 +120,24 @@ describe('identifyUser (native)', () => {
};
await expect(identifyUser(input)).rejects.toBeDefined();
});

it('awaits device registration promise when endpoint is not present', async () => {
const input: IdentifyUserInput = {
userId: 'user-id',
userProfile: {},
};
mockGetEndpointId.mockResolvedValue(undefined);
await identifyUser(input);
expect(mockGetInflightDeviceRegistration).toHaveBeenCalled();
});

it('does not await device registration promise when endpoint is present', async () => {
const input: IdentifyUserInput = {
userId: 'user-id',
userProfile: {},
};
mockGetEndpointId.mockResolvedValue('endpoint-id');
await identifyUser(input);
expect(mockGetInflightDeviceRegistration).not.toHaveBeenCalled();
});
});
Expand Up @@ -13,7 +13,11 @@ import {
resolveCredentials,
setToken,
} from '../../../../../src/pushNotifications/utils';
import { resolveConfig } from '../../../../../src/pushNotifications//providers/pinpoint/utils';
import {
rejectInflightDeviceRegistration,
resolveConfig,
resolveInflightDeviceRegistration,
} from '../../../../../src/pushNotifications//providers/pinpoint/utils';
import {
completionHandlerId,
credentials,
Expand Down Expand Up @@ -56,8 +60,12 @@ describe('initializePushNotifications (native)', () => {
const mockGetToken = getToken as jest.Mock;
const mockInitialize = initialize as jest.Mock;
const mockIsInitialized = isInitialized as jest.Mock;
const mockRejectInflightDeviceRegistration =
rejectInflightDeviceRegistration as jest.Mock;
const mockResolveCredentials = resolveCredentials as jest.Mock;
const mockResolveConfig = resolveConfig as jest.Mock;
const mockResolveInflightDeviceRegistration =
resolveInflightDeviceRegistration as jest.Mock;
const mockSetToken = setToken as jest.Mock;
const mockNotifyEventListeners = notifyEventListeners as jest.Mock;
const mockNotifyEventListenersAndAwaitHandlers =
Expand Down Expand Up @@ -114,6 +122,8 @@ describe('initializePushNotifications (native)', () => {
mockEventListenerRemover.remove.mockClear();
mockNotifyEventListeners.mockClear();
mockNotifyEventListenersAndAwaitHandlers.mockClear();
mockRejectInflightDeviceRegistration.mockClear();
mockResolveInflightDeviceRegistration.mockClear();
});

it('only enables once', () => {
Expand Down Expand Up @@ -236,29 +246,29 @@ describe('initializePushNotifications (native)', () => {

describe('token received', () => {
it('registers and calls token received listener', done => {
expect.assertions(6);
mockGetToken.mockReturnValue(undefined);
mockAddTokenEventListener.mockImplementation(
async (heardEvent, handler) => {
if (heardEvent === NativeEvent.TOKEN_RECEIVED) {
await handler(pushToken);
expect(mockAddTokenEventListener).toHaveBeenCalledWith(
NativeEvent.TOKEN_RECEIVED,
expect.any(Function),
);
expect(mockSetToken).toHaveBeenCalledWith(pushToken);
expect(mockNotifyEventListeners).toHaveBeenCalledWith(
'tokenReceived',
pushToken,
);
expect(mockUpdateEndpoint).toHaveBeenCalled();
expect(mockResolveInflightDeviceRegistration).toHaveBeenCalled();
expect(mockRejectInflightDeviceRegistration).not.toHaveBeenCalled();
done();
}
},
);
mockUpdateEndpoint.mockImplementation(() => {
expect(mockUpdateEndpoint).toHaveBeenCalled();
done();
});
initializePushNotifications();

expect(mockAddTokenEventListener).toHaveBeenCalledWith(
NativeEvent.TOKEN_RECEIVED,
expect.any(Function),
);
expect(mockSetToken).toHaveBeenCalledWith(pushToken);
expect(mockNotifyEventListeners).toHaveBeenCalledWith(
'tokenReceived',
pushToken,
);
});

it('should not be invoke token received listener with the same token twice', () => {
Expand Down Expand Up @@ -292,13 +302,18 @@ describe('initializePushNotifications (native)', () => {
});

it('throws if device registration fails', done => {
expect.assertions(3);
mockUpdateEndpoint.mockImplementation(() => {
throw new Error();
});
mockAddTokenEventListener.mockImplementation(
async (heardEvent, handler) => {
if (heardEvent === NativeEvent.TOKEN_RECEIVED) {
await expect(handler(pushToken)).rejects.toThrow();
expect(
mockResolveInflightDeviceRegistration,
).not.toHaveBeenCalled();
expect(mockRejectInflightDeviceRegistration).toHaveBeenCalled();
done();
}
},
Expand Down
@@ -0,0 +1,73 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

import {
getInflightDeviceRegistration,
rejectInflightDeviceRegistration,
resolveInflightDeviceRegistration,
} from '../../../../../src/pushNotifications/providers/pinpoint/utils/inflightDeviceRegistration';
import { InflightDeviceRegistration } from '../../../../../src/pushNotifications/providers/pinpoint/types';

describe('inflightDeviceRegistration', () => {
describe('resolveInflightDeviceRegistration', () => {
let getInflightDeviceRegistration: () => InflightDeviceRegistration;
let resolveInflightDeviceRegistration: () => void;
jest.isolateModules(() => {
({
getInflightDeviceRegistration,
resolveInflightDeviceRegistration,
} = require('../../../../../src/pushNotifications/providers/pinpoint/utils/inflightDeviceRegistration'));
});

it('creates a pending promise on module load', () => {
expect(getInflightDeviceRegistration()).toBeDefined();
});

it('should resolve the promise', async () => {
const blockedFunction = jest.fn();
const promise = getInflightDeviceRegistration()?.then(() => {
blockedFunction();
});

expect(blockedFunction).not.toHaveBeenCalled();
resolveInflightDeviceRegistration();
await promise;
expect(blockedFunction).toHaveBeenCalled();
});

it('should have released the promise from memory', () => {
expect(getInflightDeviceRegistration()).toBeUndefined();
});
});

describe('rejectInflightDeviceRegistration', () => {
let getInflightDeviceRegistration: () => InflightDeviceRegistration;
let rejectInflightDeviceRegistration: (underlyingError: unknown) => void;
jest.isolateModules(() => {
({
getInflightDeviceRegistration,
rejectInflightDeviceRegistration,
} = require('../../../../../src/pushNotifications/providers/pinpoint/utils/inflightDeviceRegistration'));
});

it('creates a pending promise on module load', () => {
expect(getInflightDeviceRegistration()).toBeDefined();
});

it('should reject the promise', async () => {
const underlyingError = new Error('underlying-error');
const blockedFunction = jest.fn();
const promise = getInflightDeviceRegistration()?.then(() => {
blockedFunction();
});

expect(blockedFunction).not.toHaveBeenCalled();
rejectInflightDeviceRegistration(underlyingError);
await expect(promise).rejects.toMatchObject({
name: 'DeviceRegistrationFailed',
underlyingError,
});
expect(blockedFunction).not.toHaveBeenCalled();
});
});
});
Expand Up @@ -2,14 +2,21 @@
// SPDX-License-Identifier: Apache-2.0

import { PushNotificationAction } from '@aws-amplify/core/internals/utils';
import { updateEndpoint } from '@aws-amplify/core/internals/providers/pinpoint';
import {
getEndpointId,
updateEndpoint,
} from '@aws-amplify/core/internals/providers/pinpoint';

import { assertIsInitialized } from '../../../errors/errorHelpers';
import {
getPushNotificationUserAgentString,
resolveCredentials,
} from '../../../utils';
import { getChannelType, resolveConfig } from '../utils';
import {
getChannelType,
getInflightDeviceRegistration,
resolveConfig,
} from '../utils';
import { IdentifyUser } from '../types';

export const identifyUser: IdentifyUser = async ({
Expand All @@ -21,6 +28,10 @@ export const identifyUser: IdentifyUser = async ({
const { credentials, identityId } = await resolveCredentials();
const { appId, region } = resolveConfig();
const { address, optOut, userAttributes } = options ?? {};
if (!(await getEndpointId(appId, 'PushNotification'))) {
// if there is no cached endpoint id, wait for successful endpoint creation before continuing
await getInflightDeviceRegistration();
}
await updateEndpoint({
address,
channelType: getChannelType(),
Expand Down
Expand Up @@ -23,7 +23,9 @@ import {
import {
createMessageEventRecorder,
getChannelType,
rejectInflightDeviceRegistration,
resolveConfig,
resolveInflightDeviceRegistration,
} from '../utils';

const {
Expand Down Expand Up @@ -203,16 +205,24 @@ const addAnalyticsListeners = (): void => {
const registerDevice = async (address: string): Promise<void> => {
const { credentials, identityId } = await resolveCredentials();
const { appId, region } = resolveConfig();
await updateEndpoint({
address,
appId,
category: 'PushNotification',
credentials,
region,
channelType: getChannelType(),
identityId,
userAgentValue: getPushNotificationUserAgentString(
PushNotificationAction.InitializePushNotifications,
),
});
try {
await updateEndpoint({
address,
appId,
category: 'PushNotification',
credentials,
region,
channelType: getChannelType(),
identityId,
userAgentValue: getPushNotificationUserAgentString(
PushNotificationAction.InitializePushNotifications,
),
});
// always resolve inflight device registration promise here even though the promise is only awaited on by
// `identifyUser` when no endpoint is found in the cache
resolveInflightDeviceRegistration();
} catch (underlyingError) {
rejectInflightDeviceRegistration(underlyingError);
throw underlyingError;
}
};
Expand Up @@ -37,4 +37,8 @@ export {
OnTokenReceivedOutput,
} from './outputs';
export { IdentifyUserOptions } from './options';
export { ChannelType } from './pushNotifications';
export {
ChannelType,
InflightDeviceRegistration,
InflightDeviceRegistrationResolver,
} from './pushNotifications';
Expand Up @@ -3,4 +3,13 @@

import { updateEndpoint } from '@aws-amplify/core/internals/providers/pinpoint';

import { PushNotificationError } from '../../../errors';

export type ChannelType = Parameters<typeof updateEndpoint>[0]['channelType'];

export type InflightDeviceRegistration = Promise<void> | undefined;

export interface InflightDeviceRegistrationResolver {
resolve?(): void;
reject?(error: PushNotificationError): void;
}
Expand Up @@ -4,4 +4,9 @@
export { createMessageEventRecorder } from './createMessageEventRecorder';
export { getAnalyticsEvent } from './getAnalyticsEvent';
export { getChannelType } from './getChannelType';
export {
getInflightDeviceRegistration,
rejectInflightDeviceRegistration,
resolveInflightDeviceRegistration,
} from './inflightDeviceRegistration';
export { resolveConfig } from './resolveConfig';

0 comments on commit f17cdf0

Please sign in to comment.