Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compiler: Add Defaults support in bindings #3375

Merged
merged 23 commits into from Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -79,7 +79,8 @@ column_constraint ::= [ CONSTRAINT {identifier} ] (
implements = "com.alecstrong.sql.psi.core.psi.SqlColumnConstraint"
override = true
}
bind_parameter ::= ( DEFAULT | '?' | ':' {identifier} ) {
bind_parameter ::= DEFAULT | ( '?' | ':' {identifier} ) {
mixin = "app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin"
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlBindParameterImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlBindParameter"
override = true
Expand Down
Expand Up @@ -111,7 +111,8 @@ column_constraint ::= [ CONSTRAINT {identifier} ] (
implements = "com.alecstrong.sql.psi.core.psi.SqlColumnConstraint"
override = true
}
bind_parameter ::= ( '?' | ':' {identifier} ) {
bind_parameter ::= DEFAULT | ( '?' | ':' {identifier} ) {
mixin = "app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin"
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlBindParameterImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlBindParameter"
override = true
Expand Down
Expand Up @@ -110,15 +110,16 @@ type_name ::= (
implements = "com.alecstrong.sql.psi.core.psi.SqlTypeName"
override = true
}
bind_parameter ::= ( DEFAULT | '?' | ':' {identifier} ) {
bind_parameter ::= DEFAULT | ( '?' | ':' {identifier} ) {
mixin = "app.cash.sqldelight.dialects.postgresql.grammar.mixins.BindParameterMixin"
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlBindParameterImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlBindParameter"
override = true
}

identity_clause ::= 'IDENTITY'

generated_clause ::= GENERATED ( (ALWAYS AS <<expr '-1'>> 'STORED') | ( (ALWAYS | BY DEFAULT) AS identity_clause ) ) {
generated_clause ::= GENERATED ( (ALWAYS AS LP <<expr '-1'>> RP 'STORED') | ( (ALWAYS | BY DEFAULT) AS identity_clause ) ) {
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlGeneratedClauseImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlGeneratedClause"
override = true
Expand Down
@@ -0,0 +1,11 @@
package app.cash.sqldelight.dialects.postgresql.grammar.mixins

import app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin
import com.intellij.lang.ASTNode

abstract class BindParameterMixin(node: ASTNode) : BindParameterMixin(node) {
override fun replaceWith(isAsync: Boolean, index: Int): String = when {
isAsync -> "$$index"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels like a driver implementation detail, that specifically for R2DBC we need indexed parameters.

We're getting closer and I'm very happy to have functioning tests for this stuff now, so this is really the only thing I'm still hung up on

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kinda like that approach more, doing it at the driver level

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But at driver level you only have strings. So "SELECT '?'" or "LIKE 'foo?%'" does not work. (Except using sql-psi at driver level during runtime again before executing each stmt)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright, I buy that

last suggestion then: could we just always replace arguments for the postgres driver with their indexed argument, even if its not the async driver? So that parameter can be removed and we're just saying for this dialect this is the expected parameter format

Copy link
Collaborator Author

@hfhbd hfhbd Oct 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What parameter do you want to drop?
index => How do you want to do this? Hardcode a check in sqldelight-core to replace the parameter wildcard with a specific string for postgresql only?
isAsync => I have to check it

Copy link
Collaborator Author

@hfhbd hfhbd Oct 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As expected, using $1 does not work on JDBC and results in an org.postgresql.util.PSQLException, because you try to bind a parameter to an unknown one:
In JDBC, the question mark (?) is the placeholder for the positional parameters of a PreparedStatement. There are, however, a number of PostgreSQL® operators that contain a question mark. To keep such question marks in an SQL statement from being interpreted as positional parameters, use two question marks ( ?? ) as escape sequence. You can also use this escape sequence in a Statement , but that is not required. Specifically only in a Statement a single ( ? ) can be used as an operator.
https://jdbc.postgresql.org/documentation/query/#using-the-statement-or-preparedstatement-interface

But I removed some useless check/binding implementations because I already check for DEFAULT

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good. I'm still wishing that this logic was in the driver somewhere instead of the dialect since it looks like a driver constraint, not a postgres one, but lets just merge to fix the underlying issue and then if it becomes a problem later we can revisit.

else -> "?"
}
}
Expand Up @@ -15,6 +15,7 @@ class PostgreSqlFixturesTest(name: String, fixtureRoot: File) : FixturesTest(nam
"?1" to "?",
"?2" to "?",
"BLOB" to "TEXT",
"id TEXT GENERATED ALWAYS AS (2) UNIQUE NOT NULL" to "id TEXT GENERATED ALWAYS AS (2) STORED UNIQUE NOT NULL",
)

override fun setupDialect() {
Expand Down
Expand Up @@ -27,6 +27,7 @@ int_data_type ::= 'INTEGER'
real_data_type ::= 'REAL'

bind_parameter ::= ( '?' [digit] | ':' {identifier} ) {
mixin = "app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin"
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlBindParameterImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlBindParameter"
override = true
Expand Down
Expand Up @@ -48,7 +48,7 @@ class R2dbcDriver(private val connection: Connection) : SqlDriver {

return QueryResult.AsyncValue {
val result = prepared.execute().awaitSingle()
return@AsyncValue result.rowsUpdated.awaitFirstOrNull() ?: 0
return@AsyncValue result.rowsUpdated.awaitFirstOrNull()?.toLong() ?: 0
}
}

Expand Down
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Expand Up @@ -76,7 +76,7 @@ testhelp = { module = "co.touchlab:testhelp", version.ref = "testhelp" }
burst = { module = "com.squareup.burst:burst-junit4", version = "1.2.0" }
testParameterInjector = { module = "com.google.testparameterinjector:test-parameter-injector", version = "1.8" }

r2dbc = { module = "io.r2dbc:r2dbc-spi", version = "1.0.0.RELEASE" }
r2dbc = { module = "io.r2dbc:r2dbc-spi", version = "0.9.1.RELEASE" }

[plugins]
android-library = { id = "com.android.library", version.ref = "agp" }
Expand Down
@@ -0,0 +1,15 @@
package app.cash.sqldelight.dialect.grammar.mixins

import com.alecstrong.sql.psi.core.psi.SqlBindParameter
import com.alecstrong.sql.psi.core.psi.SqlCompositeElementImpl
import com.intellij.lang.ASTNode

abstract class BindParameterMixin(node: ASTNode) : SqlCompositeElementImpl(node), SqlBindParameter {
hfhbd marked this conversation as resolved.
Show resolved Hide resolved
/**
* Overwrite, if the user provided sql parameter should be overwritten by sqldelight with [replaceWith].
*
* Some sql dialects support other bind parameter besides `?`, but sqldelight should still replace the
* user provided parameter with [replaceWith] for a homogen generated code.
*/
open fun replaceWith(isAsync: Boolean, index: Int): String = "?"
}
Expand Up @@ -19,6 +19,7 @@ import app.cash.sqldelight.core.lang.util.rawSqlText
import app.cash.sqldelight.core.lang.util.sqFile
import app.cash.sqldelight.core.psi.SqlDelightStmtClojureStmtList
import app.cash.sqldelight.dialect.api.IntermediateType
import app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin
import com.alecstrong.sql.psi.core.psi.SqlBinaryEqualityExpr
import com.alecstrong.sql.psi.core.psi.SqlBindExpr
import com.alecstrong.sql.psi.core.psi.SqlStmt
Expand Down Expand Up @@ -174,36 +175,39 @@ abstract class QueryGenerator(
precedingArrays.add(type.name)
argumentCounts.add("${type.name}.size")
} else {
nonArrayBindArgsCount += 1

if (!treatNullAsUnknownForEquality && type.javaType.isNullable) {
val parent = bindArg?.parent
if (parent is SqlBinaryEqualityExpr) {
needsFreshStatement = true

var symbol = parent.childOfType(SqlTypes.EQ) ?: parent.childOfType(SqlTypes.EQ2)
val nullableEquality: String
if (symbol != null) {
nullableEquality = "${symbol.leftWhitspace()}IS${symbol.rightWhitespace()}"
} else {
symbol = parent.childOfType(SqlTypes.NEQ) ?: parent.childOfType(SqlTypes.NEQ2)!!
nullableEquality = "${symbol.leftWhitspace()}IS NOT${symbol.rightWhitespace()}"
val bindParameter = bindArg?.bindParameter as? BindParameterMixin
if (bindParameter == null || bindParameter.text != "DEFAULT") {
nonArrayBindArgsCount += 1

if (!treatNullAsUnknownForEquality && type.javaType.isNullable) {
val parent = bindArg?.parent
if (parent is SqlBinaryEqualityExpr) {
needsFreshStatement = true

var symbol = parent.childOfType(SqlTypes.EQ) ?: parent.childOfType(SqlTypes.EQ2)
val nullableEquality: String
if (symbol != null) {
nullableEquality = "${symbol.leftWhitspace()}IS${symbol.rightWhitespace()}"
} else {
symbol = parent.childOfType(SqlTypes.NEQ) ?: parent.childOfType(SqlTypes.NEQ2)!!
nullableEquality = "${symbol.leftWhitspace()}IS NOT${symbol.rightWhitespace()}"
}

val block = CodeBlock.of("if (${type.name} == null) \"$nullableEquality\" else \"${symbol.text}\"")
replacements.add(symbol.range to "\${ $block }")
}

val block = CodeBlock.of("if (${type.name} == null) \"$nullableEquality\" else \"${symbol.text}\"")
replacements.add(symbol.range to "\${ $block }")
}
}

// Binds each parameter to the statement:
// statement.bindLong(0, id)
bindStatements.add(type.preparedStatementBinder(offset, extractedVariables[type]))
// Binds each parameter to the statement:
// statement.bindLong(0, id)
bindStatements.add(type.preparedStatementBinder(offset, extractedVariables[type]))

// Replace the named argument with a non named/indexed argument.
// This allows us to use the same algorithm for non Sqlite dialects
// :name becomes ?
if (bindArg != null) {
replacements.add(bindArg.range to "?")
hfhbd marked this conversation as resolved.
Show resolved Hide resolved
// Replace the named argument with a non named/indexed argument.
// This allows us to use the same algorithm for non Sqlite dialects
// :name becomes ?
if (bindParameter != null) {
replacements.add(bindArg.range to bindParameter.replaceWith(generateAsync, index = nonArrayBindArgsCount))
}
}
}
}
Expand Down Expand Up @@ -293,7 +297,7 @@ abstract class QueryGenerator(
"""
if (result%L == 0L) throw %T(%S)
""".trimIndent(),
if (generateAsync) ".await()" else ".value",
if (generateAsync) "" else ".value",
ClassName("app.cash.sqldelight.db", "OptimisticLockException"),
"UPDATE on ${query.tablesAffected.single().name} failed because optimistic lock ${optimisticLock.name} did not match",
)
Expand Down
Expand Up @@ -31,6 +31,7 @@ import app.cash.sqldelight.dialect.api.PrimitiveType.ARGUMENT
import app.cash.sqldelight.dialect.api.PrimitiveType.BOOLEAN
import app.cash.sqldelight.dialect.api.PrimitiveType.INTEGER
import app.cash.sqldelight.dialect.api.PrimitiveType.NULL
import app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin
import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement
import com.alecstrong.sql.psi.core.psi.SqlBindExpr
import com.alecstrong.sql.psi.core.psi.SqlBindParameter
Expand Down Expand Up @@ -91,29 +92,32 @@ abstract class BindableQuery(
val namesSeen = mutableSetOf<String>()
var maxIndexSeen = 0
statement.findChildrenOfType<SqlBindExpr>().forEach { bindArg ->
bindArg.bindParameter.node.findChildByType(SqlTypes.DIGIT)?.text?.toInt()?.let { index ->
if (!indexesSeen.add(index)) {
result.findAndReplace(bindArg, index) { it.index == index }
val bindParameter = bindArg.bindParameter
if (bindParameter is BindParameterMixin && bindParameter.text != "DEFAULT") {
bindParameter.node.findChildByType(SqlTypes.DIGIT)?.text?.toInt()?.let { index ->
if (!indexesSeen.add(index)) {
result.findAndReplace(bindArg, index) { it.index == index }
return@forEach
}
maxIndexSeen = maxOf(maxIndexSeen, index)
result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg)))
return@forEach
}
maxIndexSeen = maxOf(maxIndexSeen, index)
result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg)))
return@forEach
}
bindArg.bindParameter.identifier?.let {
if (!namesSeen.add(it.text)) {
result.findAndReplace(bindArg) { (_, type, _) -> type.name == it.text }
bindParameter.identifier?.let {
if (!namesSeen.add(it.text)) {
result.findAndReplace(bindArg) { (_, type, _) -> type.name == it.text }
return@forEach
}
val index = ++maxIndexSeen
indexesSeen.add(index)
manuallyNamedIndexes.add(index)
result.add(Argument(index, typeResolver.argumentType(bindArg).copy(name = it.text), mutableListOf(bindArg)))
return@forEach
}
val index = ++maxIndexSeen
indexesSeen.add(index)
manuallyNamedIndexes.add(index)
result.add(Argument(index, typeResolver.argumentType(bindArg).copy(name = it.text), mutableListOf(bindArg)))
return@forEach
result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg)))
}
val index = ++maxIndexSeen
indexesSeen.add(index)
result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg)))
}

// If there are still naming conflicts (edge case where the name we generate is the same as
Expand Down
2 changes: 2 additions & 0 deletions sqldelight-gradle-plugin/build.gradle
Expand Up @@ -124,6 +124,8 @@ tasks.named('dockerTest') {
":sqlite-migrations:publishAllPublicationsToInstallLocallyRepository",
":sqldelight-compiler:publishAllPublicationsToInstallLocallyRepository",
":sqldelight-gradle-plugin:publishAllPublicationsToInstallLocallyRepository",
":drivers:r2dbc-driver:publishAllPublicationsToInstallLocallyRepository",
":extensions:async-extensions:publishAllPublicationsToInstallLocallyRepository",
)
}

Expand Down
Expand Up @@ -16,6 +16,15 @@ class DialectIntegrationTests {
Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")
}

@Test fun integrationTestsMySqlAsync() {
val runner = GradleRunner.create()
.withCommonConfiguration(File("src/test/integration-mysql-async"))
.withArguments("clean", "check", "--stacktrace")

val result = runner.build()
Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")
}

@Test fun integrationTestsMySqlSchemaDefinitions() {
val runner = GradleRunner.create()
.withCommonConfiguration(File("src/test/integration-mysql-schema"))
Expand All @@ -34,6 +43,15 @@ class DialectIntegrationTests {
Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")
}

@Test fun integrationTestsPostgreSqlAsync() {
val runner = GradleRunner.create()
.withCommonConfiguration(File("src/test/integration-postgresql-async"))
.withArguments("clean", "check", "--stacktrace")

val result = runner.build()
Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")
}

@Test fun `dialect accepts version catalog dependency`() {
val runner = GradleRunner.create()
.withCommonConfiguration(File("src/test/integration-catalog"))
Expand Down
Expand Up @@ -8,7 +8,7 @@ CREATE TABLE dog (

insertDog:
INSERT INTO dog (name, breed, is_good, id)
VALUES (?, ?, ?, ?);
VALUES (?, ?, DEFAULT, ?);

selectDogs:
SELECT *
Expand Down
Expand Up @@ -29,7 +29,7 @@ class HsqlTest {
}

@Test fun simpleSelect() {
database.dogQueries.insertDog("Tilda", "Pomeranian", true, 1)
database.dogQueries.insertDog("Tilda", "Pomeranian", 1)
assertThat(database.dogQueries.selectDogs().executeAsOne())
.isEqualTo(
Dog(
Expand Down
Expand Up @@ -7,7 +7,7 @@ apply plugin: 'app.cash.sqldelight'

sqldelight {
MyDatabase {
packageName = "app.cash.sqldelight.mysql.integration"
packageName = "app.cash.sqldelight.mysql.integration.async"
dialect("app.cash.sqldelight:mysql-dialect:${app.cash.sqldelight.VersionKt.VERSION}")
generateAsync = true
}
Expand All @@ -26,7 +26,7 @@ dependencies {
implementation "org.testcontainers:r2dbc:1.16.2"
implementation "dev.miku:r2dbc-mysql:0.8.2.RELEASE"
implementation "app.cash.sqldelight:r2dbc-driver:${app.cash.sqldelight.VersionKt.VERSION}"
implementation "app.cash.sqldelight:coroutines-extensions:${app.cash.sqldelight.VersionKt.VERSION}"
implementation "app.cash.sqldelight:async-extensions:${app.cash.sqldelight.VersionKt.VERSION}"
implementation libs.truth
implementation libs.kotlin.coroutines.core
implementation libs.kotlin.coroutines.test
Expand Down
@@ -1,3 +1,3 @@
apply from: "../settings.gradle"

rootProject.name = 'sqldelight-mysql-integration'
rootProject.name = 'sqldelight-mysql-integration-async'
Expand Up @@ -6,7 +6,7 @@ CREATE TABLE dog (

insertDog:
INSERT INTO dog
VALUES (?, ?, ?);
VALUES (?, ?, DEFAULT);

selectDogs:
SELECT *
Expand Down
@@ -1,5 +1,8 @@
package app.cash.sqldelight.mysql.integration
package app.cash.sqldelight.mysql.integration.async

import app.cash.sqldelight.async.coroutines.awaitAsList
import app.cash.sqldelight.async.coroutines.awaitAsOne
import app.cash.sqldelight.async.coroutines.awaitCreate
import app.cash.sqldelight.driver.r2dbc.R2dbcDriver
import com.google.common.truth.Truth.assertThat
import io.r2dbc.spi.ConnectionFactories
Expand All @@ -13,13 +16,13 @@ class MySqlTest {
val connection = factory.create().awaitSingle()
val driver = R2dbcDriver(connection)

val db = MyDatabase(driver).also { MyDatabase.Schema.create(driver) }
val db = MyDatabase(driver).also { MyDatabase.Schema.awaitCreate(driver) }
block(db)
}

@Test fun simpleSelect() = runTest { database ->
database.dogQueries.insertDog("Tilda", "Pomeranian", true)
assertThat(database.dogQueries.selectDogs().executeAsOne())
database.dogQueries.insertDog("Tilda", "Pomeranian")
assertThat(database.dogQueries.selectDogs().awaitAsOne())
.isEqualTo(
Dog(
name = "Tilda",
Expand All @@ -32,15 +35,15 @@ class MySqlTest {
@Test
fun simpleSelectWithIn() = runTest { database ->
with(database) {
dogQueries.insertDog("Tilda", "Pomeranian", true)
dogQueries.insertDog("Tucker", "Portuguese Water Dog", true)
dogQueries.insertDog("Cujo", "Pomeranian", false)
dogQueries.insertDog("Buddy", "Pomeranian", true)
dogQueries.insertDog("Tilda", "Pomeranian")
dogQueries.insertDog("Tucker", "Portuguese Water Dog")
dogQueries.insertDog("Cujo", "Pomeranian")
dogQueries.insertDog("Buddy", "Pomeranian")
assertThat(
dogQueries.selectDogsByBreedAndNames(
breed = "Pomeranian",
name = listOf("Tilda", "Buddy"),
).executeAsList(),
).awaitAsList(),
)
.containsExactly(
Dog(
Expand Down