diff --git a/lib/dialects/mssql/index.js b/lib/dialects/mssql/index.js index 29f28a4468..27654fd80f 100644 --- a/lib/dialects/mssql/index.js +++ b/lib/dialects/mssql/index.js @@ -333,14 +333,24 @@ class Client_MSSQL extends Client { _typeForBinding(binding) { const Driver = this._driver(); + if ( + this.connectionSettings.options && + this.connectionSettings.options.mapBinding + ) { + const result = this.connectionSettings.options.mapBinding(binding); + if (result) { + return [result.value, result.type]; + } + } + switch (typeof binding) { case 'string': - return Driver.TYPES.NVarChar; + return [binding, Driver.TYPES.NVarChar]; case 'boolean': - return Driver.TYPES.Bit; + return [binding, Driver.TYPES.Bit]; case 'number': { if (binding % 1 !== 0) { - return Driver.TYPES.Float; + return [binding, Driver.TYPES.Float]; } if (binding < SQL_INT4.MIN || binding > SQL_INT4.MAX) { @@ -350,25 +360,21 @@ class Client_MSSQL extends Client { ); } - return Driver.TYPES.BigInt; + return [binding, Driver.TYPES.BigInt]; } - return Driver.TYPES.Int; + return [binding, Driver.TYPES.Int]; } default: { - // if (binding === null || typeof binding === 'undefined') { - // return tedious.TYPES.Null; - // } - if (binding instanceof Date) { - return Driver.TYPES.DateTime; + return [binding, Driver.TYPES.DateTime]; } if (binding instanceof Buffer) { - return Driver.TYPES.VarBinary; + return [binding, Driver.TYPES.VarBinary]; } - return Driver.TYPES.NVarChar; + return [binding, Driver.TYPES.NVarChar]; } } } @@ -401,8 +407,8 @@ class Client_MSSQL extends Client { } // sets a request input parameter. Detects bigints and decimals and sets type appropriately. - _setReqInput(req, i, binding) { - const tediousType = this._typeForBinding(binding); + _setReqInput(req, i, inputBinding) { + const [binding, tediousType] = this._typeForBinding(inputBinding); const bindingName = 'p'.concat(i); let options; diff --git a/test/integration2/dialects/mssql.spec.js b/test/integration2/dialects/mssql.spec.js index 25c488d42a..1dfe3e25dc 100644 --- a/test/integration2/dialects/mssql.spec.js +++ b/test/integration2/dialects/mssql.spec.js @@ -1,4 +1,5 @@ const { expect } = require('chai'); +const { TYPES } = require('tedious'); const { getAllDbs, getKnexForDb } = require('../util/knex-instance-provider'); async function fetchDefaultConstraint(knex, table, column) { @@ -31,6 +32,8 @@ describe('MSSQL dialect', () => { beforeEach(async () => { await knex.schema.createTable('test', function () { this.increments('id').primary(); + this.specificType('varchar', 'varchar(100)'); + this.string('nvarchar'); }); }); @@ -362,6 +365,31 @@ describe('MSSQL dialect', () => { return result ? result.comment : undefined; } }); + + describe('supports mapBinding config', async () => { + it('can remap types', async () => { + const query = knex('test') + .where('varchar', { value: 'testing', type: TYPES.VarChar }) + .select('id'); + const { bindings } = query.toSQL().toNative(); + expect(bindings[0].type, TYPES.VarChar); + expect(bindings[0].value, 'testing'); + + // verify the query runs successfully + await query; + }); + it('undefined mapBinding result falls back to default implementation', async () => { + const query = knex('test') + .where('nvarchar', 'testing') + .select('id'); + + const { bindings } = query.toSQL().toNative(); + expect(bindings[0], 'testing'); + + // verify the query runs successfully + await query; + }); + }); }); }); }); diff --git a/test/integration2/util/knex-instance-provider.js b/test/integration2/util/knex-instance-provider.js index f5a68e5c60..c287bd3b86 100644 --- a/test/integration2/util/knex-instance-provider.js +++ b/test/integration2/util/knex-instance-provider.js @@ -173,6 +173,13 @@ const testConfigs = { server: 'localhost', port: 21433, database: 'knex_test', + options: { + mapBinding(value) { + if (value && value.type) { + return { value: value.value, type: value.type }; + } + }, + }, }, pool: pool, migrations, diff --git a/types/index.d.ts b/types/index.d.ts index 69dc3c8df4..c352a6593b 100644 --- a/types/index.d.ts +++ b/types/index.d.ts @@ -2887,6 +2887,7 @@ export declare namespace Knex { multiSubnetFailover?: boolean; packetSize?: number; trustServerCertificate?: boolean; + mapBinding?: (value: any) => ({ value: any, type: any } | undefined); }>; pool?: Readonly<{ min?: number;