-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
middleware.ts
375 lines (352 loc) · 10.9 KB
/
middleware.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
// Copyright IBM Corp. 2020. All Rights Reserved.
// Node module: @loopback/express
// This file is licensed under the MIT License.
// License text available at https://opensource.org/licenses/MIT
import {
Binding,
BindingKey,
BindingScope,
BindingTemplate,
compareBindingsByTag,
Constructor,
Context,
ContextView,
createBindingFromClass,
extensionFilter,
extensionFor,
InvocationResult,
isProviderClass,
Provider,
transformValueOrPromise,
ValueOrPromise,
} from '@loopback/core';
import debugFactory from 'debug';
import {sortListOfGroups} from './group-sorter';
import {DEFAULT_MIDDLEWARE_GROUP, MIDDLEWARE_NAMESPACE} from './keys';
import {
createInterceptor,
defineInterceptorProvider,
toInterceptor,
} from './middleware-interceptor';
import {
DEFAULT_MIDDLEWARE_CHAIN,
ExpressMiddlewareFactory,
ExpressRequestHandler,
InvokeMiddlewareOptions,
Middleware,
MiddlewareBindingOptions,
MiddlewareChain,
MiddlewareContext,
} from './types';
const debug = debugFactory('loopback:middleware');
/**
* An adapter function to create a LoopBack middleware that invokes the list
* of Express middleware handler functions in the order of their positions
* @example
* ```ts
* toMiddleware(fn);
* toMiddleware(fn1, fn2, fn3);
* ```
* @param firstHandler - An Express middleware handler
* @param additionalHandlers A list of Express middleware handler functions
* @returns A LoopBack middleware function that wraps the list of Express
* middleware
*/
export function toMiddleware(
firstHandler: ExpressRequestHandler,
...additionalHandlers: ExpressRequestHandler[]
): Middleware {
if (additionalHandlers.length === 0) return toInterceptor(firstHandler);
const handlers = [firstHandler, ...additionalHandlers];
const middlewareList = handlers.map(handler =>
toInterceptor<MiddlewareContext>(handler),
);
return (middlewareCtx, next) => {
if (middlewareList.length === 1) {
return middlewareList[0](middlewareCtx, next);
}
const middlewareChain = new MiddlewareChain(middlewareCtx, middlewareList);
return middlewareChain.invokeInterceptors(next);
};
}
/**
* An adapter function to create a LoopBack middleware from Express middleware
* factory function and configuration object.
*
* @param middlewareFactory - Express middleware factory function
* @param middlewareConfig - Express middleware config
*
* @returns A LoopBack middleware function that wraps the Express middleware
*/
export function createMiddleware<CFG>(
middlewareFactory: ExpressMiddlewareFactory<CFG>,
middlewareConfig?: CFG,
): Middleware {
return createInterceptor<CFG, MiddlewareContext>(
middlewareFactory,
middlewareConfig,
);
}
/**
* Bind a Express middleware to the given context
*
* @param ctx - Context object
* @param middlewareFactory - Middleware module name or factory function
* @param middlewareConfig - Middleware config
* @param options - Options for registration
*
* @typeParam CFG - Configuration type
*/
export function registerExpressMiddleware<CFG>(
ctx: Context,
middlewareFactory: ExpressMiddlewareFactory<CFG>,
middlewareConfig?: CFG,
options: MiddlewareBindingOptions = {},
): Binding<Middleware> {
options = {injectConfiguration: true, ...options};
options.chain = options.chain ?? DEFAULT_MIDDLEWARE_CHAIN;
if (!options.injectConfiguration) {
const middleware = createMiddleware(middlewareFactory, middlewareConfig);
return registerMiddleware(ctx, middleware, options);
}
const providerClass = defineInterceptorProvider<CFG, MiddlewareContext>(
middlewareFactory,
middlewareConfig,
options,
);
return registerMiddleware(ctx, providerClass, options);
}
/**
* Template function for middleware bindings
* @param options - Options to configure the binding
*/
export function asMiddleware(
options: MiddlewareBindingOptions = {},
): BindingTemplate {
return function middlewareBindingTemplate(binding) {
binding.apply(extensionFor(options.chain ?? DEFAULT_MIDDLEWARE_CHAIN));
if (!binding.tagMap.group) {
binding.tag({group: options.group ?? DEFAULT_MIDDLEWARE_GROUP});
}
const groupsBefore = options.upstreamGroups;
if (groupsBefore != null) {
binding.tag({
upstreamGroups:
typeof groupsBefore === 'string' ? [groupsBefore] : groupsBefore,
});
}
const groupsAfter = options.downstreamGroups;
if (groupsAfter != null) {
binding.tag({
downstreamGroups:
typeof groupsAfter === 'string' ? [groupsAfter] : groupsAfter,
});
}
};
}
/**
* Bind the middleware function or provider class to the context
* @param ctx - Context object
* @param middleware - Middleware function or provider class
* @param options - Middleware binding options
*/
export function registerMiddleware(
ctx: Context,
middleware: Middleware | Constructor<Provider<Middleware>>,
options: MiddlewareBindingOptions,
) {
if (isProviderClass(middleware as Constructor<Provider<Middleware>>)) {
const binding = createMiddlewareBinding(
middleware as Constructor<Provider<Middleware>>,
options,
);
ctx.add(binding);
return binding;
}
const key = options.key ?? BindingKey.generate(MIDDLEWARE_NAMESPACE);
return ctx
.bind(key)
.to(middleware as Middleware)
.apply(asMiddleware(options));
}
/**
* Create a binding for the middleware provider class
*
* @param middlewareProviderClass - Middleware provider class
* @param options - Options to create middleware binding
*
*/
export function createMiddlewareBinding(
middlewareProviderClass: Constructor<Provider<Middleware>>,
options: MiddlewareBindingOptions = {},
) {
options.chain = options.chain ?? DEFAULT_MIDDLEWARE_CHAIN;
const binding = createBindingFromClass(middlewareProviderClass, {
defaultScope: BindingScope.TRANSIENT,
namespace: MIDDLEWARE_NAMESPACE,
key: options.key,
}).apply(asMiddleware(options));
return binding;
}
/**
* Discover and invoke registered middleware in a chain for the given extension
* point.
*
* @param middlewareCtx - Middleware context
* @param options - Options to invoke the middleware chain
*/
export function invokeMiddleware(
middlewareCtx: MiddlewareContext,
options?: InvokeMiddlewareOptions,
): ValueOrPromise<InvocationResult> {
debug(
'Invoke middleware chain for %s %s with options',
middlewareCtx.request.method,
middlewareCtx.request.originalUrl,
options,
);
let keys = options?.middlewareList;
if (keys == null) {
const view = new MiddlewareView(middlewareCtx, options);
keys = view.middlewareBindingKeys;
view.close();
}
const mwChain = new MiddlewareChain(middlewareCtx, keys);
return mwChain.invokeInterceptors(options?.next);
}
/**
* Watch middleware binding keys for the given context and sort them by
* group
* @param ctx - Context object
* @param options - Middleware options
*/
export class MiddlewareView extends ContextView {
private options: InvokeMiddlewareOptions;
private keys: string[];
constructor(ctx: Context, options?: InvokeMiddlewareOptions) {
// Find extensions for the given extension point binding
const filter = extensionFilter(options?.chain ?? DEFAULT_MIDDLEWARE_CHAIN);
super(ctx, filter);
this.options = {
chain: DEFAULT_MIDDLEWARE_CHAIN,
orderedGroups: [],
...options,
};
this.buildMiddlewareKeys();
this.open();
}
refresh() {
super.refresh();
this.buildMiddlewareKeys();
}
/**
* A list of binding keys sorted by group for registered middleware
*/
get middlewareBindingKeys() {
return this.keys;
}
private buildMiddlewareKeys() {
const middlewareBindings = this.bindings;
if (debug.enabled) {
debug(
'Middleware for extension point "%s":',
this.options.chain,
middlewareBindings.map(b => b.key),
);
}
// Calculate orders from middleware dependencies
const ordersFromDependencies: string[][] = [];
middlewareBindings.forEach(b => {
const group: string = b.tagMap.group ?? DEFAULT_MIDDLEWARE_GROUP;
const groupsBefore: string[] = b.tagMap.upstreamGroups ?? [];
groupsBefore.forEach(d => ordersFromDependencies.push([d, group]));
const groupsAfter: string[] = b.tagMap.downstreamGroups ?? [];
groupsAfter.forEach(d => ordersFromDependencies.push([group, d]));
});
const order = sortListOfGroups(
...ordersFromDependencies,
this.options.orderedGroups!,
);
/**
* Validate sorted groups
*/
if (typeof this.options?.validate === 'function') {
this.options.validate(order);
}
this.keys = middlewareBindings
.sort(compareBindingsByTag('group', order))
.map(b => b.key);
}
}
/**
* Invoke a list of Express middleware handler functions
*
* @example
* ```ts
* import cors from 'cors';
* import helmet from 'helmet';
* import morgan from 'morgan';
* import {MiddlewareContext, invokeExpressMiddleware} from '@loopback/express';
*
* // ... Either an instance of `MiddlewareContext` is passed in or a new one
* // can be instantiated from Express request and response objects
*
* const middlewareCtx = new MiddlewareContext(request, response);
* const finished = await invokeExpressMiddleware(
* middlewareCtx,
* cors(),
* helmet(),
* morgan('combined'));
*
* if (finished) {
* // Http response is sent by one of the middleware
* } else {
* // Http response is yet to be produced
* }
* ```
* @param middlewareCtx - Middleware context
* @param handlers - A list of Express middleware handler functions
*/
export function invokeExpressMiddleware(
middlewareCtx: MiddlewareContext,
...handlers: ExpressRequestHandler[]
): ValueOrPromise<boolean> {
if (handlers.length === 0) {
throw new Error('No Express middleware handler function is provided.');
}
const middleware = toMiddleware(handlers[0], ...handlers.slice(1));
debug(
'Invoke Express middleware for %s %s',
middlewareCtx.request.method,
middlewareCtx.request.originalUrl,
);
// Invoke the middleware with a no-op next()
const result = middleware(middlewareCtx, () => undefined);
// Check if the response is finished
return transformValueOrPromise(result, val => val === middlewareCtx.response);
}
/**
* An adapter function to create an Express middleware handler to discover and
* invoke registered LoopBack-style middleware in the context.
* @param ctx - Context object to discover registered middleware
*/
export function toExpressMiddleware(ctx: Context): ExpressRequestHandler {
return (req, res, next) => {
const middlewareCtx = new MiddlewareContext(req, res, ctx);
new Promise((resolve, reject) => {
// eslint-disable-next-line no-void
void (async () => {
try {
const result = await invokeMiddleware(middlewareCtx);
resolve(result);
} catch (err) {
reject(err);
}
})();
})
.then(result => {
if (result !== res) next();
})
.catch(next);
};
}