Skip to content

Commit

Permalink
Defaults (#3559)
Browse files Browse the repository at this point in the history
* Fix DEFAULT in binding

Add test for PostgreSQL

Spotless

Move mixin to sqldelight

Use mixin instead

Don't add DEFAULT bindings to query

Fix DEFAULT in binding

* Remove isDefault api

* Fix grammar

* Add mysql async test

* Add postgresql async test

* Fix wildards

* Fix code style

* Add default to mysql too

* Downgrade r2dbc spi due binary incompatible with drivers

Fix docker test dependencies

* Fix DEFAULT in binding

Add test for PostgreSQL

Spotless

Move mixin to sqldelight

Use mixin instead

Don't add DEFAULT bindings to query

Fix DEFAULT in binding

* Remove isDefault api

* Fix grammar

* Add mysql async test

* Add postgresql async test

* Fix wildards

* Fix code style

* Add default to mysql too

* Downgrade r2dbc spi due binary incompatible with drivers

Fix docker test dependencies

* Remove useless check for default

Co-authored-by: hfhbd <hfhbd@users.noreply.github.com>
  • Loading branch information
hfhbd and hfhbd committed Oct 3, 2022
1 parent 0dfca55 commit 328daf8
Show file tree
Hide file tree
Showing 30 changed files with 418 additions and 70 deletions.
Expand Up @@ -79,7 +79,8 @@ column_constraint ::= [ CONSTRAINT {identifier} ] (
implements = "com.alecstrong.sql.psi.core.psi.SqlColumnConstraint"
override = true
}
bind_parameter ::= ( DEFAULT | '?' | ':' {identifier} ) {
bind_parameter ::= DEFAULT | ( '?' | ':' {identifier} ) {
mixin = "app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin"
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlBindParameterImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlBindParameter"
override = true
Expand Down
Expand Up @@ -111,7 +111,8 @@ column_constraint ::= [ CONSTRAINT {identifier} ] (
implements = "com.alecstrong.sql.psi.core.psi.SqlColumnConstraint"
override = true
}
bind_parameter ::= ( '?' | ':' {identifier} ) {
bind_parameter ::= DEFAULT | ( '?' | ':' {identifier} ) {
mixin = "app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin"
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlBindParameterImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlBindParameter"
override = true
Expand Down
Expand Up @@ -110,15 +110,16 @@ type_name ::= (
implements = "com.alecstrong.sql.psi.core.psi.SqlTypeName"
override = true
}
bind_parameter ::= ( DEFAULT | '?' | ':' {identifier} ) {
bind_parameter ::= DEFAULT | ( '?' | ':' {identifier} ) {
mixin = "app.cash.sqldelight.dialects.postgresql.grammar.mixins.BindParameterMixin"
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlBindParameterImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlBindParameter"
override = true
}

identity_clause ::= 'IDENTITY'

generated_clause ::= GENERATED ( (ALWAYS AS <<expr '-1'>> 'STORED') | ( (ALWAYS | BY DEFAULT) AS identity_clause ) ) {
generated_clause ::= GENERATED ( (ALWAYS AS LP <<expr '-1'>> RP 'STORED') | ( (ALWAYS | BY DEFAULT) AS identity_clause ) ) {
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlGeneratedClauseImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlGeneratedClause"
override = true
Expand Down
@@ -0,0 +1,11 @@
package app.cash.sqldelight.dialects.postgresql.grammar.mixins

import app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin
import com.intellij.lang.ASTNode

abstract class BindParameterMixin(node: ASTNode) : BindParameterMixin(node) {
override fun replaceWith(isAsync: Boolean, index: Int): String = when {
isAsync -> "$$index"
else -> "?"
}
}
Expand Up @@ -15,6 +15,7 @@ class PostgreSqlFixturesTest(name: String, fixtureRoot: File) : FixturesTest(nam
"?1" to "?",
"?2" to "?",
"BLOB" to "TEXT",
"id TEXT GENERATED ALWAYS AS (2) UNIQUE NOT NULL" to "id TEXT GENERATED ALWAYS AS (2) STORED UNIQUE NOT NULL",
)

override fun setupDialect() {
Expand Down
Expand Up @@ -27,6 +27,7 @@ int_data_type ::= 'INTEGER'
real_data_type ::= 'REAL'

bind_parameter ::= ( '?' [digit] | ':' {identifier} ) {
mixin = "app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin"
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlBindParameterImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlBindParameter"
override = true
Expand Down
Expand Up @@ -48,7 +48,7 @@ class R2dbcDriver(private val connection: Connection) : SqlDriver {

return QueryResult.AsyncValue {
val result = prepared.execute().awaitSingle()
return@AsyncValue result.rowsUpdated.awaitFirstOrNull() ?: 0
return@AsyncValue result.rowsUpdated.awaitFirstOrNull()?.toLong() ?: 0
}
}

Expand Down
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Expand Up @@ -76,7 +76,7 @@ testhelp = { module = "co.touchlab:testhelp", version.ref = "testhelp" }
burst = { module = "com.squareup.burst:burst-junit4", version = "1.2.0" }
testParameterInjector = { module = "com.google.testparameterinjector:test-parameter-injector", version = "1.8" }

r2dbc = { module = "io.r2dbc:r2dbc-spi", version = "1.0.0.RELEASE" }
r2dbc = { module = "io.r2dbc:r2dbc-spi", version = "0.9.1.RELEASE" }

[plugins]
android-library = { id = "com.android.library", version.ref = "agp" }
Expand Down
@@ -0,0 +1,15 @@
package app.cash.sqldelight.dialect.grammar.mixins

import com.alecstrong.sql.psi.core.psi.SqlBindParameter
import com.alecstrong.sql.psi.core.psi.SqlCompositeElementImpl
import com.intellij.lang.ASTNode

abstract class BindParameterMixin(node: ASTNode) : SqlCompositeElementImpl(node), SqlBindParameter {
/**
* Overwrite, if the user provided sql parameter should be overwritten by sqldelight with [replaceWith].
*
* Some sql dialects support other bind parameter besides `?`, but sqldelight should still replace the
* user provided parameter with [replaceWith] for a homogen generated code.
*/
open fun replaceWith(isAsync: Boolean, index: Int): String = "?"
}
Expand Up @@ -19,6 +19,7 @@ import app.cash.sqldelight.core.lang.util.rawSqlText
import app.cash.sqldelight.core.lang.util.sqFile
import app.cash.sqldelight.core.psi.SqlDelightStmtClojureStmtList
import app.cash.sqldelight.dialect.api.IntermediateType
import app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin
import com.alecstrong.sql.psi.core.psi.SqlBinaryEqualityExpr
import com.alecstrong.sql.psi.core.psi.SqlBindExpr
import com.alecstrong.sql.psi.core.psi.SqlStmt
Expand Down Expand Up @@ -174,36 +175,39 @@ abstract class QueryGenerator(
precedingArrays.add(type.name)
argumentCounts.add("${type.name}.size")
} else {
nonArrayBindArgsCount += 1

if (!treatNullAsUnknownForEquality && type.javaType.isNullable) {
val parent = bindArg?.parent
if (parent is SqlBinaryEqualityExpr) {
needsFreshStatement = true

var symbol = parent.childOfType(SqlTypes.EQ) ?: parent.childOfType(SqlTypes.EQ2)
val nullableEquality: String
if (symbol != null) {
nullableEquality = "${symbol.leftWhitspace()}IS${symbol.rightWhitespace()}"
} else {
symbol = parent.childOfType(SqlTypes.NEQ) ?: parent.childOfType(SqlTypes.NEQ2)!!
nullableEquality = "${symbol.leftWhitspace()}IS NOT${symbol.rightWhitespace()}"
val bindParameter = bindArg?.bindParameter as? BindParameterMixin
if (bindParameter == null || bindParameter.text != "DEFAULT") {
nonArrayBindArgsCount += 1

if (!treatNullAsUnknownForEquality && type.javaType.isNullable) {
val parent = bindArg?.parent
if (parent is SqlBinaryEqualityExpr) {
needsFreshStatement = true

var symbol = parent.childOfType(SqlTypes.EQ) ?: parent.childOfType(SqlTypes.EQ2)
val nullableEquality: String
if (symbol != null) {
nullableEquality = "${symbol.leftWhitspace()}IS${symbol.rightWhitespace()}"
} else {
symbol = parent.childOfType(SqlTypes.NEQ) ?: parent.childOfType(SqlTypes.NEQ2)!!
nullableEquality = "${symbol.leftWhitspace()}IS NOT${symbol.rightWhitespace()}"
}

val block = CodeBlock.of("if (${type.name} == null) \"$nullableEquality\" else \"${symbol.text}\"")
replacements.add(symbol.range to "\${ $block }")
}

val block = CodeBlock.of("if (${type.name} == null) \"$nullableEquality\" else \"${symbol.text}\"")
replacements.add(symbol.range to "\${ $block }")
}
}

// Binds each parameter to the statement:
// statement.bindLong(0, id)
bindStatements.add(type.preparedStatementBinder(offset, extractedVariables[type]))
// Binds each parameter to the statement:
// statement.bindLong(0, id)
bindStatements.add(type.preparedStatementBinder(offset, extractedVariables[type]))

// Replace the named argument with a non named/indexed argument.
// This allows us to use the same algorithm for non Sqlite dialects
// :name becomes ?
if (bindArg != null) {
replacements.add(bindArg.range to "?")
// Replace the named argument with a non named/indexed argument.
// This allows us to use the same algorithm for non Sqlite dialects
// :name becomes ?
if (bindParameter != null) {
replacements.add(bindArg.range to bindParameter.replaceWith(generateAsync, index = nonArrayBindArgsCount))
}
}
}
}
Expand Down Expand Up @@ -293,7 +297,7 @@ abstract class QueryGenerator(
"""
if (result%L == 0L) throw %T(%S)
""".trimIndent(),
if (generateAsync) ".await()" else ".value",
if (generateAsync) "" else ".value",
ClassName("app.cash.sqldelight.db", "OptimisticLockException"),
"UPDATE on ${query.tablesAffected.single().name} failed because optimistic lock ${optimisticLock.name} did not match",
)
Expand Down
Expand Up @@ -32,6 +32,7 @@ import app.cash.sqldelight.dialect.api.PrimitiveType.ARGUMENT
import app.cash.sqldelight.dialect.api.PrimitiveType.BOOLEAN
import app.cash.sqldelight.dialect.api.PrimitiveType.INTEGER
import app.cash.sqldelight.dialect.api.PrimitiveType.NULL
import app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin
import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement
import com.alecstrong.sql.psi.core.psi.SqlBindExpr
import com.alecstrong.sql.psi.core.psi.SqlBindParameter
Expand Down Expand Up @@ -92,29 +93,32 @@ abstract class BindableQuery(
val namesSeen = mutableSetOf<String>()
var maxIndexSeen = 0
statement.findChildrenOfType<SqlBindExpr>().forEach { bindArg ->
bindArg.bindParameter.node.findChildByType(SqlTypes.DIGIT)?.text?.toInt()?.let { index ->
if (!indexesSeen.add(index)) {
result.findAndReplace(bindArg, index) { it.index == index }
val bindParameter = bindArg.bindParameter
if (bindParameter is BindParameterMixin && bindParameter.text != "DEFAULT") {
bindParameter.node.findChildByType(SqlTypes.DIGIT)?.text?.toInt()?.let { index ->
if (!indexesSeen.add(index)) {
result.findAndReplace(bindArg, index) { it.index == index }
return@forEach
}
maxIndexSeen = maxOf(maxIndexSeen, index)
result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg)))
return@forEach
}
maxIndexSeen = maxOf(maxIndexSeen, index)
result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg)))
return@forEach
}
bindArg.bindParameter.identifier?.let {
if (!namesSeen.add(it.text)) {
result.findAndReplace(bindArg) { (_, type, _) -> type.name == it.text }
bindParameter.identifier?.let {
if (!namesSeen.add(it.text)) {
result.findAndReplace(bindArg) { (_, type, _) -> type.name == it.text }
return@forEach
}
val index = ++maxIndexSeen
indexesSeen.add(index)
manuallyNamedIndexes.add(index)
result.add(Argument(index, typeResolver.argumentType(bindArg).copy(name = it.text), mutableListOf(bindArg)))
return@forEach
}
val index = ++maxIndexSeen
indexesSeen.add(index)
manuallyNamedIndexes.add(index)
result.add(Argument(index, typeResolver.argumentType(bindArg).copy(name = it.text), mutableListOf(bindArg)))
return@forEach
result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg)))
}
val index = ++maxIndexSeen
indexesSeen.add(index)
result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg)))
}

// If there are still naming conflicts (edge case where the name we generate is the same as
Expand Down
2 changes: 2 additions & 0 deletions sqldelight-gradle-plugin/build.gradle
Expand Up @@ -124,6 +124,8 @@ tasks.named('dockerTest') {
":sqlite-migrations:publishAllPublicationsToInstallLocallyRepository",
":sqldelight-compiler:publishAllPublicationsToInstallLocallyRepository",
":sqldelight-gradle-plugin:publishAllPublicationsToInstallLocallyRepository",
":drivers:r2dbc-driver:publishAllPublicationsToInstallLocallyRepository",
":extensions:async-extensions:publishAllPublicationsToInstallLocallyRepository",
)
}

Expand Down
Expand Up @@ -16,6 +16,15 @@ class DialectIntegrationTests {
Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")
}

@Test fun integrationTestsMySqlAsync() {
val runner = GradleRunner.create()
.withCommonConfiguration(File("src/test/integration-mysql-async"))
.withArguments("clean", "check", "--stacktrace")

val result = runner.build()
Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")
}

@Test fun integrationTestsMySqlSchemaDefinitions() {
val runner = GradleRunner.create()
.withCommonConfiguration(File("src/test/integration-mysql-schema"))
Expand All @@ -34,6 +43,15 @@ class DialectIntegrationTests {
Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")
}

@Test fun integrationTestsPostgreSqlAsync() {
val runner = GradleRunner.create()
.withCommonConfiguration(File("src/test/integration-postgresql-async"))
.withArguments("clean", "check", "--stacktrace")

val result = runner.build()
Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")
}

@Test fun `dialect accepts version catalog dependency`() {
val runner = GradleRunner.create()
.withCommonConfiguration(File("src/test/integration-catalog"))
Expand Down
Expand Up @@ -8,7 +8,7 @@ CREATE TABLE dog (

insertDog:
INSERT INTO dog (name, breed, is_good, id)
VALUES (?, ?, ?, ?);
VALUES (?, ?, DEFAULT, ?);

selectDogs:
SELECT *
Expand Down
Expand Up @@ -29,7 +29,7 @@ class HsqlTest {
}

@Test fun simpleSelect() {
database.dogQueries.insertDog("Tilda", "Pomeranian", true, 1)
database.dogQueries.insertDog("Tilda", "Pomeranian", 1)
assertThat(database.dogQueries.selectDogs().executeAsOne())
.isEqualTo(
Dog(
Expand Down
Expand Up @@ -7,7 +7,7 @@ apply plugin: 'app.cash.sqldelight'

sqldelight {
MyDatabase {
packageName = "app.cash.sqldelight.mysql.integration"
packageName = "app.cash.sqldelight.mysql.integration.async"
dialect("app.cash.sqldelight:mysql-dialect:${app.cash.sqldelight.VersionKt.VERSION}")
generateAsync = true
}
Expand All @@ -26,7 +26,7 @@ dependencies {
implementation "org.testcontainers:r2dbc:1.16.2"
implementation "dev.miku:r2dbc-mysql:0.8.2.RELEASE"
implementation "app.cash.sqldelight:r2dbc-driver:${app.cash.sqldelight.VersionKt.VERSION}"
implementation "app.cash.sqldelight:coroutines-extensions:${app.cash.sqldelight.VersionKt.VERSION}"
implementation "app.cash.sqldelight:async-extensions:${app.cash.sqldelight.VersionKt.VERSION}"
implementation libs.truth
implementation libs.kotlin.coroutines.core
implementation libs.kotlin.coroutines.test
Expand Down
@@ -1,3 +1,3 @@
apply from: "../settings.gradle"

rootProject.name = 'sqldelight-mysql-integration'
rootProject.name = 'sqldelight-mysql-integration-async'
Expand Up @@ -6,7 +6,7 @@ CREATE TABLE dog (

insertDog:
INSERT INTO dog
VALUES (?, ?, ?);
VALUES (?, ?, DEFAULT);

selectDogs:
SELECT *
Expand Down
@@ -1,5 +1,8 @@
package app.cash.sqldelight.mysql.integration
package app.cash.sqldelight.mysql.integration.async

import app.cash.sqldelight.async.coroutines.awaitAsList
import app.cash.sqldelight.async.coroutines.awaitAsOne
import app.cash.sqldelight.async.coroutines.awaitCreate
import app.cash.sqldelight.driver.r2dbc.R2dbcDriver
import com.google.common.truth.Truth.assertThat
import io.r2dbc.spi.ConnectionFactories
Expand All @@ -13,13 +16,13 @@ class MySqlTest {
val connection = factory.create().awaitSingle()
val driver = R2dbcDriver(connection)

val db = MyDatabase(driver).also { MyDatabase.Schema.create(driver) }
val db = MyDatabase(driver).also { MyDatabase.Schema.awaitCreate(driver) }
block(db)
}

@Test fun simpleSelect() = runTest { database ->
database.dogQueries.insertDog("Tilda", "Pomeranian", true)
assertThat(database.dogQueries.selectDogs().executeAsOne())
database.dogQueries.insertDog("Tilda", "Pomeranian")
assertThat(database.dogQueries.selectDogs().awaitAsOne())
.isEqualTo(
Dog(
name = "Tilda",
Expand All @@ -32,15 +35,15 @@ class MySqlTest {
@Test
fun simpleSelectWithIn() = runTest { database ->
with(database) {
dogQueries.insertDog("Tilda", "Pomeranian", true)
dogQueries.insertDog("Tucker", "Portuguese Water Dog", true)
dogQueries.insertDog("Cujo", "Pomeranian", false)
dogQueries.insertDog("Buddy", "Pomeranian", true)
dogQueries.insertDog("Tilda", "Pomeranian")
dogQueries.insertDog("Tucker", "Portuguese Water Dog")
dogQueries.insertDog("Cujo", "Pomeranian")
dogQueries.insertDog("Buddy", "Pomeranian")
assertThat(
dogQueries.selectDogsByBreedAndNames(
breed = "Pomeranian",
name = listOf("Tilda", "Buddy"),
).executeAsList(),
).awaitAsList(),
)
.containsExactly(
Dog(
Expand Down

0 comments on commit 328daf8

Please sign in to comment.