diff --git a/lib/Server.js b/lib/Server.js index 5538f6f454..5cdeccffaf 100644 --- a/lib/Server.js +++ b/lib/Server.js @@ -685,7 +685,14 @@ class Server { return; } - if (headers && (!this.checkHost(headers) || !this.checkOrigin(headers))) { + if (!headers) { + this.log.warn( + 'serverMode implementation must pass headers to the callback of onConnection(f) ' + + 'via f(connection, headers) in order for clients to pass a headers security check' + ); + } + + if (!headers || !this.checkHost(headers) || !this.checkOrigin(headers)) { this.sockWrite([connection], 'error', 'Invalid Host/Origin header'); this.socketServer.close(connection); diff --git a/lib/servers/SockJSServer.js b/lib/servers/SockJSServer.js index 49c42538d2..464bc3cbd1 100644 --- a/lib/servers/SockJSServer.js +++ b/lib/servers/SockJSServer.js @@ -57,7 +57,7 @@ module.exports = class SockJSServer extends BaseServer { connection.close(); } - // f should return the resulting connection and, optionally, the connection headers + // f should be passed the resulting connection and the connection headers onConnection(f) { this.socket.on('connection', (connection) => { f(connection, connection.headers); diff --git a/lib/servers/WebsocketServer.js b/lib/servers/WebsocketServer.js index 03c5a56e46..2e237d2856 100644 --- a/lib/servers/WebsocketServer.js +++ b/lib/servers/WebsocketServer.js @@ -27,7 +27,7 @@ module.exports = class WebsocketServer extends BaseServer { connection.close(); } - // f should return the resulting connection + // f should be passed the resulting connection and the connection headers onConnection(f) { this.wsServer.on('connection', (connection, req) => { f(connection, req.headers); diff --git a/test/server/__snapshots__/serverMode-option.test.js.snap b/test/server/__snapshots__/serverMode-option.test.js.snap index a7ad43ca4e..ef6ab793a5 100644 --- a/test/server/__snapshots__/serverMode-option.test.js.snap +++ b/test/server/__snapshots__/serverMode-option.test.js.snap @@ -7,3 +7,11 @@ Array [ "close", ] `; + +exports[`serverMode option without a header results in an error 1`] = ` +Array [ + "open", + "{\\"type\\":\\"error\\",\\"data\\":\\"Invalid Host/Origin header\\"}", + "close", +] +`; diff --git a/test/server/serverMode-option.test.js b/test/server/serverMode-option.test.js index 8a660ddab6..6f64331b2b 100644 --- a/test/server/serverMode-option.test.js +++ b/test/server/serverMode-option.test.js @@ -223,4 +223,95 @@ describe('serverMode option', () => { }, 5000); }); }); + + describe('without a header', () => { + let mockWarn; + beforeAll((done) => { + server = testServer.start( + config, + { + port, + serverMode: class MySockJSServer extends BaseServer { + constructor(serv) { + super(serv); + this.socket = sockjs.createServer({ + // Use provided up-to-date sockjs-client + sockjs_url: '/__webpack_dev_server__/sockjs.bundle.js', + // Limit useless logs + log: (severity, line) => { + if (severity === 'error') { + this.server.log.error(line); + } else { + this.server.log.debug(line); + } + }, + }); + + this.socket.installHandlers(this.server.listeningApp, { + prefix: this.server.sockPath, + }); + } + + send(connection, message) { + connection.write(message); + } + + close(connection) { + connection.close(); + } + + onConnection(f) { + this.socket.on('connection', (connection) => { + f(connection); + }); + } + + onConnectionClose(connection, f) { + connection.on('close', f); + } + }, + }, + done + ); + + mockWarn = jest.spyOn(server.log, 'warn').mockImplementation(() => {}); + }); + + it('results in an error', (done) => { + const data = []; + const client = new SockJS(`http://localhost:${port}/sockjs-node`); + + client.onopen = () => { + data.push('open'); + }; + + client.onmessage = (e) => { + data.push(e.data); + }; + + client.onclose = () => { + data.push('close'); + }; + + setTimeout(() => { + expect(data).toMatchSnapshot(); + const calls = mockWarn.mock.calls; + mockWarn.mockRestore(); + + let foundWarning = false; + const regExp = /serverMode implementation must pass headers to the callback of onConnection\(f\)/; + calls.every((call) => { + if (regExp.test(call)) { + foundWarning = true; + return false; + } + return true; + }); + + expect(foundWarning).toBeTruthy(); + + done(); + }, 5000); + }); + }); });