Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Defaults #3559

Merged
merged 24 commits into from Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -79,7 +79,8 @@ column_constraint ::= [ CONSTRAINT {identifier} ] (
implements = "com.alecstrong.sql.psi.core.psi.SqlColumnConstraint"
override = true
}
bind_parameter ::= ( DEFAULT | '?' | ':' {identifier} ) {
bind_parameter ::= DEFAULT | ( '?' | ':' {identifier} ) {
mixin = "app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin"
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlBindParameterImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlBindParameter"
override = true
Expand Down
Expand Up @@ -111,7 +111,8 @@ column_constraint ::= [ CONSTRAINT {identifier} ] (
implements = "com.alecstrong.sql.psi.core.psi.SqlColumnConstraint"
override = true
}
bind_parameter ::= ( '?' | ':' {identifier} ) {
bind_parameter ::= DEFAULT | ( '?' | ':' {identifier} ) {
mixin = "app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin"
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlBindParameterImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlBindParameter"
override = true
Expand Down
Expand Up @@ -110,15 +110,16 @@ type_name ::= (
implements = "com.alecstrong.sql.psi.core.psi.SqlTypeName"
override = true
}
bind_parameter ::= ( DEFAULT | '?' | ':' {identifier} ) {
bind_parameter ::= DEFAULT | ( '?' | ':' {identifier} ) {
mixin = "app.cash.sqldelight.dialects.postgresql.grammar.mixins.BindParameterMixin"
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlBindParameterImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlBindParameter"
override = true
}

identity_clause ::= 'IDENTITY'

generated_clause ::= GENERATED ( (ALWAYS AS <<expr '-1'>> 'STORED') | ( (ALWAYS | BY DEFAULT) AS identity_clause ) ) {
generated_clause ::= GENERATED ( (ALWAYS AS LP <<expr '-1'>> RP 'STORED') | ( (ALWAYS | BY DEFAULT) AS identity_clause ) ) {
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlGeneratedClauseImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlGeneratedClause"
override = true
Expand Down
@@ -0,0 +1,11 @@
package app.cash.sqldelight.dialects.postgresql.grammar.mixins

import app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin
import com.intellij.lang.ASTNode

abstract class BindParameterMixin(node: ASTNode) : BindParameterMixin(node) {
override fun replaceWith(isAsync: Boolean, index: Int): String = when {
isAsync -> "$$index"
else -> "?"
}
}
Expand Up @@ -15,6 +15,7 @@ class PostgreSqlFixturesTest(name: String, fixtureRoot: File) : FixturesTest(nam
"?1" to "?",
"?2" to "?",
"BLOB" to "TEXT",
"id TEXT GENERATED ALWAYS AS (2) UNIQUE NOT NULL" to "id TEXT GENERATED ALWAYS AS (2) STORED UNIQUE NOT NULL",
)

override fun setupDialect() {
Expand Down
Expand Up @@ -27,6 +27,7 @@ int_data_type ::= 'INTEGER'
real_data_type ::= 'REAL'

bind_parameter ::= ( '?' [digit] | ':' {identifier} ) {
mixin = "app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin"
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlBindParameterImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlBindParameter"
override = true
Expand Down
Expand Up @@ -48,7 +48,7 @@ class R2dbcDriver(private val connection: Connection) : SqlDriver {

return QueryResult.AsyncValue {
val result = prepared.execute().awaitSingle()
return@AsyncValue result.rowsUpdated.awaitFirstOrNull() ?: 0
return@AsyncValue result.rowsUpdated.awaitFirstOrNull()?.toLong() ?: 0
}
}

Expand Down
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Expand Up @@ -76,7 +76,7 @@ testhelp = { module = "co.touchlab:testhelp", version.ref = "testhelp" }
burst = { module = "com.squareup.burst:burst-junit4", version = "1.2.0" }
testParameterInjector = { module = "com.google.testparameterinjector:test-parameter-injector", version = "1.8" }

r2dbc = { module = "io.r2dbc:r2dbc-spi", version = "1.0.0.RELEASE" }
r2dbc = { module = "io.r2dbc:r2dbc-spi", version = "0.9.1.RELEASE" }

[plugins]
android-library = { id = "com.android.library", version.ref = "agp" }
Expand Down
@@ -0,0 +1,15 @@
package app.cash.sqldelight.dialect.grammar.mixins

import com.alecstrong.sql.psi.core.psi.SqlBindParameter
import com.alecstrong.sql.psi.core.psi.SqlCompositeElementImpl
import com.intellij.lang.ASTNode

abstract class BindParameterMixin(node: ASTNode) : SqlCompositeElementImpl(node), SqlBindParameter {
/**
* Overwrite, if the user provided sql parameter should be overwritten by sqldelight with [replaceWith].
*
* Some sql dialects support other bind parameter besides `?`, but sqldelight should still replace the
* user provided parameter with [replaceWith] for a homogen generated code.
*/
open fun replaceWith(isAsync: Boolean, index: Int): String = "?"
}
Expand Up @@ -19,6 +19,7 @@ import app.cash.sqldelight.core.lang.util.rawSqlText
import app.cash.sqldelight.core.lang.util.sqFile
import app.cash.sqldelight.core.psi.SqlDelightStmtClojureStmtList
import app.cash.sqldelight.dialect.api.IntermediateType
import app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin
import com.alecstrong.sql.psi.core.psi.SqlBinaryEqualityExpr
import com.alecstrong.sql.psi.core.psi.SqlBindExpr
import com.alecstrong.sql.psi.core.psi.SqlStmt
Expand Down Expand Up @@ -174,36 +175,39 @@ abstract class QueryGenerator(
precedingArrays.add(type.name)
argumentCounts.add("${type.name}.size")
} else {
nonArrayBindArgsCount += 1

if (!treatNullAsUnknownForEquality && type.javaType.isNullable) {
val parent = bindArg?.parent
if (parent is SqlBinaryEqualityExpr) {
needsFreshStatement = true

var symbol = parent.childOfType(SqlTypes.EQ) ?: parent.childOfType(SqlTypes.EQ2)
val nullableEquality: String
if (symbol != null) {
nullableEquality = "${symbol.leftWhitspace()}IS${symbol.rightWhitespace()}"
} else {
symbol = parent.childOfType(SqlTypes.NEQ) ?: parent.childOfType(SqlTypes.NEQ2)!!
nullableEquality = "${symbol.leftWhitspace()}IS NOT${symbol.rightWhitespace()}"
val bindParameter = bindArg?.bindParameter as? BindParameterMixin
if (bindParameter == null || bindParameter.text != "DEFAULT") {
nonArrayBindArgsCount += 1

if (!treatNullAsUnknownForEquality && type.javaType.isNullable) {
val parent = bindArg?.parent
if (parent is SqlBinaryEqualityExpr) {
needsFreshStatement = true

var symbol = parent.childOfType(SqlTypes.EQ) ?: parent.childOfType(SqlTypes.EQ2)
val nullableEquality: String
if (symbol != null) {
nullableEquality = "${symbol.leftWhitspace()}IS${symbol.rightWhitespace()}"
} else {
symbol = parent.childOfType(SqlTypes.NEQ) ?: parent.childOfType(SqlTypes.NEQ2)!!
nullableEquality = "${symbol.leftWhitspace()}IS NOT${symbol.rightWhitespace()}"
}

val block = CodeBlock.of("if (${type.name} == null) \"$nullableEquality\" else \"${symbol.text}\"")
replacements.add(symbol.range to "\${ $block }")
}

val block = CodeBlock.of("if (${type.name} == null) \"$nullableEquality\" else \"${symbol.text}\"")
replacements.add(symbol.range to "\${ $block }")
}
}

// Binds each parameter to the statement:
// statement.bindLong(0, id)
bindStatements.add(type.preparedStatementBinder(offset, extractedVariables[type]))
// Binds each parameter to the statement:
// statement.bindLong(0, id)
bindStatements.add(type.preparedStatementBinder(offset, extractedVariables[type]))

// Replace the named argument with a non named/indexed argument.
// This allows us to use the same algorithm for non Sqlite dialects
// :name becomes ?
if (bindArg != null) {
replacements.add(bindArg.range to "?")
// Replace the named argument with a non named/indexed argument.
// This allows us to use the same algorithm for non Sqlite dialects
// :name becomes ?
if (bindParameter != null) {
replacements.add(bindArg.range to bindParameter.replaceWith(generateAsync, index = nonArrayBindArgsCount))
}
}
}
}
Expand Down Expand Up @@ -293,7 +297,7 @@ abstract class QueryGenerator(
"""
if (result%L == 0L) throw %T(%S)
""".trimIndent(),
if (generateAsync) ".await()" else ".value",
if (generateAsync) "" else ".value",
ClassName("app.cash.sqldelight.db", "OptimisticLockException"),
"UPDATE on ${query.tablesAffected.single().name} failed because optimistic lock ${optimisticLock.name} did not match",
)
Expand Down
Expand Up @@ -32,6 +32,7 @@ import app.cash.sqldelight.dialect.api.PrimitiveType.ARGUMENT
import app.cash.sqldelight.dialect.api.PrimitiveType.BOOLEAN
import app.cash.sqldelight.dialect.api.PrimitiveType.INTEGER
import app.cash.sqldelight.dialect.api.PrimitiveType.NULL
import app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin
import com.alecstrong.sql.psi.core.psi.SqlAnnotatedElement
import com.alecstrong.sql.psi.core.psi.SqlBindExpr
import com.alecstrong.sql.psi.core.psi.SqlBindParameter
Expand Down Expand Up @@ -92,29 +93,32 @@ abstract class BindableQuery(
val namesSeen = mutableSetOf<String>()
var maxIndexSeen = 0
statement.findChildrenOfType<SqlBindExpr>().forEach { bindArg ->
bindArg.bindParameter.node.findChildByType(SqlTypes.DIGIT)?.text?.toInt()?.let { index ->
if (!indexesSeen.add(index)) {
result.findAndReplace(bindArg, index) { it.index == index }
val bindParameter = bindArg.bindParameter
if (bindParameter is BindParameterMixin && bindParameter.text != "DEFAULT") {
bindParameter.node.findChildByType(SqlTypes.DIGIT)?.text?.toInt()?.let { index ->
if (!indexesSeen.add(index)) {
result.findAndReplace(bindArg, index) { it.index == index }
return@forEach
}
maxIndexSeen = maxOf(maxIndexSeen, index)
result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg)))
return@forEach
}
maxIndexSeen = maxOf(maxIndexSeen, index)
result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg)))
return@forEach
}
bindArg.bindParameter.identifier?.let {
if (!namesSeen.add(it.text)) {
result.findAndReplace(bindArg) { (_, type, _) -> type.name == it.text }
bindParameter.identifier?.let {
if (!namesSeen.add(it.text)) {
result.findAndReplace(bindArg) { (_, type, _) -> type.name == it.text }
return@forEach
}
val index = ++maxIndexSeen
indexesSeen.add(index)
manuallyNamedIndexes.add(index)
result.add(Argument(index, typeResolver.argumentType(bindArg).copy(name = it.text), mutableListOf(bindArg)))
return@forEach
}
val index = ++maxIndexSeen
indexesSeen.add(index)
manuallyNamedIndexes.add(index)
result.add(Argument(index, typeResolver.argumentType(bindArg).copy(name = it.text), mutableListOf(bindArg)))
return@forEach
result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg)))
}
val index = ++maxIndexSeen
indexesSeen.add(index)
result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg)))
}

// If there are still naming conflicts (edge case where the name we generate is the same as
Expand Down
2 changes: 2 additions & 0 deletions sqldelight-gradle-plugin/build.gradle
Expand Up @@ -124,6 +124,8 @@ tasks.named('dockerTest') {
":sqlite-migrations:publishAllPublicationsToInstallLocallyRepository",
":sqldelight-compiler:publishAllPublicationsToInstallLocallyRepository",
":sqldelight-gradle-plugin:publishAllPublicationsToInstallLocallyRepository",
":drivers:r2dbc-driver:publishAllPublicationsToInstallLocallyRepository",
":extensions:async-extensions:publishAllPublicationsToInstallLocallyRepository",
)
}

Expand Down
Expand Up @@ -16,6 +16,15 @@ class DialectIntegrationTests {
Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")
}

@Test fun integrationTestsMySqlAsync() {
val runner = GradleRunner.create()
.withCommonConfiguration(File("src/test/integration-mysql-async"))
.withArguments("clean", "check", "--stacktrace")

val result = runner.build()
Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")
}

@Test fun integrationTestsMySqlSchemaDefinitions() {
val runner = GradleRunner.create()
.withCommonConfiguration(File("src/test/integration-mysql-schema"))
Expand All @@ -34,6 +43,15 @@ class DialectIntegrationTests {
Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")
}

@Test fun integrationTestsPostgreSqlAsync() {
val runner = GradleRunner.create()
.withCommonConfiguration(File("src/test/integration-postgresql-async"))
.withArguments("clean", "check", "--stacktrace")

val result = runner.build()
Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")
}

@Test fun `dialect accepts version catalog dependency`() {
val runner = GradleRunner.create()
.withCommonConfiguration(File("src/test/integration-catalog"))
Expand Down
Expand Up @@ -8,7 +8,7 @@ CREATE TABLE dog (

insertDog:
INSERT INTO dog (name, breed, is_good, id)
VALUES (?, ?, ?, ?);
VALUES (?, ?, DEFAULT, ?);

selectDogs:
SELECT *
Expand Down
Expand Up @@ -29,7 +29,7 @@ class HsqlTest {
}

@Test fun simpleSelect() {
database.dogQueries.insertDog("Tilda", "Pomeranian", true, 1)
database.dogQueries.insertDog("Tilda", "Pomeranian", 1)
assertThat(database.dogQueries.selectDogs().executeAsOne())
.isEqualTo(
Dog(
Expand Down
Expand Up @@ -7,7 +7,7 @@ apply plugin: 'app.cash.sqldelight'

sqldelight {
MyDatabase {
packageName = "app.cash.sqldelight.mysql.integration"
packageName = "app.cash.sqldelight.mysql.integration.async"
dialect("app.cash.sqldelight:mysql-dialect:${app.cash.sqldelight.VersionKt.VERSION}")
generateAsync = true
}
Expand All @@ -26,7 +26,7 @@ dependencies {
implementation "org.testcontainers:r2dbc:1.16.2"
implementation "dev.miku:r2dbc-mysql:0.8.2.RELEASE"
implementation "app.cash.sqldelight:r2dbc-driver:${app.cash.sqldelight.VersionKt.VERSION}"
implementation "app.cash.sqldelight:coroutines-extensions:${app.cash.sqldelight.VersionKt.VERSION}"
implementation "app.cash.sqldelight:async-extensions:${app.cash.sqldelight.VersionKt.VERSION}"
implementation libs.truth
implementation libs.kotlin.coroutines.core
implementation libs.kotlin.coroutines.test
Expand Down
@@ -1,3 +1,3 @@
apply from: "../settings.gradle"

rootProject.name = 'sqldelight-mysql-integration'
rootProject.name = 'sqldelight-mysql-integration-async'
Expand Up @@ -6,7 +6,7 @@ CREATE TABLE dog (

insertDog:
INSERT INTO dog
VALUES (?, ?, ?);
VALUES (?, ?, DEFAULT);

selectDogs:
SELECT *
Expand Down
@@ -1,5 +1,8 @@
package app.cash.sqldelight.mysql.integration
package app.cash.sqldelight.mysql.integration.async

import app.cash.sqldelight.async.coroutines.awaitAsList
import app.cash.sqldelight.async.coroutines.awaitAsOne
import app.cash.sqldelight.async.coroutines.awaitCreate
import app.cash.sqldelight.driver.r2dbc.R2dbcDriver
import com.google.common.truth.Truth.assertThat
import io.r2dbc.spi.ConnectionFactories
Expand All @@ -13,13 +16,13 @@ class MySqlTest {
val connection = factory.create().awaitSingle()
val driver = R2dbcDriver(connection)

val db = MyDatabase(driver).also { MyDatabase.Schema.create(driver) }
val db = MyDatabase(driver).also { MyDatabase.Schema.awaitCreate(driver) }
block(db)
}

@Test fun simpleSelect() = runTest { database ->
database.dogQueries.insertDog("Tilda", "Pomeranian", true)
assertThat(database.dogQueries.selectDogs().executeAsOne())
database.dogQueries.insertDog("Tilda", "Pomeranian")
assertThat(database.dogQueries.selectDogs().awaitAsOne())
.isEqualTo(
Dog(
name = "Tilda",
Expand All @@ -32,15 +35,15 @@ class MySqlTest {
@Test
fun simpleSelectWithIn() = runTest { database ->
with(database) {
dogQueries.insertDog("Tilda", "Pomeranian", true)
dogQueries.insertDog("Tucker", "Portuguese Water Dog", true)
dogQueries.insertDog("Cujo", "Pomeranian", false)
dogQueries.insertDog("Buddy", "Pomeranian", true)
dogQueries.insertDog("Tilda", "Pomeranian")
dogQueries.insertDog("Tucker", "Portuguese Water Dog")
dogQueries.insertDog("Cujo", "Pomeranian")
dogQueries.insertDog("Buddy", "Pomeranian")
assertThat(
dogQueries.selectDogsByBreedAndNames(
breed = "Pomeranian",
name = listOf("Tilda", "Buddy"),
).executeAsList(),
).awaitAsList(),
)
.containsExactly(
Dog(
Expand Down