Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use numbered indices in list arguments for async postgresql queries #3579

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -8,4 +8,6 @@ abstract class BindParameterMixin(node: ASTNode) : BindParameterMixin(node) {
isAsync -> "$$index"
else -> "?"
}

override fun useNumberedIndices(isAsync: Boolean) = isAsync
}
18 changes: 18 additions & 0 deletions runtime/src/commonMain/kotlin/app/cash/sqldelight/Transacter.kt
Expand Up @@ -336,6 +336,24 @@ abstract class BaseTransacterImpl(protected val driver: SqlDriver) {
append(')')
}
}

/**
* For internal use, creates a string in the format ($1, $2, $3) where there are [count] variables beginning at
* [offset] +1.
*/
protected fun createNumberedArguments(offset: Int, count: Int): String {
if (count == 0) return "()"

return buildString(3 * count + 1) {
append("(\$")
append(offset + 1)
repeat(count - 1) {
append(",\$")
append(offset + 1 + index)
}
append(')')
}
}
}

/**
Expand Down
Expand Up @@ -12,4 +12,6 @@ abstract class BindParameterMixin(node: ASTNode) : SqlCompositeElementImpl(node)
* user provided parameter with [replaceWith] for a homogen generated code.
*/
open fun replaceWith(isAsync: Boolean, index: Int): String = "?"

open fun useNumberedIndices(isAsync: Boolean) = false
}
Expand Up @@ -68,20 +68,21 @@ abstract class QueryGenerator(
}
val handledArrayArgs = mutableSetOf<BindableQuery.Argument>()
query.statement.findChildrenOfType<SqlStmt>().forEachIndexed { index, statement ->
val (block, additionalArrayArgs) = executeBlock(statement, handledArrayArgs, query.idForIndex(index))
val (block, additionalArrayArgs) = executeBlock(index, statement, handledArrayArgs, query.idForIndex(index))
handledArrayArgs.addAll(additionalArrayArgs)
result.add(block)
}
result.endControlFlow()
if (generateAsync && query is NamedQuery) { result.endControlFlow() }
} else {
result.add(executeBlock(query.statement, emptySet(), query.id).first)
result.add(executeBlock(0, query.statement, emptySet(), query.id).first)
}

return result.build()
}

private fun executeBlock(
blockIndex: Int,
statement: PsiElement,
handledArrayArgs: Set<BindableQuery.Argument>,
id: Int,
Expand Down Expand Up @@ -147,19 +148,33 @@ abstract class QueryGenerator(
val offset = (precedingArrays.map { "$it.size" } + "$nonArrayBindArgsCount")
.joinToString(separator = " + ").replace(" + 0", "")
if (bindArg?.isArrayParameter() == true) {
val bindParameter = bindArg.bindParameter as? BindParameterMixin
val useNumberedIndices = bindParameter?.useNumberedIndices(generateAsync) ?: false
needsFreshStatement = true

if (!handledArrayArgs.contains(argument) && seenArrayArguments.add(argument)) {
val declareVariable = seenArrayArguments.add(argument) && (useNumberedIndices || !handledArrayArgs.contains(argument))
val variableName = if (useNumberedIndices) {
"${type.name}Indexes$blockIndex"
} else {
"${type.name}Indexes"
}
if (declareVariable && useNumberedIndices) {
result.addStatement(
"""
|val $variableName = createNumberedArguments(offset = $offset, count = ${type.name}.size)
""".trimMargin(),
)
} else if (declareVariable) {
result.addStatement(
"""
|val ${type.name}Indexes = createArguments(count = ${type.name}.size)
|val $variableName = createArguments(count = ${type.name}.size)
""".trimMargin(),
)
}

// Replace the single bind argument with the array of bind arguments:
// WHERE id IN ${idIndexes}
replacements.add(bindArg.range to "\$${type.name}Indexes")
replacements.add(bindArg.range to "\$$variableName")

// Perform the necessary binds:
// id.forEachIndex { index, parameter ->
Expand Down
Expand Up @@ -260,4 +260,218 @@ class AsyncSelectQueryTypeTest {
""".trimMargin(),
)
}

@Test
fun `IN in async postgresql queries uses numbered indices`() {
val result = FixtureCompiler.parseSql(
"""
|CREATE TABLE data (
| id INTEGER PRIMARY KEY,
| value TEXT
|);
|
|selectForMultipleIds:
|SELECT *
|FROM data
|WHERE id IN ?;
""".trimMargin(),
tempFolder,
dialect = PostgreSqlDialect(),
generateAsync = true,
)

val query = result.namedQueries.first()
val generator = SelectQueryGenerator(query)

assertThat(generator.querySubtype().toString()).isEqualTo(
"""
|private inner class SelectForMultipleIdsQuery<out T : kotlin.Any>(
| public val id: kotlin.collections.Collection<kotlin.Int>,
| mapper: (app.cash.sqldelight.db.SqlCursor) -> T,
|) : app.cash.sqldelight.Query<T>(mapper) {
| public override fun addListener(listener: app.cash.sqldelight.Query.Listener): kotlin.Unit {
| driver.addListener(listener, arrayOf("data"))
| }
|
| public override fun removeListener(listener: app.cash.sqldelight.Query.Listener): kotlin.Unit {
| driver.removeListener(listener, arrayOf("data"))
| }
|
| public override fun <R> execute(mapper: (app.cash.sqldelight.db.SqlCursor) -> R): app.cash.sqldelight.db.QueryResult<R> {
| val idIndexes0 = createNumberedArguments(offset = 0, count = id.size)
| return driver.executeQuery(null, ""${'"'}
| |SELECT *
| |FROM data
| |WHERE id IN ${"$"}idIndexes0
| ""${'"'}.trimMargin(), mapper, id.size) {
| check(this is app.cash.sqldelight.driver.r2dbc.R2dbcPreparedStatement)
| id.forEachIndexed { index, id_ ->
| bindLong(index, id_.toLong())
| }
| }
| }
|
| public override fun toString(): kotlin.String = "Test.sq:selectForMultipleIds"
|}
|
""".trimMargin(),
)
}

@Test
fun `offset calculation for IN in async postgresql queries is fine`() {
val result = FixtureCompiler.parseSql(
"""
|CREATE TABLE data (
| id INTEGER PRIMARY KEY,
| value TEXT
|);
|
|selectForValueAndMultipleIds:
|SELECT *
|FROM data
|WHERE value = ? AND id IN ?;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add another ? after the array parameter just to make sure the index still works?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. It actually doesn't work as intended. It changes the ? behind the list array $2 but tries to bind at index id.size+1. I think the best way to fix this is by moving the non array arguments to the front and change the way the offset calculation and the code for the binding works accordingly.

""".trimMargin(),
tempFolder,
dialect = PostgreSqlDialect(),
generateAsync = true,
)

val query = result.namedQueries.first()
val generator = SelectQueryGenerator(query)

assertThat(generator.querySubtype().toString()).isEqualTo(
"""
|private inner class SelectForValueAndMultipleIdsQuery<out T : kotlin.Any>(
| public val value_: kotlin.String?,
| public val id: kotlin.collections.Collection<kotlin.Int>,
| mapper: (app.cash.sqldelight.db.SqlCursor) -> T,
|) : app.cash.sqldelight.Query<T>(mapper) {
| public override fun addListener(listener: app.cash.sqldelight.Query.Listener): kotlin.Unit {
| driver.addListener(listener, arrayOf("data"))
| }
|
| public override fun removeListener(listener: app.cash.sqldelight.Query.Listener): kotlin.Unit {
| driver.removeListener(listener, arrayOf("data"))
| }
|
| public override fun <R> execute(mapper: (app.cash.sqldelight.db.SqlCursor) -> R): app.cash.sqldelight.db.QueryResult<R> {
| val idIndexes0 = createNumberedArguments(offset = 1, count = id.size)
| return driver.executeQuery(null, ""${'"'}
| |SELECT *
| |FROM data
| |WHERE value ${"$"}{ if (value_ == null) "IS" else "=" } $1 AND id IN ${"$"}idIndexes0
| ""${'"'}.trimMargin(), mapper, 1 + id.size) {
| check(this is app.cash.sqldelight.driver.r2dbc.R2dbcPreparedStatement)
| bindString(0, value_)
| id.forEachIndexed { index, id_ ->
| bindLong(index + 1, id_.toLong())
| }
| }
| }
|
| public override fun toString(): kotlin.String = "Test.sq:selectForValueAndMultipleIds"
|}
|
""".trimMargin(),
)
}

@Test
fun `offset calculation for IN in async postgresql transactions with same list multiple times is fine`() {
val result = FixtureCompiler.parseSql(
"""
|CREATE TABLE data (
| id INTEGER PRIMARY KEY,
| value TEXT
|);
|
|deleteByValueAndMultipleIdsInTransaction {
| DELETE FROM data WHERE value= ? AND id IN :id;
| DELETE FROM data WHERE id IN :id;
|}
""".trimMargin(),
tempFolder,
dialect = PostgreSqlDialect(),
generateAsync = true,
)

val query = result.namedExecutes.first()
val generator = ExecuteQueryGenerator(query)

assertThat(generator.function().toString()).isEqualTo(
"""
|public suspend fun deleteByValueAndMultipleIdsInTransaction(value_: kotlin.String?, id: kotlin.collections.Collection<kotlin.Int>): kotlin.Unit {
| transaction {
| val idIndexes0 = createNumberedArguments(offset = 1, count = id.size)
| driver.execute(null, ""${'"'}DELETE FROM data WHERE value${"$"}{ if (value_ == null) " IS" else "=" } ${'$'}1 AND id IN ${"$"}idIndexes0""${'"'}, 1 + id.size) {
| check(this is app.cash.sqldelight.driver.r2dbc.R2dbcPreparedStatement)
| bindString(0, value_)
| id.forEachIndexed { index, id_ ->
| bindLong(index + 1, id_.toLong())
| }
| }.await()
| val idIndexes1 = createNumberedArguments(offset = 0, count = id.size)
| driver.execute(null, ""${'"'}DELETE FROM data WHERE id IN ${"$"}idIndexes1""${'"'}, id.size) {
| check(this is app.cash.sqldelight.driver.r2dbc.R2dbcPreparedStatement)
| id.forEachIndexed { index, id_ ->
| bindLong(index, id_.toLong())
| }
| }.await()
| }
| notifyQueries(1231275350) { emit ->
| emit("data")
| }
|}
|
""".trimMargin(),
)
}

@Test
fun `IN in async sqlite transactions does not use numbered indices`() {
val result = FixtureCompiler.parseSql(
"""
|CREATE TABLE data (
| id INTEGER PRIMARY KEY,
| value TEXT
|);
|
|deleteByValueAndMultipleIdsInTransaction {
| DELETE FROM data WHERE value= ? AND id IN :id;
| DELETE FROM data WHERE id IN :id;
|}
""".trimMargin(),
tempFolder,
generateAsync = true,
)

val query = result.namedExecutes.first()
val generator = ExecuteQueryGenerator(query)

assertThat(generator.function().toString()).isEqualTo(
"""
|public suspend fun deleteByValueAndMultipleIdsInTransaction(value_: kotlin.String?, id: kotlin.collections.Collection<kotlin.Long>): kotlin.Unit {
| transaction {
| val idIndexes = createArguments(count = id.size)
| driver.execute(null, ""${'"'}DELETE FROM data WHERE value${"$"}{ if (value_ == null) " IS" else "=" } ? AND id IN ${"$"}idIndexes""${'"'}, 1 + id.size) {
| bindString(0, value_)
| id.forEachIndexed { index, id_ ->
| bindLong(index + 1, id_)
| }
| }.await()
| driver.execute(null, ""${'"'}DELETE FROM data WHERE id IN ${"$"}idIndexes""${'"'}, id.size) {
| id.forEachIndexed { index, id_ ->
| bindLong(index, id_)
| }
| }.await()
| }
| notifyQueries(1231275350) { emit ->
| emit("data")
| }
|}
|
""".trimMargin(),
)
}
}