From 29fdffc08ebd9ddcaa860e53cf0b850b858e46bf Mon Sep 17 00:00:00 2001 From: Luigi Pinca Date: Sat, 19 Jan 2019 16:29:32 +0100 Subject: [PATCH] [feature] Add ability to follow redirects Fixes #812 --- doc/ws.md | 6 +- lib/websocket.js | 158 ++++++++++++++++++++++++++--------------- test/websocket.test.js | 66 +++++++++++++++++ 3 files changed, 171 insertions(+), 59 deletions(-) diff --git a/doc/ws.md b/doc/ws.md index 0b21101eb..c23f1d6c8 100644 --- a/doc/ws.md +++ b/doc/ws.md @@ -193,8 +193,12 @@ This class represents a WebSocket. It extends the `EventEmitter`. - `address` {String|url.Url|url.URL} The URL to which to connect. - `protocols` {String|Array} The list of subprotocols. - `options` {Object} + - `followRedirects` {Boolean} Whether or not to follow redirects. Defaults to + `false`. - `handshakeTimeout` {Number} Timeout in milliseconds for the handshake - request. + request. This is reset after every redirection. + - `maxRedirects` {Number} The maximum number of redirects allowed. Defaults + to 10. - `perMessageDeflate` {Boolean|Object} Enable/disable permessage-deflate. - `protocolVersion` {Number} Value of the `Sec-WebSocket-Version` header. - `origin` {String} Value of the `Origin` or `Sec-WebSocket-Origin` header diff --git a/lib/websocket.js b/lib/websocket.js index 1d7c4468b..5bd4d4b52 100644 --- a/lib/websocket.js +++ b/lib/websocket.js @@ -46,12 +46,14 @@ class WebSocket extends EventEmitter { this._closeTimer = null; this._closeCode = 1006; this._extensions = {}; - this._isServer = true; this._receiver = null; this._sender = null; this._socket = null; if (address !== null) { + this._isServer = false; + this._redirects = 0; + if (Array.isArray(protocols)) { protocols = protocols.join(', '); } else if (typeof protocols === 'object' && protocols !== null) { @@ -59,7 +61,9 @@ class WebSocket extends EventEmitter { protocols = undefined; } - initAsClient.call(this, address, protocols, options); + initAsClient(this, address, protocols, options); + } else { + this._isServer = true; } } @@ -417,22 +421,31 @@ module.exports = WebSocket; /** * Initialize a WebSocket client. * + * @param {WebSocket} websocket The client to initialize * @param {(String|url.Url|url.URL)} address The URL to which to connect * @param {String} protocols The subprotocols * @param {Object} options Connection options - * @param {(Boolean|Object)} options.perMessageDeflate Enable/disable permessage-deflate - * @param {Number} options.handshakeTimeout Timeout in milliseconds for the handshake request - * @param {Number} options.protocolVersion Value of the `Sec-WebSocket-Version` header - * @param {String} options.origin Value of the `Origin` or `Sec-WebSocket-Origin` header + * @param {(Boolean|Object)} options.perMessageDeflate Enable/disable + * permessage-deflate + * @param {Number} options.handshakeTimeout Timeout in milliseconds for the + * handshake request + * @param {Number} options.protocolVersion Value of the `Sec-WebSocket-Version` + * header + * @param {String} options.origin Value of the `Origin` or + * `Sec-WebSocket-Origin` header * @param {Number} options.maxPayload The maximum allowed message size + * @param {Boolean} options.followRedirects Whether or not to follow redirects + * @param {Number} options.maxRedirects The maximum number of redirects allowed * @private */ -function initAsClient(address, protocols, options) { - options = Object.assign( +function initAsClient(websocket, address, protocols, options) { + const opts = Object.assign( { protocolVersion: protocolVersions[1], + maxPayload: 100 * 1024 * 1024, perMessageDeflate: true, - maxPayload: 100 * 1024 * 1024 + followRedirects: false, + maxRedirects: 10 }, options, { @@ -449,128 +462,151 @@ function initAsClient(address, protocols, options) { } ); - if (!protocolVersions.includes(options.protocolVersion)) { + if (!protocolVersions.includes(opts.protocolVersion)) { throw new RangeError( - `Unsupported protocol version: ${options.protocolVersion} ` + + `Unsupported protocol version: ${opts.protocolVersion} ` + `(supported versions: ${protocolVersions.join(', ')})` ); } - this._isServer = false; - var parsedUrl; if (typeof address === 'object' && address.href !== undefined) { parsedUrl = address; - this.url = address.href; + websocket.url = address.href; } else { // // The WHATWG URL constructor is not available on Node.js < 6.13.0 // parsedUrl = url.URL ? new url.URL(address) : url.parse(address); - this.url = address; + websocket.url = address; } const isUnixSocket = parsedUrl.protocol === 'ws+unix:'; if (!parsedUrl.host && (!isUnixSocket || !parsedUrl.pathname)) { - throw new Error(`Invalid URL: ${this.url}`); + throw new Error(`Invalid URL: ${websocket.url}`); } const isSecure = parsedUrl.protocol === 'wss:' || parsedUrl.protocol === 'https:'; const defaultPort = isSecure ? 443 : 80; const key = crypto.randomBytes(16).toString('base64'); - const httpObj = isSecure ? https : http; + const get = isSecure ? https.get : http.get; const path = parsedUrl.search ? `${parsedUrl.pathname || '/'}${parsedUrl.search}` : parsedUrl.pathname || '/'; var perMessageDeflate; - options.createConnection = isSecure ? tlsConnect : netConnect; - options.defaultPort = options.defaultPort || defaultPort; - options.port = parsedUrl.port || defaultPort; - options.host = parsedUrl.hostname.startsWith('[') + opts.createConnection = isSecure ? tlsConnect : netConnect; + opts.defaultPort = opts.defaultPort || defaultPort; + opts.port = parsedUrl.port || defaultPort; + opts.host = parsedUrl.hostname.startsWith('[') ? parsedUrl.hostname.slice(1, -1) : parsedUrl.hostname; - options.headers = Object.assign( + opts.headers = Object.assign( { - 'Sec-WebSocket-Version': options.protocolVersion, + 'Sec-WebSocket-Version': opts.protocolVersion, 'Sec-WebSocket-Key': key, Connection: 'Upgrade', Upgrade: 'websocket' }, - options.headers + opts.headers ); - options.path = path; - options.timeout = options.handshakeTimeout; + opts.path = path; + opts.timeout = opts.handshakeTimeout; - if (options.perMessageDeflate) { + if (opts.perMessageDeflate) { perMessageDeflate = new PerMessageDeflate( - options.perMessageDeflate !== true ? options.perMessageDeflate : {}, + opts.perMessageDeflate !== true ? opts.perMessageDeflate : {}, false, - options.maxPayload + opts.maxPayload ); - options.headers['Sec-WebSocket-Extensions'] = extension.format({ + opts.headers['Sec-WebSocket-Extensions'] = extension.format({ [PerMessageDeflate.extensionName]: perMessageDeflate.offer() }); } if (protocols) { - options.headers['Sec-WebSocket-Protocol'] = protocols; + opts.headers['Sec-WebSocket-Protocol'] = protocols; } - if (options.origin) { - if (options.protocolVersion < 13) { - options.headers['Sec-WebSocket-Origin'] = options.origin; + if (opts.origin) { + if (opts.protocolVersion < 13) { + opts.headers['Sec-WebSocket-Origin'] = opts.origin; } else { - options.headers.Origin = options.origin; + opts.headers.Origin = opts.origin; } } if (parsedUrl.auth) { - options.auth = parsedUrl.auth; + opts.auth = parsedUrl.auth; } else if (parsedUrl.username || parsedUrl.password) { - options.auth = `${parsedUrl.username}:${parsedUrl.password}`; + opts.auth = `${parsedUrl.username}:${parsedUrl.password}`; } if (isUnixSocket) { const parts = path.split(':'); - options.socketPath = parts[0]; - options.path = parts[1]; + opts.socketPath = parts[0]; + opts.path = parts[1]; } - var req = (this._req = httpObj.get(options)); + var req = (websocket._req = get(opts)); - if (options.handshakeTimeout) { + if (opts.timeout) { req.on('timeout', () => { - abortHandshake(this, req, 'Opening handshake has timed out'); + abortHandshake(websocket, req, 'Opening handshake has timed out'); }); } req.on('error', (err) => { - if (this._req.aborted) return; + if (websocket._req.aborted) return; - req = this._req = null; - this.readyState = WebSocket.CLOSING; - this.emit('error', err); - this.emitClose(); + req = websocket._req = null; + websocket.readyState = WebSocket.CLOSING; + websocket.emit('error', err); + websocket.emitClose(); }); req.on('response', (res) => { - if (this.emit('unexpected-response', req, res)) return; + const location = res.headers.location; + const statusCode = res.statusCode; + + if ( + location && + opts.followRedirects && + statusCode >= 300 && + statusCode < 400 + ) { + if (++websocket._redirects > opts.maxRedirects) { + abortHandshake(websocket, req, 'Maximum redirects exceeded'); + return; + } + + req.abort(); + + const addr = url.URL + ? new url.URL(location, address) + : url.resolve(address, location); - abortHandshake(this, req, `Unexpected server response: ${res.statusCode}`); + initAsClient(websocket, addr, protocols, options); + } else if (!websocket.emit('unexpected-response', req, res)) { + abortHandshake( + websocket, + req, + `Unexpected server response: ${res.statusCode}` + ); + } }); req.on('upgrade', (res, socket, head) => { - this.emit('upgrade', res); + websocket.emit('upgrade', res); // // The user may have closed the connection from a listener of the `upgrade` // event. // - if (this.readyState !== WebSocket.CONNECTING) return; + if (websocket.readyState !== WebSocket.CONNECTING) return; - req = this._req = null; + req = websocket._req = null; const digest = crypto .createHash('sha1') @@ -578,7 +614,7 @@ function initAsClient(address, protocols, options) { .digest('base64'); if (res.headers['sec-websocket-accept'] !== digest) { - abortHandshake(this, socket, 'Invalid Sec-WebSocket-Accept header'); + abortHandshake(websocket, socket, 'Invalid Sec-WebSocket-Accept header'); return; } @@ -595,11 +631,11 @@ function initAsClient(address, protocols, options) { } if (protError) { - abortHandshake(this, socket, protError); + abortHandshake(websocket, socket, protError); return; } - if (serverProt) this.protocol = serverProt; + if (serverProt) websocket.protocol = serverProt; if (perMessageDeflate) { try { @@ -609,15 +645,21 @@ function initAsClient(address, protocols, options) { if (extensions[PerMessageDeflate.extensionName]) { perMessageDeflate.accept(extensions[PerMessageDeflate.extensionName]); - this._extensions[PerMessageDeflate.extensionName] = perMessageDeflate; + websocket._extensions[ + PerMessageDeflate.extensionName + ] = perMessageDeflate; } } catch (err) { - abortHandshake(this, socket, 'Invalid Sec-WebSocket-Extensions header'); + abortHandshake( + websocket, + socket, + 'Invalid Sec-WebSocket-Extensions header' + ); return; } } - this.setSocket(socket, head, options.maxPayload); + websocket.setSocket(socket, head, opts.maxPayload); }); } diff --git a/test/websocket.test.js b/test/websocket.test.js index 71f66044f..d77b54794 100644 --- a/test/websocket.test.js +++ b/test/websocket.test.js @@ -664,6 +664,72 @@ describe('WebSocket', function() { ws.on('close', () => wss.close(done)); }); }); + + it('does not follow redirects by default', function(done) { + server.once('upgrade', (req, socket) => { + socket.end( + 'HTTP/1.1 301 Moved Permanently\r\n' + + 'Location: ws://localhost:8080\r\n' + + '\r\n' + ); + }); + + const ws = new WebSocket(`ws://localhost:${server.address().port}`); + + ws.on('open', () => done(new Error("Unexpected 'open' event"))); + ws.on('error', (err) => { + assert.ok(err instanceof Error); + assert.strictEqual(err.message, 'Unexpected server response: 301'); + assert.strictEqual(ws._redirects, 0); + ws.on('close', () => done()); + }); + }); + + it('honors the `followRedirects` option', function(done) { + const wss = new WebSocket.Server({ noServer: true, path: '/foo' }); + + server.once('upgrade', (req, socket) => { + socket.end('HTTP/1.1 302 Found\r\nLocation: /foo\r\n\r\n'); + server.once('upgrade', (req, socket, head) => { + wss.handleUpgrade(req, socket, head, () => {}); + }); + }); + + const port = server.address().port; + const ws = new WebSocket(`ws://localhost:${port}`, { + followRedirects: true + }); + + ws.on('open', () => { + assert.strictEqual(ws.url, `ws://localhost:${port}/foo`); + assert.strictEqual(ws._redirects, 1); + ws.on('close', () => done()); + ws.close(); + }); + }); + + it('honors the `maxRedirects` option', function(done) { + const onUpgrade = (req, socket) => { + socket.end('HTTP/1.1 302 Found\r\nLocation: /\r\n\r\n'); + }; + + server.on('upgrade', onUpgrade); + + const ws = new WebSocket(`ws://localhost:${server.address().port}`, { + followRedirects: true, + maxRedirects: 1 + }); + + ws.on('open', () => done(new Error("Unexpected 'open' event"))); + ws.on('error', (err) => { + assert.ok(err instanceof Error); + assert.strictEqual(err.message, 'Maximum redirects exceeded'); + assert.strictEqual(ws._redirects, 2); + + server.removeListener('upgrade', onUpgrade); + ws.on('close', () => done()); + }); + }); }); describe('Connection with query string', function() {