Skip to content

Commit

Permalink
Use named args and improve types
Browse files Browse the repository at this point in the history
  • Loading branch information
FrederikBolding committed Jul 20, 2022
1 parent 823b7b5 commit eaa5c25
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 91 deletions.
18 changes: 8 additions & 10 deletions packages/controllers/src/services/AbstractExecutionService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ export type ExecutionServiceArgs = {
terminationTimeout?: number;
};

export type SnapRpcHookArgs = {
origin: string;
handler: HandlerType;
request: Record<string, unknown>;
};

// The snap is the callee
export type SnapRpcHook = (
origin: string,
handler: HandlerType,
request: Record<string, unknown>,
) => Promise<unknown>;
export type SnapRpcHook = (options: SnapRpcHookArgs) => Promise<unknown>;

export type JobStreams = {
command: Duplex;
Expand Down Expand Up @@ -382,11 +384,7 @@ export abstract class AbstractExecutionService<WorkerType>
}

protected _createSnapHooks(snapId: string, workerId: string) {
const rpcHook = async (
origin: string,
handler: HandlerType,
request: Record<string, unknown>,
) => {
const rpcHook = async ({ origin, handler, request }: SnapRpcHookArgs) => {
return await this._command(workerId, {
id: nanoid(),
jsonrpc: '2.0',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,15 @@ describe('IframeExecutionService', () => {

assert(handler !== undefined);

const result = await handler('foo', HandlerType.onRpcRequest, {
jsonrpc: '2.0',
id: 1,
method: 'foobar',
params: [],
const result = await handler({
origin: 'foo',
handler: HandlerType.onRpcRequest,
request: {
jsonrpc: '2.0',
id: 1,
method: 'foobar',
params: [],
},
});

expect(result).toBe(blockNumber);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,15 @@ describe('NodeProcessExecutionService', () => {
assert(hook !== undefined);

await expect(
hook('fooOrigin', ON_RPC_REQUEST, {
jsonrpc: '2.0',
method: 'foo',
params: {},
id: 1,
hook({
origin: 'fooOrigin',
handler: ON_RPC_REQUEST,
request: {
jsonrpc: '2.0',
method: 'foo',
params: {},
id: 1,
},
}),
).rejects.toThrow('foobar');
await service.terminateAllSnaps();
Expand Down Expand Up @@ -190,11 +194,15 @@ describe('NodeProcessExecutionService', () => {
});

expect(
await hook('fooOrigin', ON_RPC_REQUEST, {
jsonrpc: '2.0',
method: '',
params: {},
id: 1,
await hook({
origin: 'fooOrigin',
handler: ON_RPC_REQUEST,
request: {
jsonrpc: '2.0',
method: '',
params: {},
id: 1,
},
}),
).toBe('foo');

Expand Down Expand Up @@ -261,11 +269,15 @@ describe('NodeProcessExecutionService', () => {

assert(handler !== undefined);

const result = await handler('foo', ON_RPC_REQUEST, {
jsonrpc: '2.0',
id: 1,
method: 'foobar',
params: [],
const result = await handler({
origin: 'foo',
handler: ON_RPC_REQUEST,
request: {
jsonrpc: '2.0',
id: 1,
method: 'foobar',
params: [],
},
});

expect(result).toBe(blockNumber);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,15 @@ describe('NodeThreadExecutionService', () => {
assert(hook !== undefined);

await expect(
hook('fooOrigin', ON_RPC_REQUEST, {
jsonrpc: '2.0',
method: 'foo',
params: {},
id: 1,
hook({
origin: 'fooOrigin',
handler: ON_RPC_REQUEST,
request: {
jsonrpc: '2.0',
method: 'foo',
params: {},
id: 1,
},
}),
).rejects.toThrow('foobar');
await service.terminateAllSnaps();
Expand Down Expand Up @@ -190,11 +194,15 @@ describe('NodeThreadExecutionService', () => {
});

expect(
await hook('fooOrigin', ON_RPC_REQUEST, {
jsonrpc: '2.0',
method: '',
params: {},
id: 1,
await hook({
origin: 'fooOrigin',
handler: ON_RPC_REQUEST,
request: {
jsonrpc: '2.0',
method: '',
params: {},
id: 1,
},
}),
).toBe('foo');

Expand Down Expand Up @@ -261,11 +269,15 @@ describe('NodeThreadExecutionService', () => {

assert(handler !== undefined);

const result = await handler('foo', ON_RPC_REQUEST, {
jsonrpc: '2.0',
id: 1,
method: 'foobar',
params: [],
const result = await handler({
origin: 'foo',
handler: ON_RPC_REQUEST,
request: {
jsonrpc: '2.0',
id: 1,
method: 'foobar',
params: [],
},
});

expect(result).toBe(blockNumber);
Expand Down
18 changes: 11 additions & 7 deletions packages/controllers/src/snaps/SnapController.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ import {
} from '@metamask/snap-utils';
import { HandlerType } from '@metamask/execution-environments';
import { ExecutionService } from '../services/ExecutionService';
import { NodeThreadExecutionService, setupMultiplex } from '../services';
import {
NodeThreadExecutionService,
setupMultiplex,
SnapRpcHookArgs,
} from '../services';
import { delay } from '../utils';

import { SnapEndowments } from './endowments';
Expand Down Expand Up @@ -272,7 +276,7 @@ class ExecutionEnvironmentStub implements ExecutionService {
}

async getRpcRequestHandler(_snapId: string) {
return (_origin: any, _handler: any, request: Record<string, unknown>) => {
return ({ request }: SnapRpcHookArgs) => {
return new Promise((resolve) => {
const results = `${request.method}${request.id}`;
resolve(results);
Expand Down Expand Up @@ -1587,15 +1591,15 @@ describe('SnapController', () => {
});

expect(mockMessageHandler).toHaveBeenCalledTimes(1);
expect(mockMessageHandler).toHaveBeenCalledWith(
'foo.com',
HandlerType.onRpcRequest,
{
expect(mockMessageHandler).toHaveBeenCalledWith({
origin: 'foo.com',
handler: HandlerType.onRpcRequest,
request: {
id: 1,
method: 'bar',
jsonrpc: '2.0',
},
);
});
await service.terminateAllSnaps();
});

Expand Down
67 changes: 28 additions & 39 deletions packages/controllers/src/snaps/SnapController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ import {
ExecuteSnapAction,
ExecutionServiceEvents,
GetRpcRequestHandlerAction,
SnapRpcHook,
SnapRpcHookArgs,
TerminateAllSnapsAction,
TerminateSnapAction,
} from '..';
Expand Down Expand Up @@ -188,13 +190,7 @@ export interface SnapRuntimeData {
/**
* RPC handler designated for the Snap
*/
rpcHandler:
| null
| ((
origin: string,
handlerName: HandlerType,
request: Record<string, unknown>,
) => Promise<unknown>);
rpcHandler: null | SnapRpcHook;
}

/**
Expand Down Expand Up @@ -2058,12 +2054,12 @@ export class SnapController extends BaseController<
origin: string,
request: Record<string, unknown>,
): Promise<unknown> {
return this.handleRequest(
return this.handleRequest({
snapId,
origin,
HandlerType.onRpcRequest,
handler: HandlerType.onRpcRequest,
request,
);
});
}

/**
Expand All @@ -2079,36 +2075,37 @@ export class SnapController extends BaseController<
origin: string,
request: Record<string, unknown>,
): Promise<unknown> {
return this.handleRequest(
return this.handleRequest({
snapId,
origin,
HandlerType.getTransactionInsight,
handler: HandlerType.getTransactionInsight,
request,
);
});
}

/**
* Passes a JSON-RPC request object to the RPC handler function of a snap.
*
* @param snapId - The ID of the recipient snap.
* @param origin - The origin of the RPC request.
* @param handlerName - The handler to trigger on the snap for the request.
* @param request - The JSON-RPC request object.
* @param options - A bag of options.
* @param options.snapId - The ID of the recipient snap.
* @param options.origin - The origin of the RPC request.
* @param options.handler - The handler to trigger on the snap for the request.
* @param options.request - The JSON-RPC request object.
* @returns The result of the JSON-RPC request.
*/
private async handleRequest(
snapId: SnapId,
origin: string,
handlerName: HandlerType,
request: Record<string, unknown>,
): Promise<unknown> {
private async handleRequest({
snapId,
origin,
handler: handlerType,
request,
}: SnapRpcHookArgs & { snapId: SnapId }): Promise<unknown> {
const handler = await this.getRpcRequestHandler(snapId);
if (!handler) {
throw new Error(
`Snap RPC message handler not found for snap "${snapId}".`,
);
}
return handler(origin, handlerName, request);
return handler({ origin, handler: handlerType, request });
}

/**
Expand All @@ -2117,15 +2114,7 @@ export class SnapController extends BaseController<
* @param snapId - The id of the Snap whose message handler to get.
* @returns The RPC handler for the given snap.
*/
private async getRpcRequestHandler(
snapId: SnapId,
): Promise<
(
origin: string,
handler: HandlerType,
request: Record<string, unknown>,
) => Promise<unknown>
> {
private async getRpcRequestHandler(snapId: SnapId): Promise<SnapRpcHook> {
const runtime = this._getSnapRuntimeData(snapId);
const existingHandler = runtime?.rpcHandler;
if (existingHandler) {
Expand All @@ -2137,11 +2126,11 @@ export class SnapController extends BaseController<
// because otherwise we would lose context on the correct startPromise.
const startPromises = new Map<string, Promise<void>>();

const rpcHandler = async (
origin: string,
handlerName: HandlerType,
request: Record<string, unknown>,
) => {
const rpcHandler = async ({
origin,
handler: handlerType,
request,
}: SnapRpcHookArgs) => {
if (this.state.snaps[snapId].enabled === false) {
throw new Error(`Snap "${snapId}" is disabled.`);
}
Expand Down Expand Up @@ -2214,7 +2203,7 @@ export class SnapController extends BaseController<
try {
const result = await this._executeWithTimeout(
snapId,
handler(origin, handlerName, _request),
handler({ origin, handler: handlerType, request: _request }),
timer,
);
this._recordSnapRpcRequestFinish(snapId, request.id);
Expand Down

0 comments on commit eaa5c25

Please sign in to comment.