diff --git a/adonis-typings/database.ts b/adonis-typings/database.ts index f07ec327..1229ad7b 100644 --- a/adonis-typings/database.ts +++ b/adonis-typings/database.ts @@ -336,7 +336,7 @@ declare module '@ioc:Adonis/Lucid/Database' { /** * Shared config options for all clients */ - type SharedConfigNode = { + export type SharedConfigNode = { useNullAsDefault?: boolean debug?: boolean asyncStackTraces?: boolean @@ -344,6 +344,7 @@ declare module '@ioc:Adonis/Lucid/Database' { healthCheck?: boolean migrations?: MigratorConfig seeders?: SeedersConfig + dontDrop?: string[] pool?: { afterCreate?: (conn: any, done: any) => void min?: number diff --git a/src/Dialects/Mysql.ts b/src/Dialects/Mysql.ts index d8cf15da..062dbac1 100644 --- a/src/Dialects/Mysql.ts +++ b/src/Dialects/Mysql.ts @@ -10,7 +10,7 @@ /// import { RawBuilder } from '../Database/StaticBuilder/Raw' -import { DialectContract, QueryClientContract } from '@ioc:Adonis/Lucid/Database' +import { DialectContract, MysqlConfig, QueryClientContract } from '@ioc:Adonis/Lucid/Database' export class MysqlDialect implements DialectContract { public readonly name = 'mysql' @@ -31,7 +31,7 @@ export class MysqlDialect implements DialectContract { */ public readonly dateTimeFormat = 'yyyy-MM-dd HH:mm:ss' - constructor(private client: QueryClientContract) {} + constructor(private client: QueryClientContract, private config: MysqlConfig) {} /** * Truncate mysql table with option to cascade @@ -99,7 +99,10 @@ export class MysqlDialect implements DialectContract { public async dropAllTables() { let tables = await this.getAllTables() - if (!tables.length) return + /** + * Filter out tables that are not allowed to be dropped + */ + tables = tables.filter((table) => !(this.config.dontDrop || []).includes(table)) /** * Add backquote around table names to avoid syntax errors @@ -107,6 +110,10 @@ export class MysqlDialect implements DialectContract { */ tables = tables.map((table) => '`' + table + '`') + if (!tables.length) { + return + } + /** * Cascade and truncate */ diff --git a/src/Dialects/Pg.ts b/src/Dialects/Pg.ts index 39e41800..5de75d29 100644 --- a/src/Dialects/Pg.ts +++ b/src/Dialects/Pg.ts @@ -9,7 +9,7 @@ /// -import { DialectContract, QueryClientContract } from '@ioc:Adonis/Lucid/Database' +import { DialectContract, PostgreConfig, QueryClientContract } from '@ioc:Adonis/Lucid/Database' export class PgDialect implements DialectContract { public readonly name = 'postgres' @@ -30,7 +30,7 @@ export class PgDialect implements DialectContract { */ public readonly dateTimeFormat = "yyyy-MM-dd'T'HH:mm:ss.SSSZZ" - constructor(private client: QueryClientContract) {} + constructor(private client: QueryClientContract, private config: PostgreConfig) {} /** * Returns an array of table names for one or many schemas. @@ -87,8 +87,18 @@ export class PgDialect implements DialectContract { * Drop all tables inside the database */ public async dropAllTables(schemas: string[]) { - const tables = await this.getAllTables(schemas) - if (!tables.length) return + let tables = await this.getAllTables(schemas) + + /** + * Filter out tables that are not allowed to be dropped + */ + tables = tables.filter( + (table) => !(this.config.dontDrop || ['spatial_ref_sys']).includes(table) + ) + + if (!tables.length) { + return + } await this.client.rawQuery(`DROP TABLE "${tables.join('", "')}" CASCADE;`) } diff --git a/src/Dialects/Redshift.ts b/src/Dialects/Redshift.ts index 296b59c0..3e4271ec 100644 --- a/src/Dialects/Redshift.ts +++ b/src/Dialects/Redshift.ts @@ -9,7 +9,7 @@ /// -import { DialectContract, QueryClientContract } from '@ioc:Adonis/Lucid/Database' +import { DialectContract, PostgreConfig, QueryClientContract } from '@ioc:Adonis/Lucid/Database' export class RedshiftDialect implements DialectContract { public readonly name = 'redshift' @@ -30,7 +30,7 @@ export class RedshiftDialect implements DialectContract { */ public readonly dateTimeFormat = "yyyy-MM-dd'T'HH:mm:ss.SSSZZ" - constructor(private client: QueryClientContract) {} + constructor(private client: QueryClientContract, private config: PostgreConfig) {} /** * Returns an array of table names for one or many schemas. @@ -95,8 +95,18 @@ export class RedshiftDialect implements DialectContract { * Drop all tables inside the database */ public async dropAllTables(schemas: string[]) { - const tables = await this.getAllTables(schemas) - if (!tables.length) return + let tables = await this.getAllTables(schemas) + + /** + * Filter out tables that are not allowed to be dropped + */ + tables = tables.filter( + (table) => !(this.config.dontDrop || ['spatial_ref_sys']).includes(table) + ) + + if (!tables.length) { + return + } await this.client.rawQuery(`DROP table ${tables.join(',')} CASCADE;`) } diff --git a/src/Dialects/index.ts b/src/Dialects/index.ts index 394071e2..b93fad1b 100644 --- a/src/Dialects/index.ts +++ b/src/Dialects/index.ts @@ -14,6 +14,7 @@ import { SqliteDialect } from './Sqlite' import { OracleDialect } from './Oracle' import { RedshiftDialect } from './Redshift' import { BetterSqliteDialect } from './BetterSqlite' +import { DialectContract, QueryClientContract, SharedConfigNode } from '@ioc:Adonis/Lucid/Database' export const dialects = { 'mssql': MssqlDialect, @@ -24,4 +25,8 @@ export const dialects = { 'redshift': RedshiftDialect, 'sqlite3': SqliteDialect, 'better-sqlite3': BetterSqliteDialect, +} as { + [key: string]: { + new (client: QueryClientContract, config: SharedConfigNode): DialectContract + } } diff --git a/src/QueryClient/index.ts b/src/QueryClient/index.ts index c307b792..a93d5771 100644 --- a/src/QueryClient/index.ts +++ b/src/QueryClient/index.ts @@ -44,7 +44,10 @@ export class QueryClient implements QueryClientContract { /** * The dialect in use */ - public dialect: DialectContract = new dialects[this.connection.dialectName](this) + public dialect: DialectContract = new dialects[this.connection.dialectName]( + this, + this.connection.config + ) /** * The profiler to be used for profiling queries diff --git a/test/database/drop-tables.spec.ts b/test/database/drop-tables.spec.ts index 2129f06f..eb85bcfc 100644 --- a/test/database/drop-tables.spec.ts +++ b/test/database/drop-tables.spec.ts @@ -65,7 +65,7 @@ test.group('Query client | drop tables', (group) => { await connection.disconnect() }) - test('dropAllTables should not throw when there are no tables', async ({ assert }) => { + test('drop all tables should not throw when there are no tables', async ({ assert }) => { await fs.fsExtra.ensureDir(join(fs.basePath, 'temp')) const connection = new Connection('primary', getConfig(), app.logger) connection.connect() @@ -81,4 +81,37 @@ test.group('Query client | drop tables', (group) => { await connection.disconnect() }) + + test('drop all tables except those defined in dontDrop', async ({ assert }) => { + await fs.fsExtra.ensureDir(join(fs.basePath, 'temp')) + const config = getConfig() + config.dontDrop = ['table_that_should_not_be_dropped'] + + const connection = new Connection('primary', config, app.logger) + connection.connect() + + await connection.client!.schema.createTableIfNotExists('temp_users', (table) => { + table.increments('id') + }) + + await connection.client!.schema.createTableIfNotExists('temp_posts', (table) => { + table.increments('id') + }) + + await connection.client!.schema.createTableIfNotExists( + 'table_that_should_not_be_dropped', + (table) => { + table.increments('id') + } + ) + + const client = new QueryClient('dual', connection, app.container.use('Adonis/Core/Event')) + await client.dialect.dropAllTables(['public']) + + assert.isFalse(await connection.client!.schema.hasTable('temp_users')) + assert.isFalse(await connection.client!.schema.hasTable('temp_posts')) + assert.isTrue(await connection.client!.schema.hasTable('table_that_should_not_be_dropped')) + + await connection.disconnect() + }).skip(!['pg', 'mysql'].includes(process.env.DB!)) })