Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to follow redirects #1490

Merged
merged 1 commit into from Mar 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion doc/ws.md
Expand Up @@ -193,8 +193,12 @@ This class represents a WebSocket. It extends the `EventEmitter`.
- `address` {String|url.Url|url.URL} The URL to which to connect.
- `protocols` {String|Array} The list of subprotocols.
- `options` {Object}
- `followRedirects` {Boolean} Whether or not to follow redirects. Defaults to
`false`.
- `handshakeTimeout` {Number} Timeout in milliseconds for the handshake
request.
request. This is reset after every redirection.
- `maxRedirects` {Number} The maximum number of redirects allowed. Defaults
to 10.
- `perMessageDeflate` {Boolean|Object} Enable/disable permessage-deflate.
- `protocolVersion` {Number} Value of the `Sec-WebSocket-Version` header.
- `origin` {String} Value of the `Origin` or `Sec-WebSocket-Origin` header
Expand Down
158 changes: 100 additions & 58 deletions lib/websocket.js
Expand Up @@ -46,20 +46,24 @@ class WebSocket extends EventEmitter {
this._closeTimer = null;
this._closeCode = 1006;
this._extensions = {};
this._isServer = true;
this._receiver = null;
this._sender = null;
this._socket = null;

if (address !== null) {
this._isServer = false;
this._redirects = 0;

if (Array.isArray(protocols)) {
protocols = protocols.join(', ');
} else if (typeof protocols === 'object' && protocols !== null) {
options = protocols;
protocols = undefined;
}

initAsClient.call(this, address, protocols, options);
initAsClient(this, address, protocols, options);
} else {
this._isServer = true;
}
}

Expand Down Expand Up @@ -417,22 +421,31 @@ module.exports = WebSocket;
/**
* Initialize a WebSocket client.
*
* @param {WebSocket} websocket The client to initialize
* @param {(String|url.Url|url.URL)} address The URL to which to connect
* @param {String} protocols The subprotocols
* @param {Object} options Connection options
* @param {(Boolean|Object)} options.perMessageDeflate Enable/disable permessage-deflate
* @param {Number} options.handshakeTimeout Timeout in milliseconds for the handshake request
* @param {Number} options.protocolVersion Value of the `Sec-WebSocket-Version` header
* @param {String} options.origin Value of the `Origin` or `Sec-WebSocket-Origin` header
* @param {(Boolean|Object)} options.perMessageDeflate Enable/disable
* permessage-deflate
* @param {Number} options.handshakeTimeout Timeout in milliseconds for the
* handshake request
* @param {Number} options.protocolVersion Value of the `Sec-WebSocket-Version`
* header
* @param {String} options.origin Value of the `Origin` or
* `Sec-WebSocket-Origin` header
* @param {Number} options.maxPayload The maximum allowed message size
* @param {Boolean} options.followRedirects Whether or not to follow redirects
* @param {Number} options.maxRedirects The maximum number of redirects allowed
* @private
*/
function initAsClient(address, protocols, options) {
options = Object.assign(
function initAsClient(websocket, address, protocols, options) {
const opts = Object.assign(
{
protocolVersion: protocolVersions[1],
maxPayload: 100 * 1024 * 1024,
perMessageDeflate: true,
maxPayload: 100 * 1024 * 1024
followRedirects: false,
maxRedirects: 10
},
options,
{
Expand All @@ -449,136 +462,159 @@ function initAsClient(address, protocols, options) {
}
);

if (!protocolVersions.includes(options.protocolVersion)) {
if (!protocolVersions.includes(opts.protocolVersion)) {
throw new RangeError(
`Unsupported protocol version: ${options.protocolVersion} ` +
`Unsupported protocol version: ${opts.protocolVersion} ` +
`(supported versions: ${protocolVersions.join(', ')})`
);
}

this._isServer = false;

var parsedUrl;

if (typeof address === 'object' && address.href !== undefined) {
parsedUrl = address;
this.url = address.href;
websocket.url = address.href;
} else {
//
// The WHATWG URL constructor is not available on Node.js < 6.13.0
//
parsedUrl = url.URL ? new url.URL(address) : url.parse(address);
this.url = address;
websocket.url = address;
}

const isUnixSocket = parsedUrl.protocol === 'ws+unix:';

if (!parsedUrl.host && (!isUnixSocket || !parsedUrl.pathname)) {
throw new Error(`Invalid URL: ${this.url}`);
throw new Error(`Invalid URL: ${websocket.url}`);
}

const isSecure =
parsedUrl.protocol === 'wss:' || parsedUrl.protocol === 'https:';
const defaultPort = isSecure ? 443 : 80;
const key = crypto.randomBytes(16).toString('base64');
const httpObj = isSecure ? https : http;
const get = isSecure ? https.get : http.get;
const path = parsedUrl.search
? `${parsedUrl.pathname || '/'}${parsedUrl.search}`
: parsedUrl.pathname || '/';
var perMessageDeflate;

options.createConnection = isSecure ? tlsConnect : netConnect;
options.defaultPort = options.defaultPort || defaultPort;
options.port = parsedUrl.port || defaultPort;
options.host = parsedUrl.hostname.startsWith('[')
opts.createConnection = isSecure ? tlsConnect : netConnect;
opts.defaultPort = opts.defaultPort || defaultPort;
opts.port = parsedUrl.port || defaultPort;
opts.host = parsedUrl.hostname.startsWith('[')
? parsedUrl.hostname.slice(1, -1)
: parsedUrl.hostname;
options.headers = Object.assign(
opts.headers = Object.assign(
{
'Sec-WebSocket-Version': options.protocolVersion,
'Sec-WebSocket-Version': opts.protocolVersion,
'Sec-WebSocket-Key': key,
Connection: 'Upgrade',
Upgrade: 'websocket'
},
options.headers
opts.headers
);
options.path = path;
options.timeout = options.handshakeTimeout;
opts.path = path;
opts.timeout = opts.handshakeTimeout;

if (options.perMessageDeflate) {
if (opts.perMessageDeflate) {
perMessageDeflate = new PerMessageDeflate(
options.perMessageDeflate !== true ? options.perMessageDeflate : {},
opts.perMessageDeflate !== true ? opts.perMessageDeflate : {},
false,
options.maxPayload
opts.maxPayload
);
options.headers['Sec-WebSocket-Extensions'] = extension.format({
opts.headers['Sec-WebSocket-Extensions'] = extension.format({
[PerMessageDeflate.extensionName]: perMessageDeflate.offer()
});
}
if (protocols) {
options.headers['Sec-WebSocket-Protocol'] = protocols;
opts.headers['Sec-WebSocket-Protocol'] = protocols;
}
if (options.origin) {
if (options.protocolVersion < 13) {
options.headers['Sec-WebSocket-Origin'] = options.origin;
if (opts.origin) {
if (opts.protocolVersion < 13) {
opts.headers['Sec-WebSocket-Origin'] = opts.origin;
} else {
options.headers.Origin = options.origin;
opts.headers.Origin = opts.origin;
}
}
if (parsedUrl.auth) {
options.auth = parsedUrl.auth;
opts.auth = parsedUrl.auth;
} else if (parsedUrl.username || parsedUrl.password) {
options.auth = `${parsedUrl.username}:${parsedUrl.password}`;
opts.auth = `${parsedUrl.username}:${parsedUrl.password}`;
}

if (isUnixSocket) {
const parts = path.split(':');

options.socketPath = parts[0];
options.path = parts[1];
opts.socketPath = parts[0];
opts.path = parts[1];
}

var req = (this._req = httpObj.get(options));
var req = (websocket._req = get(opts));

if (options.handshakeTimeout) {
if (opts.timeout) {
req.on('timeout', () => {
abortHandshake(this, req, 'Opening handshake has timed out');
abortHandshake(websocket, req, 'Opening handshake has timed out');
});
}

req.on('error', (err) => {
if (this._req.aborted) return;
if (websocket._req.aborted) return;

req = this._req = null;
this.readyState = WebSocket.CLOSING;
this.emit('error', err);
this.emitClose();
req = websocket._req = null;
websocket.readyState = WebSocket.CLOSING;
websocket.emit('error', err);
websocket.emitClose();
});

req.on('response', (res) => {
if (this.emit('unexpected-response', req, res)) return;
const location = res.headers.location;
const statusCode = res.statusCode;

if (
location &&
opts.followRedirects &&
statusCode >= 300 &&
statusCode < 400
) {
if (++websocket._redirects > opts.maxRedirects) {
abortHandshake(websocket, req, 'Maximum redirects exceeded');
return;
}

req.abort();

const addr = url.URL
? new url.URL(location, address)
: url.resolve(address, location);

abortHandshake(this, req, `Unexpected server response: ${res.statusCode}`);
initAsClient(websocket, addr, protocols, options);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the redirect address should be emitted in some way to allow tracking.

} else if (!websocket.emit('unexpected-response', req, res)) {
abortHandshake(
websocket,
req,
`Unexpected server response: ${res.statusCode}`
);
}
});

req.on('upgrade', (res, socket, head) => {
this.emit('upgrade', res);
websocket.emit('upgrade', res);

//
// The user may have closed the connection from a listener of the `upgrade`
// event.
//
if (this.readyState !== WebSocket.CONNECTING) return;
if (websocket.readyState !== WebSocket.CONNECTING) return;

req = this._req = null;
req = websocket._req = null;

const digest = crypto
.createHash('sha1')
.update(key + constants.GUID, 'binary')
.digest('base64');

if (res.headers['sec-websocket-accept'] !== digest) {
abortHandshake(this, socket, 'Invalid Sec-WebSocket-Accept header');
abortHandshake(websocket, socket, 'Invalid Sec-WebSocket-Accept header');
return;
}

Expand All @@ -595,11 +631,11 @@ function initAsClient(address, protocols, options) {
}

if (protError) {
abortHandshake(this, socket, protError);
abortHandshake(websocket, socket, protError);
return;
}

if (serverProt) this.protocol = serverProt;
if (serverProt) websocket.protocol = serverProt;

if (perMessageDeflate) {
try {
Expand All @@ -609,15 +645,21 @@ function initAsClient(address, protocols, options) {

if (extensions[PerMessageDeflate.extensionName]) {
perMessageDeflate.accept(extensions[PerMessageDeflate.extensionName]);
this._extensions[PerMessageDeflate.extensionName] = perMessageDeflate;
websocket._extensions[
PerMessageDeflate.extensionName
] = perMessageDeflate;
}
} catch (err) {
abortHandshake(this, socket, 'Invalid Sec-WebSocket-Extensions header');
abortHandshake(
websocket,
socket,
'Invalid Sec-WebSocket-Extensions header'
);
return;
}
}

this.setSocket(socket, head, options.maxPayload);
websocket.setSocket(socket, head, opts.maxPayload);
});
}

Expand Down
66 changes: 66 additions & 0 deletions test/websocket.test.js
Expand Up @@ -664,6 +664,72 @@ describe('WebSocket', function() {
ws.on('close', () => wss.close(done));
});
});

it('does not follow redirects by default', function(done) {
server.once('upgrade', (req, socket) => {
socket.end(
'HTTP/1.1 301 Moved Permanently\r\n' +
'Location: ws://localhost:8080\r\n' +
'\r\n'
);
});

const ws = new WebSocket(`ws://localhost:${server.address().port}`);

ws.on('open', () => done(new Error("Unexpected 'open' event")));
ws.on('error', (err) => {
assert.ok(err instanceof Error);
assert.strictEqual(err.message, 'Unexpected server response: 301');
assert.strictEqual(ws._redirects, 0);
ws.on('close', () => done());
});
});

it('honors the `followRedirects` option', function(done) {
const wss = new WebSocket.Server({ noServer: true, path: '/foo' });

server.once('upgrade', (req, socket) => {
socket.end('HTTP/1.1 302 Found\r\nLocation: /foo\r\n\r\n');
server.once('upgrade', (req, socket, head) => {
wss.handleUpgrade(req, socket, head, () => {});
});
});

const port = server.address().port;
const ws = new WebSocket(`ws://localhost:${port}`, {
followRedirects: true
});

ws.on('open', () => {
assert.strictEqual(ws.url, `ws://localhost:${port}/foo`);
assert.strictEqual(ws._redirects, 1);
ws.on('close', () => done());
ws.close();
});
});

it('honors the `maxRedirects` option', function(done) {
const onUpgrade = (req, socket) => {
socket.end('HTTP/1.1 302 Found\r\nLocation: /\r\n\r\n');
};

server.on('upgrade', onUpgrade);

const ws = new WebSocket(`ws://localhost:${server.address().port}`, {
followRedirects: true,
maxRedirects: 1
});

ws.on('open', () => done(new Error("Unexpected 'open' event")));
ws.on('error', (err) => {
assert.ok(err instanceof Error);
assert.strictEqual(err.message, 'Maximum redirects exceeded');
assert.strictEqual(ws._redirects, 2);

server.removeListener('upgrade', onUpgrade);
ws.on('close', () => done());
});
});
});

describe('Connection with query string', function() {
Expand Down