Skip to content

Commit

Permalink
feat: add a "dontDrop" property in shared config to define tables not…
Browse files Browse the repository at this point in the history
… to be dropped, close #820
  • Loading branch information
Julien-R44 committed Sep 17, 2022
1 parent 75630a7 commit 0e5df7e
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 14 deletions.
3 changes: 2 additions & 1 deletion adonis-typings/database.ts
Expand Up @@ -336,14 +336,15 @@ declare module '@ioc:Adonis/Lucid/Database' {
/**
* Shared config options for all clients
*/
type SharedConfigNode = {
export type SharedConfigNode = {
useNullAsDefault?: boolean
debug?: boolean
asyncStackTraces?: boolean
revision?: number
healthCheck?: boolean
migrations?: MigratorConfig
seeders?: SeedersConfig
dontDrop?: string[]
pool?: {
afterCreate?: (conn: any, done: any) => void
min?: number
Expand Down
13 changes: 10 additions & 3 deletions src/Dialects/Mysql.ts
Expand Up @@ -10,7 +10,7 @@
/// <reference path="../../adonis-typings/index.ts" />

import { RawBuilder } from '../Database/StaticBuilder/Raw'
import { DialectContract, QueryClientContract } from '@ioc:Adonis/Lucid/Database'
import { DialectContract, MysqlConfig, QueryClientContract } from '@ioc:Adonis/Lucid/Database'

export class MysqlDialect implements DialectContract {
public readonly name = 'mysql'
Expand All @@ -31,7 +31,7 @@ export class MysqlDialect implements DialectContract {
*/
public readonly dateTimeFormat = 'yyyy-MM-dd HH:mm:ss'

constructor(private client: QueryClientContract) {}
constructor(private client: QueryClientContract, private config: MysqlConfig) {}

/**
* Truncate mysql table with option to cascade
Expand Down Expand Up @@ -99,14 +99,21 @@ export class MysqlDialect implements DialectContract {
public async dropAllTables() {
let tables = await this.getAllTables()

if (!tables.length) return
/**
* Filter out tables that are not allowed to be dropped
*/
tables = tables.filter((table) => !(this.config.dontDrop || []).includes(table))

/**
* Add backquote around table names to avoid syntax errors
* in case of a table name with a reserved keyword
*/
tables = tables.map((table) => '`' + table + '`')

if (!tables.length) {
return
}

/**
* Cascade and truncate
*/
Expand Down
18 changes: 14 additions & 4 deletions src/Dialects/Pg.ts
Expand Up @@ -9,7 +9,7 @@

/// <reference path="../../adonis-typings/index.ts" />

import { DialectContract, QueryClientContract } from '@ioc:Adonis/Lucid/Database'
import { DialectContract, PostgreConfig, QueryClientContract } from '@ioc:Adonis/Lucid/Database'

export class PgDialect implements DialectContract {
public readonly name = 'postgres'
Expand All @@ -30,7 +30,7 @@ export class PgDialect implements DialectContract {
*/
public readonly dateTimeFormat = "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"

constructor(private client: QueryClientContract) {}
constructor(private client: QueryClientContract, private config: PostgreConfig) {}

/**
* Returns an array of table names for one or many schemas.
Expand Down Expand Up @@ -87,8 +87,18 @@ export class PgDialect implements DialectContract {
* Drop all tables inside the database
*/
public async dropAllTables(schemas: string[]) {
const tables = await this.getAllTables(schemas)
if (!tables.length) return
let tables = await this.getAllTables(schemas)

/**
* Filter out tables that are not allowed to be dropped
*/
tables = tables.filter(
(table) => !(this.config.dontDrop || ['spatial_ref_sys']).includes(table)
)

if (!tables.length) {
return
}

await this.client.rawQuery(`DROP TABLE "${tables.join('", "')}" CASCADE;`)
}
Expand Down
18 changes: 14 additions & 4 deletions src/Dialects/Redshift.ts
Expand Up @@ -9,7 +9,7 @@

/// <reference path="../../adonis-typings/index.ts" />

import { DialectContract, QueryClientContract } from '@ioc:Adonis/Lucid/Database'
import { DialectContract, PostgreConfig, QueryClientContract } from '@ioc:Adonis/Lucid/Database'

export class RedshiftDialect implements DialectContract {
public readonly name = 'redshift'
Expand All @@ -30,7 +30,7 @@ export class RedshiftDialect implements DialectContract {
*/
public readonly dateTimeFormat = "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"

constructor(private client: QueryClientContract) {}
constructor(private client: QueryClientContract, private config: PostgreConfig) {}

/**
* Returns an array of table names for one or many schemas.
Expand Down Expand Up @@ -95,8 +95,18 @@ export class RedshiftDialect implements DialectContract {
* Drop all tables inside the database
*/
public async dropAllTables(schemas: string[]) {
const tables = await this.getAllTables(schemas)
if (!tables.length) return
let tables = await this.getAllTables(schemas)

/**
* Filter out tables that are not allowed to be dropped
*/
tables = tables.filter(
(table) => !(this.config.dontDrop || ['spatial_ref_sys']).includes(table)
)

if (!tables.length) {
return
}

await this.client.rawQuery(`DROP table ${tables.join(',')} CASCADE;`)
}
Expand Down
5 changes: 5 additions & 0 deletions src/Dialects/index.ts
Expand Up @@ -14,6 +14,7 @@ import { SqliteDialect } from './Sqlite'
import { OracleDialect } from './Oracle'
import { RedshiftDialect } from './Redshift'
import { BetterSqliteDialect } from './BetterSqlite'
import { DialectContract, QueryClientContract, SharedConfigNode } from '@ioc:Adonis/Lucid/Database'

export const dialects = {
'mssql': MssqlDialect,
Expand All @@ -24,4 +25,8 @@ export const dialects = {
'redshift': RedshiftDialect,
'sqlite3': SqliteDialect,
'better-sqlite3': BetterSqliteDialect,
} as {
[key: string]: {
new (client: QueryClientContract, config: SharedConfigNode): DialectContract
}
}
5 changes: 4 additions & 1 deletion src/QueryClient/index.ts
Expand Up @@ -44,7 +44,10 @@ export class QueryClient implements QueryClientContract {
/**
* The dialect in use
*/
public dialect: DialectContract = new dialects[this.connection.dialectName](this)
public dialect: DialectContract = new dialects[this.connection.dialectName](
this,
this.connection.config
)

/**
* The profiler to be used for profiling queries
Expand Down
35 changes: 34 additions & 1 deletion test/database/drop-tables.spec.ts
Expand Up @@ -65,7 +65,7 @@ test.group('Query client | drop tables', (group) => {
await connection.disconnect()
})

test('dropAllTables should not throw when there are no tables', async ({ assert }) => {
test('drop all tables should not throw when there are no tables', async ({ assert }) => {
await fs.fsExtra.ensureDir(join(fs.basePath, 'temp'))
const connection = new Connection('primary', getConfig(), app.logger)
connection.connect()
Expand All @@ -81,4 +81,37 @@ test.group('Query client | drop tables', (group) => {

await connection.disconnect()
})

test('drop all tables except those defined in dontDrop', async ({ assert }) => {
await fs.fsExtra.ensureDir(join(fs.basePath, 'temp'))
const config = getConfig()
config.dontDrop = ['table_that_should_not_be_dropped']

const connection = new Connection('primary', config, app.logger)
connection.connect()

await connection.client!.schema.createTableIfNotExists('temp_users', (table) => {
table.increments('id')
})

await connection.client!.schema.createTableIfNotExists('temp_posts', (table) => {
table.increments('id')
})

await connection.client!.schema.createTableIfNotExists(
'table_that_should_not_be_dropped',
(table) => {
table.increments('id')
}
)

const client = new QueryClient('dual', connection, app.container.use('Adonis/Core/Event'))
await client.dialect.dropAllTables(['public'])

assert.isFalse(await connection.client!.schema.hasTable('temp_users'))
assert.isFalse(await connection.client!.schema.hasTable('temp_posts'))
assert.isTrue(await connection.client!.schema.hasTable('table_that_should_not_be_dropped'))

await connection.disconnect()
}).skip(!['pg', 'mysql'].includes(process.env.DB!))
})

0 comments on commit 0e5df7e

Please sign in to comment.