Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(common): Make ValidationPipe aware of WebSocket context #13255

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
103 changes: 103 additions & 0 deletions integration/websockets/e2e/gateway-validation-pipe.spec.ts
@@ -0,0 +1,103 @@
import { INestApplication } from '@nestjs/common';
import { WsAdapter } from '@nestjs/platform-ws';
import { Test } from '@nestjs/testing';
import * as WebSocket from 'ws';
import { ValidationPipeGateway } from '../src/validation-pipe.gateway';
import { expect } from 'chai';
import { ApplicationGateway } from '../src/app.gateway';

async function createNestApp(...gateways): Promise<INestApplication> {
const testingModule = await Test.createTestingModule({
providers: gateways,
}).compile();
const app = testingModule.createNestApplication();
app.useWebSocketAdapter(new WsAdapter(app) as any);
return app;
}

const testBody = { ws: null, app: null };

async function prepareGatewayAndClientForResponseAction(
gateway: typeof ValidationPipeGateway | ApplicationGateway,
action: () => void,
) {
testBody.app = await createNestApp(gateway);
await testBody.app.listen(3000);

testBody.ws = new WebSocket('ws://localhost:8080');
await new Promise(resolve => testBody.ws.on('open', resolve));

testBody.ws.send(
JSON.stringify({
event: 'push',
data: {
stringProp: 123,
},
}),
);

action();
}

const UNCAUGHT_EXCEPTION = 'uncaughtException';

type WsExceptionWithWrappedValidationError = {
getError: () => {
response: {
message: string[];
};
};
};

function prepareToHandleExpectedUncaughtException() {
const listeners = process.listeners(UNCAUGHT_EXCEPTION);
process.removeAllListeners(UNCAUGHT_EXCEPTION);

process.on(
UNCAUGHT_EXCEPTION,
(err: WsExceptionWithWrappedValidationError) => {
expect(err.getError().response.message[0]).to.equal(
'stringProp must be a string',
);
reattachUncaughtExceptionListeners(listeners);
},
);
}

function reattachUncaughtExceptionListeners(
listeners: NodeJS.UncaughtExceptionListener[],
) {
process.removeAllListeners(UNCAUGHT_EXCEPTION);
for (const listener of listeners) {
process.on(UNCAUGHT_EXCEPTION, listener);
}
}

describe('WebSocketGateway with ValidationPipe', () => {
it(`should throw WsException`, async () => {
prepareToHandleExpectedUncaughtException();

await prepareGatewayAndClientForResponseAction(
ValidationPipeGateway,
() => {
testBody.ws.once('message', () => {});
},
);
});

it('should return message normally', async () => {
await new Promise<void>(resolve =>
prepareGatewayAndClientForResponseAction(ApplicationGateway, async () => {
testBody.ws.once('message', msg => {
expect(JSON.parse(msg).data.stringProp).to.equal(123);
resolve();
});
}),
);
});

afterEach(function (done) {
testBody.ws.close();
testBody.app.close().then(() => done());
});
});
40 changes: 40 additions & 0 deletions integration/websockets/src/validation-pipe.gateway.ts
@@ -0,0 +1,40 @@
import {
ArgumentsHost,
Catch,
UseFilters,
UsePipes,
ValidationPipe,
} from '@nestjs/common';
import {
BaseWsExceptionFilter,
MessageBody,
SubscribeMessage,
WebSocketGateway,
} from '@nestjs/websockets';
import { IsString } from 'class-validator';

class TestModel {
@IsString()
stringProp: string;
}

@Catch()
export class AllExceptionsFilter extends BaseWsExceptionFilter {
catch(exception: unknown, host: ArgumentsHost) {
throw exception;
}
}

@WebSocketGateway(8080)
@UsePipes(new ValidationPipe())
@UseFilters(new AllExceptionsFilter())
export class ValidationPipeGateway {
@SubscribeMessage('push')
onPush(@MessageBody() data: TestModel) {
console.log('received msg');
return {
event: 'push',
data,
};
}
}
1 change: 1 addition & 0 deletions packages/common/constants.ts
Expand Up @@ -45,3 +45,4 @@ export const INJECTABLE_WATERMARK = '__injectable__';
export const CONTROLLER_WATERMARK = '__controller__';
export const CATCH_WATERMARK = '__catch__';
export const ENTRY_PROVIDER_WATERMARK = '__entryProvider__';
export const GATEWAY_METADATA = 'websockets:is_gateway';
7 changes: 7 additions & 0 deletions packages/common/decorators/core/use-pipes.decorator.ts
Expand Up @@ -3,6 +3,7 @@ import { PipeTransform } from '../../interfaces/index';
import { extendArrayMetadata } from '../../utils/extend-metadata.util';
import { isFunction } from '../../utils/shared.utils';
import { validateEach } from '../../utils/validate-each.util';
import { isTargetAware } from '../../interfaces/features/target-aware-pipe.interface';

/**
* Decorator that binds pipes to the scope of the controller or method,
Expand Down Expand Up @@ -43,6 +44,12 @@ export function UsePipes(
return descriptor;
}
validateEach(target, pipes, isPipeValid, '@UsePipes', 'pipe');

const pipesWithSetTarget = pipes.filter(pipe => isTargetAware(pipe));
pipesWithSetTarget.forEach(pipeWithSetTarget =>
pipeWithSetTarget['setTarget'](target),
);

extendArrayMetadata(PIPES_METADATA, pipes, target);
return target;
};
Expand Down
1 change: 1 addition & 0 deletions packages/common/exceptions/index.ts
Expand Up @@ -20,3 +20,4 @@ export * from './gateway-timeout.exception';
export * from './im-a-teapot.exception';
export * from './precondition-failed.exception';
export * from './misdirected.exception';
export * from './ws-exception';
27 changes: 27 additions & 0 deletions packages/common/exceptions/ws-exception.ts
@@ -0,0 +1,27 @@
import { isObject, isString } from '../utils/shared.utils';

export class WsException extends Error {
constructor(private readonly error: string | object) {
super();
this.initMessage();
}

public initMessage() {
if (isString(this.error)) {
this.message = this.error;
} else if (
isObject(this.error) &&
isString((this.error as Record<string, any>).message)
) {
this.message = (this.error as Record<string, any>).message;
} else if (this.constructor) {
this.message = this.constructor.name
.match(/[A-Z][a-z]+|[0-9]+/g)
.join(' ');
}
}

public getError(): string | object {
return this.error;
}
}
12 changes: 12 additions & 0 deletions packages/common/interfaces/features/target-aware-pipe.interface.ts
@@ -0,0 +1,12 @@
/**
* Interface describing method to set the target of the pipe decorator
*/
export interface TargetAwarePipe {
isTargetAware: true;

setTarget(target: unknown): void;
}

export function isTargetAware(pipe: unknown): pipe is TargetAwarePipe {
return pipe['isTargetAware'];
}
41 changes: 38 additions & 3 deletions packages/common/pipes/validation.pipe.ts
Expand Up @@ -19,6 +19,10 @@ import {
} from '../utils/http-error-by-code.util';
import { loadPackage } from '../utils/load-package.util';
import { isNil, isUndefined } from '../utils/shared.utils';
import { GATEWAY_METADATA } from '../constants';
import { WsException } from '../exceptions';
import { HttpException } from '../exceptions';
import { TargetAwarePipe } from '../interfaces/features/target-aware-pipe.interface';

/**
* @publicApi
Expand All @@ -44,7 +48,7 @@ let classTransformer: TransformerPackage = {} as any;
* @publicApi
*/
@Injectable()
export class ValidationPipe implements PipeTransform<any> {
export class ValidationPipe implements PipeTransform<any>, TargetAwarePipe {
protected isTransformEnabled: boolean;
protected isDetailedOutputDisabled?: boolean;
protected validatorOptions: ValidatorOptions;
Expand All @@ -53,6 +57,9 @@ export class ValidationPipe implements PipeTransform<any> {
protected expectedType: Type<any>;
protected exceptionFactory: (errors: ValidationError[]) => any;
protected validateCustomDecorators: boolean;
protected isInGatewayMode = false;
protected target: unknown;
isTargetAware = true as const;

constructor(@Optional() options?: ValidationPipeOptions) {
options = options || {};
Expand Down Expand Up @@ -82,6 +89,10 @@ export class ValidationPipe implements PipeTransform<any> {
classTransformer = this.loadTransformer(options.transformerPackage);
}

public setTarget(target: unknown) {
this.target = target;
}

protected loadValidator(
validatorPackage?: ValidatorPackage,
): ValidatorPackage {
Expand All @@ -105,6 +116,15 @@ export class ValidationPipe implements PipeTransform<any> {
}

public async transform(value: any, metadata: ArgumentMetadata) {
if (
!this.isInGatewayMode &&
this.target &&
Reflect.getMetadata(GATEWAY_METADATA, this.target)
) {
this.isInGatewayMode = true;
this.exceptionFactory = this.createExceptionFactory();
}

if (this.expectedType) {
metadata = { ...metadata, metatype: this.expectedType };
}
Expand Down Expand Up @@ -165,12 +185,27 @@ export class ValidationPipe implements PipeTransform<any> {
}

public createExceptionFactory() {
let errorConstructorWrapper: (error: unknown) => unknown | WsException = (
error: unknown,
) => error;
if (this.isInGatewayMode) {
errorConstructorWrapper = (error: unknown) => {
if (error instanceof HttpException) {
return new WsException(error);
}
};
}

return (validationErrors: ValidationError[] = []) => {
if (this.isDetailedOutputDisabled) {
return new HttpErrorByCode[this.errorHttpStatusCode]();
return errorConstructorWrapper(
new HttpErrorByCode[this.errorHttpStatusCode](),
);
}
const errors = this.flattenValidationErrors(validationErrors);
return new HttpErrorByCode[this.errorHttpStatusCode](errors);
return errorConstructorWrapper(
new HttpErrorByCode[this.errorHttpStatusCode](errors),
);
};
}

Expand Down
2 changes: 1 addition & 1 deletion packages/websockets/context/ws-context-creator.ts
Expand Up @@ -24,7 +24,7 @@ import {
} from '@nestjs/core/interceptors';
import { PipesConsumer, PipesContextCreator } from '@nestjs/core/pipes';
import { MESSAGE_METADATA, PARAM_ARGS_METADATA } from '../constants';
import { WsException } from '../errors/ws-exception';
import { WsException } from '@nestjs/common';
import { WsParamsFactory } from '../factories/ws-params-factory';
import { ExceptionFiltersContext } from './exception-filters-context';
import { DEFAULT_CALLBACK_METADATA } from './ws-metadata-constants';
Expand Down
2 changes: 1 addition & 1 deletion packages/websockets/errors/index.ts
@@ -1 +1 @@
export * from './ws-exception';
export { WsException } from '@nestjs/common/exceptions/ws-exception';
31 changes: 6 additions & 25 deletions packages/websockets/errors/ws-exception.ts
@@ -1,27 +1,8 @@
import { isObject, isString } from '@nestjs/common/utils/shared.utils';
import { WsException as WSE } from '@nestjs/common/exceptions/ws-exception';

export class WsException extends Error {
constructor(private readonly error: string | object) {
super();
this.initMessage();
}
/**
* @deprecated WsException has been moved to @nestjs/common
*/
const WsException = WSE;

public initMessage() {
if (isString(this.error)) {
this.message = this.error;
} else if (
isObject(this.error) &&
isString((this.error as Record<string, any>).message)
) {
this.message = (this.error as Record<string, any>).message;
} else if (this.constructor) {
this.message = this.constructor.name
.match(/[A-Z][a-z]+|[0-9]+/g)
.join(' ');
}
}

public getError(): string | object {
return this.error;
}
}
export { WsException };
2 changes: 1 addition & 1 deletion packages/websockets/exceptions/base-ws-exception-filter.ts
@@ -1,7 +1,7 @@
import { ArgumentsHost, Logger, WsExceptionFilter } from '@nestjs/common';
import { isObject } from '@nestjs/common/utils/shared.utils';
import { MESSAGES } from '@nestjs/core/constants';
import { WsException } from '../errors/ws-exception';
import { WsException } from '@nestjs/common';

/**
* @publicApi
Expand Down
2 changes: 1 addition & 1 deletion packages/websockets/exceptions/ws-exceptions-handler.ts
Expand Up @@ -3,7 +3,7 @@ import { ArgumentsHost } from '@nestjs/common';
import { ExceptionFilterMetadata } from '@nestjs/common/interfaces/exceptions/exception-filter-metadata.interface';
import { selectExceptionFilterMetadata } from '@nestjs/common/utils/select-exception-filter-metadata.util';
import { InvalidExceptionFilterException } from '@nestjs/core/errors/exceptions/invalid-exception-filter.exception';
import { WsException } from '../errors/ws-exception';
import { WsException } from '@nestjs/common';
import { BaseWsExceptionFilter } from './base-ws-exception-filter';

/**
Expand Down