diff --git a/.changeset/seven-colts-flash.md b/.changeset/seven-colts-flash.md new file mode 100644 index 00000000..4c37bc2b --- /dev/null +++ b/.changeset/seven-colts-flash.md @@ -0,0 +1,5 @@ +--- +"agent-base": patch +--- + +Synchronously update internal sockets length so `http.Agent` pooling is used diff --git a/packages/agent-base/src/index.ts b/packages/agent-base/src/index.ts index f3758fb9..fcf2b57c 100644 --- a/packages/agent-base/src/index.ts +++ b/packages/agent-base/src/index.ts @@ -1,6 +1,7 @@ import * as net from 'net'; import * as tls from 'tls'; import * as http from 'http'; +import { Agent as HttpsAgent } from 'https'; import type { Duplex } from 'stream'; export * from './helpers'; @@ -77,6 +78,65 @@ export abstract class Agent extends http.Agent { ); } + // In order to support async signatures in `connect()` and Node's native + // connection pooling in `http.Agent`, the array of sockets for each origin + // has to be updated synchronously. This is so the length of the array is + // accurate when `addRequest()` is next called. We achieve this by creating a + // fake socket and adding it to `sockets[origin]` and incrementing + // `totalSocketCount`. + private incrementSockets(name: string) { + // If `maxSockets` and `maxTotalSockets` are both Infinity then there is no + // need to create a fake socket because Node.js native connection pooling + // will never be invoked. + if (this.maxSockets === Infinity && this.maxTotalSockets === Infinity) { + return null; + } + // All instances of `sockets` are expected TypeScript errors. The + // alternative is to add it as a private property of this class but that + // will break TypeScript subclassing. + if (!this.sockets[name]) { + // @ts-expect-error `sockets` is readonly in `@types/node` + this.sockets[name] = []; + } + const fakeSocket = new net.Socket({ writable: false }); + (this.sockets[name] as net.Socket[]).push(fakeSocket); + // @ts-expect-error `totalSocketCount` isn't defined in `@types/node` + this.totalSocketCount++; + return fakeSocket; + } + + private decrementSockets(name: string, socket: null | net.Socket) { + if (!this.sockets[name] || socket === null) { + return; + } + const sockets = this.sockets[name] as net.Socket[]; + const index = sockets.indexOf(socket); + if (index !== -1) { + sockets.splice(index, 1); + // @ts-expect-error `totalSocketCount` isn't defined in `@types/node` + this.totalSocketCount--; + if (sockets.length === 0) { + // @ts-expect-error `sockets` is readonly in `@types/node` + delete this.sockets[name]; + } + } + } + + // In order to properly update the socket pool, we need to call `getName()` on + // the core `https.Agent` if it is a secureEndpoint. + getName(options: AgentConnectOpts): string { + const secureEndpoint = + typeof options.secureEndpoint === 'boolean' + ? options.secureEndpoint + : this.isSecureEndpoint(options); + if (secureEndpoint) { + // @ts-expect-error `getName()` isn't defined in `@types/node` + return HttpsAgent.prototype.getName.call(this, options); + } + // @ts-expect-error `getName()` isn't defined in `@types/node` + return super.getName(options); + } + createSocket( req: http.ClientRequest, options: AgentConnectOpts, @@ -86,17 +146,26 @@ export abstract class Agent extends http.Agent { ...options, secureEndpoint: this.isSecureEndpoint(options), }; + const name = this.getName(connectOpts); + const fakeSocket = this.incrementSockets(name); Promise.resolve() .then(() => this.connect(req, connectOpts)) - .then((socket) => { - if (socket instanceof http.Agent) { - // @ts-expect-error `addRequest()` isn't defined in `@types/node` - return socket.addRequest(req, connectOpts); + .then( + (socket) => { + this.decrementSockets(name, fakeSocket); + if (socket instanceof http.Agent) { + // @ts-expect-error `addRequest()` isn't defined in `@types/node` + return socket.addRequest(req, connectOpts); + } + this[INTERNAL].currentSocket = socket; + // @ts-expect-error `createSocket()` isn't defined in `@types/node` + super.createSocket(req, options, cb); + }, + (err) => { + this.decrementSockets(name, fakeSocket); + cb(err); } - this[INTERNAL].currentSocket = socket; - // @ts-expect-error `createSocket()` isn't defined in `@types/node` - super.createSocket(req, options, cb); - }, cb); + ); } createConnection(): Duplex { diff --git a/packages/agent-base/test/test.ts b/packages/agent-base/test/test.ts index 6f49a84d..deeed875 100644 --- a/packages/agent-base/test/test.ts +++ b/packages/agent-base/test/test.ts @@ -79,6 +79,7 @@ describe('Agent (TypeScript)', () => { ) { gotCallback = true; assert(opts.secureEndpoint === false); + assert.equal(this.getName(opts), `127.0.0.1:${port}:`); return net.connect(opts); } } @@ -308,6 +309,60 @@ describe('Agent (TypeScript)', () => { server2.close(); } }); + + it('should support `keepAlive: true` with `maxSockets`', async () => { + let reqCount = 0; + let connectCount = 0; + + class MyAgent extends Agent { + async connect( + _req: http.ClientRequest, + opts: AgentConnectOpts + ) { + connectCount++; + assert(opts.secureEndpoint === false); + await sleep(10); + return net.connect(opts); + } + } + const agent = new MyAgent({ keepAlive: true, maxSockets: 1 }); + + const server = http.createServer(async (req, res) => { + expect(req.headers.connection).toEqual('keep-alive'); + reqCount++; + await sleep(10); + res.end(); + }); + const addr = await listen(server); + + try { + const resPromise = req(new URL('/foo', addr), { agent }); + const res2Promise = req(new URL('/another', addr), { + agent, + }); + + const res = await resPromise; + expect(reqCount).toEqual(1); + expect(connectCount).toEqual(1); + expect(res.headers.connection).toEqual('keep-alive'); + + res.resume(); + const s1 = res.socket; + await once(s1, 'free'); + + const res2 = await res2Promise; + expect(reqCount).toEqual(2); + expect(connectCount).toEqual(1); + expect(res2.headers.connection).toEqual('keep-alive'); + assert(res2.socket === s1); + + res2.resume(); + await once(res2.socket, 'free'); + } finally { + agent.destroy(); + server.close(); + } + }); }); describe('"https" module', () => { @@ -322,6 +377,10 @@ describe('Agent (TypeScript)', () => { ): net.Socket { gotCallback = true; assert(opts.secureEndpoint === true); + assert.equal( + this.getName(opts), + `127.0.0.1:${port}::::::::false:::::::::::::` + ); return tls.connect(opts); } } @@ -509,5 +568,63 @@ describe('Agent (TypeScript)', () => { server.close(); } }); + + it('should support `keepAlive: true` with `maxSockets`', async () => { + let reqCount = 0; + let connectCount = 0; + + class MyAgent extends Agent { + async connect( + _req: http.ClientRequest, + opts: AgentConnectOpts + ) { + connectCount++; + assert(opts.secureEndpoint === true); + await sleep(10); + return tls.connect(opts); + } + } + const agent = new MyAgent({ keepAlive: true, maxSockets: 1 }); + + const server = https.createServer(sslOptions, async (req, res) => { + expect(req.headers.connection).toEqual('keep-alive'); + reqCount++; + await sleep(10); + res.end(); + }); + const addr = await listen(server); + + try { + const resPromise = req(new URL('/foo', addr), { + agent, + rejectUnauthorized: false, + }); + const res2Promise = req(new URL('/another', addr), { + agent, + rejectUnauthorized: false, + }); + + const res = await resPromise; + expect(reqCount).toEqual(1); + expect(connectCount).toEqual(1); + expect(res.headers.connection).toEqual('keep-alive'); + + res.resume(); + const s1 = res.socket; + await once(s1, 'free'); + + const res2 = await res2Promise; + expect(reqCount).toEqual(2); + expect(connectCount).toEqual(1); + expect(res2.headers.connection).toEqual('keep-alive'); + assert(res2.socket === s1); + + res2.resume(); + await once(res2.socket, 'free'); + } finally { + agent.destroy(); + server.close(); + } + }); }); });