diff --git a/bench/parser.benchmark.js b/bench/parser.benchmark.js index ce295ccd9..dd97701af 100644 --- a/bench/parser.benchmark.js +++ b/bench/parser.benchmark.js @@ -29,14 +29,14 @@ const pingFrame1 = Buffer.concat( ); const textFrame = Buffer.from('819461616161' + '61'.repeat(20), 'hex'); -const pingFrame2 = Buffer.from('8900', 'hex'); +const pingFrame2 = Buffer.from('8980146e915a', 'hex'); const binaryFrame1 = createBinaryFrame(125); const binaryFrame2 = createBinaryFrame(65535); const binaryFrame3 = createBinaryFrame(200 * 1024); const binaryFrame4 = createBinaryFrame(1024 * 1024); const suite = new benchmark.Suite(); -const receiver = new Receiver(); +const receiver = new Receiver('nodebuffer', {}, true); suite.add('ping frame (5 bytes payload)', { defer: true, diff --git a/lib/receiver.js b/lib/receiver.js index 3a8b92a6b..57daa725d 100644 --- a/lib/receiver.js +++ b/lib/receiver.js @@ -30,14 +30,17 @@ class Receiver extends Writable { * * @param {String} binaryType The type for binary data * @param {Object} extensions An object containing the negotiated extensions + * @param {Boolean} isServer Specifies whether to operate in client or server + * mode * @param {Number} maxPayload The maximum allowed message length */ - constructor(binaryType, extensions, maxPayload) { + constructor(binaryType, extensions, isServer, maxPayload) { super(); this._binaryType = binaryType || BINARY_TYPES[0]; this[kWebSocket] = undefined; this._extensions = extensions || {}; + this._isServer = !!isServer; this._maxPayload = maxPayload | 0; this._bufferedBytes = 0; @@ -225,6 +228,16 @@ class Receiver extends Writable { if (!this._fin && !this._fragmented) this._fragmented = this._opcode; this._masked = (buf[1] & 0x80) === 0x80; + if (this._isServer) { + if (!this._masked) { + this._loop = false; + return error(RangeError, 'MASK must be set', true, 1002); + } + } else if (this._masked) { + this._loop = false; + return error(RangeError, 'MASK must be clear', true, 1002); + } + if (this._payloadLength === 126) this._state = GET_PAYLOAD_LENGTH_16; else if (this._payloadLength === 127) this._state = GET_PAYLOAD_LENGTH_64; else return this.haveLength(); diff --git a/lib/websocket.js b/lib/websocket.js index cd8929bdc..c157cfd2c 100644 --- a/lib/websocket.js +++ b/lib/websocket.js @@ -141,6 +141,7 @@ class WebSocket extends EventEmitter { const receiver = new Receiver( this._binaryType, this._extensions, + this._isServer, maxPayload ); diff --git a/test/receiver.test.js b/test/receiver.test.js index 7d00c76cc..a70cc8dbe 100644 --- a/test/receiver.test.js +++ b/test/receiver.test.js @@ -48,7 +48,7 @@ describe('Receiver', () => { }); it('parses a masked text message', (done) => { - const receiver = new Receiver(); + const receiver = new Receiver(undefined, {}, true); receiver.on('message', (data) => { assert.strictEqual(data, '5:::{"name":"echo"}'); @@ -61,7 +61,7 @@ describe('Receiver', () => { }); it('parses a masked text message longer than 125 B', (done) => { - const receiver = new Receiver(); + const receiver = new Receiver(undefined, {}, true); const msg = 'A'.repeat(200); const list = Sender.frame(Buffer.from(msg), { @@ -84,7 +84,7 @@ describe('Receiver', () => { }); it('parses a really long masked text message', (done) => { - const receiver = new Receiver(); + const receiver = new Receiver(undefined, {}, true); const msg = 'A'.repeat(64 * 1024); const list = Sender.frame(Buffer.from(msg), { @@ -106,7 +106,7 @@ describe('Receiver', () => { }); it('parses a 300 B fragmented masked text message', (done) => { - const receiver = new Receiver(); + const receiver = new Receiver(undefined, {}, true); const msg = 'A'.repeat(300); const fragment1 = msg.substr(0, 150); @@ -139,7 +139,7 @@ describe('Receiver', () => { }); it('parses a ping message', (done) => { - const receiver = new Receiver(); + const receiver = new Receiver(undefined, {}, true); const msg = 'Hello'; const list = Sender.frame(Buffer.from(msg), { @@ -172,7 +172,7 @@ describe('Receiver', () => { }); it('parses a 300 B fragmented masked text message with a ping in the middle (1/2)', (done) => { - const receiver = new Receiver(); + const receiver = new Receiver(undefined, {}, true); const msg = 'A'.repeat(300); const pingMessage = 'Hello'; @@ -221,7 +221,7 @@ describe('Receiver', () => { }); it('parses a 300 B fragmented masked text message with a ping in the middle (2/2)', (done) => { - const receiver = new Receiver(); + const receiver = new Receiver(undefined, {}, true); const msg = 'A'.repeat(300); const pingMessage = 'Hello'; @@ -280,7 +280,7 @@ describe('Receiver', () => { }); it('parses a 100 B masked binary message', (done) => { - const receiver = new Receiver(); + const receiver = new Receiver(undefined, {}, true); const msg = crypto.randomBytes(100); const list = Sender.frame(msg, { @@ -302,7 +302,7 @@ describe('Receiver', () => { }); it('parses a 256 B masked binary message', (done) => { - const receiver = new Receiver(); + const receiver = new Receiver(undefined, {}, true); const msg = crypto.randomBytes(256); const list = Sender.frame(msg, { @@ -324,7 +324,7 @@ describe('Receiver', () => { }); it('parses a 200 KiB masked binary message', (done) => { - const receiver = new Receiver(); + const receiver = new Receiver(undefined, {}, true); const msg = crypto.randomBytes(200 * 1024); const list = Sender.frame(msg, { @@ -439,7 +439,7 @@ describe('Receiver', () => { }); it('resets `totalPayloadLength` only on final frame (unfragmented)', (done) => { - const receiver = new Receiver(undefined, {}, 10); + const receiver = new Receiver(undefined, {}, false, 10); receiver.on('message', (data) => { assert.strictEqual(receiver._totalPayloadLength, 0); @@ -452,7 +452,7 @@ describe('Receiver', () => { }); it('resets `totalPayloadLength` only on final frame (fragmented)', (done) => { - const receiver = new Receiver(undefined, {}, 10); + const receiver = new Receiver(undefined, {}, false, 10); receiver.on('message', (data) => { assert.strictEqual(receiver._totalPayloadLength, 0); @@ -467,7 +467,7 @@ describe('Receiver', () => { }); it('resets `totalPayloadLength` only on final frame (fragmented + ping)', (done) => { - const receiver = new Receiver(undefined, {}, 10); + const receiver = new Receiver(undefined, {}, false, 10); let data; receiver.on('ping', (buf) => { @@ -680,6 +680,40 @@ describe('Receiver', () => { receiver.write(Buffer.from([0x09, 0x00])); }); + it('emits an error if a frame has the MASK bit off (server mode)', (done) => { + const receiver = new Receiver(undefined, {}, true); + + receiver.on('error', (err) => { + assert.ok(err instanceof RangeError); + assert.strictEqual( + err.message, + 'Invalid WebSocket frame: MASK must be set' + ); + assert.strictEqual(err[kStatusCode], 1002); + done(); + }); + + receiver.write(Buffer.from([0x81, 0x02, 0x68, 0x69])); + }); + + it('emits an error if a frame has the MASK bit on (client mode)', (done) => { + const receiver = new Receiver(undefined, {}, false); + + receiver.on('error', (err) => { + assert.ok(err instanceof RangeError); + assert.strictEqual( + err.message, + 'Invalid WebSocket frame: MASK must be clear' + ); + assert.strictEqual(err[kStatusCode], 1002); + done(); + }); + + receiver.write( + Buffer.from([0x81, 0x82, 0x56, 0x3a, 0xac, 0x80, 0x3e, 0x53]) + ); + }); + it('emits an error if a control frame has a payload bigger than 125 B', (done) => { const receiver = new Receiver(); @@ -811,7 +845,7 @@ describe('Receiver', () => { }); it('emits an error if a frame payload length is bigger than `maxPayload`', (done) => { - const receiver = new Receiver(undefined, {}, 20 * 1024); + const receiver = new Receiver(undefined, {}, true, 20 * 1024); const msg = crypto.randomBytes(200 * 1024); const list = Sender.frame(msg, { @@ -843,6 +877,7 @@ describe('Receiver', () => { { 'permessage-deflate': perMessageDeflate }, + false, 25 ); const buf = Buffer.from('A'.repeat(50)); @@ -871,6 +906,7 @@ describe('Receiver', () => { { 'permessage-deflate': perMessageDeflate }, + false, 25 ); const buf = Buffer.from('A'.repeat(15)); diff --git a/test/websocket-server.test.js b/test/websocket-server.test.js index b9c460ce4..7503010a2 100644 --- a/test/websocket-server.test.js +++ b/test/websocket-server.test.js @@ -9,6 +9,7 @@ const http = require('http'); const net = require('net'); const fs = require('fs'); +const Sender = require('../lib/sender'); const WebSocket = require('..'); describe('WebSocketServer', () => { @@ -744,7 +745,15 @@ describe('WebSocketServer', () => { } }); - req.write(Buffer.from([0x81, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f])); + const list = Sender.frame(Buffer.from('Hello'), { + fin: true, + rsv1: false, + opcode: 0x01, + mask: true, + readOnly: false + }); + + req.write(Buffer.concat(list)); req.end(); }); diff --git a/test/websocket.test.js b/test/websocket.test.js index b5b265248..229b78715 100644 --- a/test/websocket.test.js +++ b/test/websocket.test.js @@ -1293,7 +1293,7 @@ describe('WebSocket', () => { }); }); - it('can send text data with `mask` option set to `false`', (done) => { + it('honors the `mask` option', (done) => { const wss = new WebSocket.Server({ port: 0 }, () => { const ws = new WebSocket(`ws://localhost:${wss.address().port}`); @@ -1301,30 +1301,29 @@ describe('WebSocket', () => { }); wss.on('connection', (ws) => { - ws.on('message', (message) => { - assert.strictEqual(message, 'hi'); - wss.close(done); - }); - }); - }); - - it('can send binary data with `mask` option set to `false`', (done) => { - const array = new Float32Array(5); + const chunks = []; - for (let i = 0; i < array.length; ++i) { - array[i] = i / 2; - } - - const wss = new WebSocket.Server({ port: 0 }, () => { - const ws = new WebSocket(`ws://localhost:${wss.address().port}`); + ws._socket.prependListener('data', (chunk) => { + chunks.push(chunk); + }); - ws.on('open', () => ws.send(array, { mask: false })); - }); + ws.on('error', (err) => { + assert.ok(err instanceof RangeError); + assert.strictEqual( + err.message, + 'Invalid WebSocket frame: MASK must be set' + ); + assert.ok( + Buffer.concat(chunks) + .slice(0, 2) + .equals(Buffer.from('8102', 'hex')) + ); - wss.on('connection', (ws) => { - ws.on('message', (message) => { - assert.ok(message.equals(Buffer.from(array.buffer))); - wss.close(done); + ws.on('close', (code, reason) => { + assert.strictEqual(code, 1002); + assert.strictEqual(reason, ''); + wss.close(done); + }); }); }); });