diff --git a/adonis-typings/database.ts b/adonis-typings/database.ts index f07ec327..0fe1c446 100644 --- a/adonis-typings/database.ts +++ b/adonis-typings/database.ts @@ -336,7 +336,7 @@ declare module '@ioc:Adonis/Lucid/Database' { /** * Shared config options for all clients */ - type SharedConfigNode = { + export type SharedConfigNode = { useNullAsDefault?: boolean debug?: boolean asyncStackTraces?: boolean @@ -344,6 +344,7 @@ declare module '@ioc:Adonis/Lucid/Database' { healthCheck?: boolean migrations?: MigratorConfig seeders?: SeedersConfig + wipe?: { ignoreTables?: string[] } pool?: { afterCreate?: (conn: any, done: any) => void min?: number diff --git a/bin/test.ts b/bin/test.ts index c4c6521a..4f30a926 100644 --- a/bin/test.ts +++ b/bin/test.ts @@ -21,7 +21,7 @@ import { join } from 'path' configure({ ...processCliArgs(process.argv.slice(2)), ...{ - files: ['test/**/*.spec.ts', '!test/database/drop-table.spec.ts'], + files: ['test/**/*.spec.ts'], plugins: [assert(), runFailedTests()], reporters: [specReporter()], importer: (filePath: string) => import(filePath), diff --git a/src/Dialects/Mssql.ts b/src/Dialects/Mssql.ts index 19ca4d18..bc92ee91 100644 --- a/src/Dialects/Mssql.ts +++ b/src/Dialects/Mssql.ts @@ -10,7 +10,7 @@ /// import { RawBuilder } from '../Database/StaticBuilder/Raw' -import { DialectContract, QueryClientContract } from '@ioc:Adonis/Lucid/Database' +import { DialectContract, MssqlConfig, QueryClientContract } from '@ioc:Adonis/Lucid/Database' export class MssqlDialect implements DialectContract { public readonly name = 'mssql' @@ -31,7 +31,7 @@ export class MssqlDialect implements DialectContract { */ public readonly dateTimeFormat = "yyyy-MM-dd'T'HH:mm:ss.SSSZZ" - constructor(private client: QueryClientContract) {} + constructor(private client: QueryClientContract, private config: MssqlConfig) {} /** * Returns an array of table names @@ -67,14 +67,21 @@ export class MssqlDialect implements DialectContract { public async dropAllTables() { await this.client.rawQuery(` DECLARE @sql NVARCHAR(MAX) = N''; - SELECT @sql += 'ALTER TABLE ' - + QUOTENAME(OBJECT_SCHEMA_NAME(parent_object_id)) + '.' + + QUOTENAME(OBJECT_NAME(parent_object_id)) - + ' DROP CONSTRAINT ' + QUOTENAME(name) + ';' - FROM sys.foreign_keys; - EXEC sp_executesql @sql; + SELECT @sql += 'ALTER TABLE ' + + QUOTENAME(OBJECT_SCHEMA_NAME(parent_object_id)) + '.' + + QUOTENAME(OBJECT_NAME(parent_object_id)) + + ' DROP CONSTRAINT ' + QUOTENAME(name) + ';' + FROM sys.foreign_keys; + EXEC sp_executesql @sql; `) - await this.client.rawQuery(`EXEC sp_MSforeachtable 'DROP TABLE \\?';`) + const ignoredTables = (this.config.wipe?.ignoreTables || []) + .map((table) => `"${table}"`) + .join(', ') + + await this.client.rawQuery(` + EXEC sp_MSforeachtable 'DROP TABLE \\?', + @whereand='AND o.Name NOT IN (${ignoredTables || '""'})' + `) } public async getAllViews(): Promise { diff --git a/src/Dialects/Mysql.ts b/src/Dialects/Mysql.ts index d8cf15da..122897d6 100644 --- a/src/Dialects/Mysql.ts +++ b/src/Dialects/Mysql.ts @@ -10,7 +10,7 @@ /// import { RawBuilder } from '../Database/StaticBuilder/Raw' -import { DialectContract, QueryClientContract } from '@ioc:Adonis/Lucid/Database' +import { DialectContract, MysqlConfig, QueryClientContract } from '@ioc:Adonis/Lucid/Database' export class MysqlDialect implements DialectContract { public readonly name = 'mysql' @@ -31,7 +31,7 @@ export class MysqlDialect implements DialectContract { */ public readonly dateTimeFormat = 'yyyy-MM-dd HH:mm:ss' - constructor(private client: QueryClientContract) {} + constructor(private client: QueryClientContract, private config: MysqlConfig) {} /** * Truncate mysql table with option to cascade @@ -99,7 +99,10 @@ export class MysqlDialect implements DialectContract { public async dropAllTables() { let tables = await this.getAllTables() - if (!tables.length) return + /** + * Filter out tables that are not allowed to be dropped + */ + tables = tables.filter((table) => !(this.config.wipe?.ignoreTables || []).includes(table)) /** * Add backquote around table names to avoid syntax errors @@ -107,6 +110,10 @@ export class MysqlDialect implements DialectContract { */ tables = tables.map((table) => '`' + table + '`') + if (!tables.length) { + return + } + /** * Cascade and truncate */ diff --git a/src/Dialects/Pg.ts b/src/Dialects/Pg.ts index 39e41800..750a5008 100644 --- a/src/Dialects/Pg.ts +++ b/src/Dialects/Pg.ts @@ -9,7 +9,7 @@ /// -import { DialectContract, QueryClientContract } from '@ioc:Adonis/Lucid/Database' +import { DialectContract, PostgreConfig, QueryClientContract } from '@ioc:Adonis/Lucid/Database' export class PgDialect implements DialectContract { public readonly name = 'postgres' @@ -30,7 +30,7 @@ export class PgDialect implements DialectContract { */ public readonly dateTimeFormat = "yyyy-MM-dd'T'HH:mm:ss.SSSZZ" - constructor(private client: QueryClientContract) {} + constructor(private client: QueryClientContract, private config: PostgreConfig) {} /** * Returns an array of table names for one or many schemas. @@ -87,8 +87,18 @@ export class PgDialect implements DialectContract { * Drop all tables inside the database */ public async dropAllTables(schemas: string[]) { - const tables = await this.getAllTables(schemas) - if (!tables.length) return + let tables = await this.getAllTables(schemas) + + /** + * Filter out tables that are not allowed to be dropped + */ + tables = tables.filter( + (table) => !(this.config.wipe?.ignoreTables || ['spatial_ref_sys']).includes(table) + ) + + if (!tables.length) { + return + } await this.client.rawQuery(`DROP TABLE "${tables.join('", "')}" CASCADE;`) } diff --git a/src/Dialects/Redshift.ts b/src/Dialects/Redshift.ts index 296b59c0..9a9c065c 100644 --- a/src/Dialects/Redshift.ts +++ b/src/Dialects/Redshift.ts @@ -9,7 +9,7 @@ /// -import { DialectContract, QueryClientContract } from '@ioc:Adonis/Lucid/Database' +import { DialectContract, PostgreConfig, QueryClientContract } from '@ioc:Adonis/Lucid/Database' export class RedshiftDialect implements DialectContract { public readonly name = 'redshift' @@ -30,7 +30,7 @@ export class RedshiftDialect implements DialectContract { */ public readonly dateTimeFormat = "yyyy-MM-dd'T'HH:mm:ss.SSSZZ" - constructor(private client: QueryClientContract) {} + constructor(private client: QueryClientContract, private config: PostgreConfig) {} /** * Returns an array of table names for one or many schemas. @@ -95,8 +95,18 @@ export class RedshiftDialect implements DialectContract { * Drop all tables inside the database */ public async dropAllTables(schemas: string[]) { - const tables = await this.getAllTables(schemas) - if (!tables.length) return + let tables = await this.getAllTables(schemas) + + /** + * Filter out tables that are not allowed to be dropped + */ + tables = tables.filter( + (table) => !(this.config.wipe?.ignoreTables || ['spatial_ref_sys']).includes(table) + ) + + if (!tables.length) { + return + } await this.client.rawQuery(`DROP table ${tables.join(',')} CASCADE;`) } diff --git a/src/Dialects/SqliteBase.ts b/src/Dialects/SqliteBase.ts index 9a750ab3..128ec017 100644 --- a/src/Dialects/SqliteBase.ts +++ b/src/Dialects/SqliteBase.ts @@ -9,7 +9,7 @@ /// -import { DialectContract, QueryClientContract } from '@ioc:Adonis/Lucid/Database' +import { DialectContract, QueryClientContract, SqliteConfig } from '@ioc:Adonis/Lucid/Database' export abstract class BaseSqliteDialect implements DialectContract { public abstract readonly name: 'sqlite3' | 'better-sqlite3' @@ -30,7 +30,7 @@ export abstract class BaseSqliteDialect implements DialectContract { */ public readonly dateTimeFormat = 'yyyy-MM-dd HH:mm:ss' - constructor(private client: QueryClientContract) {} + constructor(private client: QueryClientContract, private config: SqliteConfig) {} /** * Returns an array of table names @@ -81,9 +81,13 @@ export abstract class BaseSqliteDialect implements DialectContract { */ public async dropAllTables() { await this.client.rawQuery('PRAGMA writable_schema = 1;') - await this.client.rawQuery( - `delete from sqlite_master where type in ('table', 'index', 'trigger');` - ) + await this.client + .knexQuery() + .delete() + .from('sqlite_master') + .whereIn('type', ['table', 'index', 'trigger']) + .whereNotIn('name', this.config.wipe?.ignoreTables || []) + await this.client.rawQuery('PRAGMA writable_schema = 0;') await this.client.rawQuery('VACUUM;') } diff --git a/src/Dialects/index.ts b/src/Dialects/index.ts index 394071e2..b93fad1b 100644 --- a/src/Dialects/index.ts +++ b/src/Dialects/index.ts @@ -14,6 +14,7 @@ import { SqliteDialect } from './Sqlite' import { OracleDialect } from './Oracle' import { RedshiftDialect } from './Redshift' import { BetterSqliteDialect } from './BetterSqlite' +import { DialectContract, QueryClientContract, SharedConfigNode } from '@ioc:Adonis/Lucid/Database' export const dialects = { 'mssql': MssqlDialect, @@ -24,4 +25,8 @@ export const dialects = { 'redshift': RedshiftDialect, 'sqlite3': SqliteDialect, 'better-sqlite3': BetterSqliteDialect, +} as { + [key: string]: { + new (client: QueryClientContract, config: SharedConfigNode): DialectContract + } } diff --git a/src/QueryClient/index.ts b/src/QueryClient/index.ts index c307b792..a93d5771 100644 --- a/src/QueryClient/index.ts +++ b/src/QueryClient/index.ts @@ -44,7 +44,10 @@ export class QueryClient implements QueryClientContract { /** * The dialect in use */ - public dialect: DialectContract = new dialects[this.connection.dialectName](this) + public dialect: DialectContract = new dialects[this.connection.dialectName]( + this, + this.connection.config + ) /** * The profiler to be used for profiling queries diff --git a/test/database/drop-tables.spec.ts b/test/database/drop-tables.spec.ts index 2129f06f..1a2f9511 100644 --- a/test/database/drop-tables.spec.ts +++ b/test/database/drop-tables.spec.ts @@ -26,7 +26,7 @@ test.group('Query client | drop tables', (group) => { }) group.teardown(async () => { - await cleanup(['temp_posts', 'temp_users']) + await cleanup(['temp_posts', 'temp_users', 'table_that_should_not_be_dropped']) await cleanup() await fs.cleanup() }) @@ -65,7 +65,7 @@ test.group('Query client | drop tables', (group) => { await connection.disconnect() }) - test('dropAllTables should not throw when there are no tables', async ({ assert }) => { + test('drop all tables should not throw when there are no tables', async ({ assert }) => { await fs.fsExtra.ensureDir(join(fs.basePath, 'temp')) const connection = new Connection('primary', getConfig(), app.logger) connection.connect() @@ -81,4 +81,41 @@ test.group('Query client | drop tables', (group) => { await connection.disconnect() }) + + test('drop all tables except those defined in ignoreTables', async ({ assert }) => { + await fs.fsExtra.ensureDir(join(fs.basePath, 'temp')) + const config = getConfig() + config.wipe = {} + config.wipe.ignoreTables = ['table_that_should_not_be_dropped', 'ignore_me'] + + const connection = new Connection('primary', config, app.logger) + connection.connect() + + await connection.client!.schema.createTableIfNotExists('temp_users', (table) => { + table.increments('id') + }) + + await connection.client!.schema.createTableIfNotExists('temp_posts', (table) => { + table.increments('id') + }) + + await connection.client!.schema.createTableIfNotExists( + 'table_that_should_not_be_dropped', + (table) => table.increments('id') + ) + + await connection.client!.schema.createTableIfNotExists('ignore_me', (table) => + table.increments('id') + ) + + const client = new QueryClient('dual', connection, app.container.use('Adonis/Core/Event')) + await client.dialect.dropAllTables(['public']) + + assert.isFalse(await connection.client!.schema.hasTable('temp_users')) + assert.isFalse(await connection.client!.schema.hasTable('temp_posts')) + assert.isTrue(await connection.client!.schema.hasTable('table_that_should_not_be_dropped')) + assert.isTrue(await connection.client!.schema.hasTable('ignore_me')) + + await connection.disconnect() + }).pin() })