From 3e065e92ddae792490bdb827f802995bd22d3a1b Mon Sep 17 00:00:00 2001 From: hfhbd Date: Fri, 5 Aug 2022 22:59:21 +0200 Subject: [PATCH] Infer the Kotlin type of bind parameter, if possible, or fail with a better error message --- .../sqldelight/dialect/api/TypeResolver.kt | 6 +++- .../app/cash/sqldelight/core/BindArgsTest.kt | 23 ++++++++++++++ .../src/main/sqldelight/schema/Foo.sq | 3 ++ .../custom-dialect/src/test/kotlin/Testing.kt | 31 +++++++++++++++++++ .../dialect/DialectIntegrationTests.kt | 2 +- 5 files changed, 63 insertions(+), 2 deletions(-) create mode 100644 sqldelight-gradle-plugin/src/test/custom-dialect/src/test/kotlin/Testing.kt diff --git a/sqldelight-compiler/dialect/src/main/kotlin/app/cash/sqldelight/dialect/api/TypeResolver.kt b/sqldelight-compiler/dialect/src/main/kotlin/app/cash/sqldelight/dialect/api/TypeResolver.kt index 65906757a59..4fbc778c46f 100644 --- a/sqldelight-compiler/dialect/src/main/kotlin/app/cash/sqldelight/dialect/api/TypeResolver.kt +++ b/sqldelight-compiler/dialect/src/main/kotlin/app/cash/sqldelight/dialect/api/TypeResolver.kt @@ -57,7 +57,11 @@ fun TypeResolver.encapsulatingType( val types = exprList.map { resolvedType(it) } val sqlTypes = types.map { it.dialectType } - val type = typeOrder.last { it in sqlTypes } + val type = typeOrder.lastOrNull { it in sqlTypes } + ?: if (PrimitiveType.ARGUMENT in sqlTypes && typeOrder.size == 1) { + return IntermediateType(typeOrder.single()) + } else error("The Kotlin type of the argument cannot be inferred, use CAST instead.") + if (!nullableIfAny && types.all { it.javaType.isNullable } || nullableIfAny && types.any { it.javaType.isNullable } ) { diff --git a/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/BindArgsTest.kt b/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/BindArgsTest.kt index ea83424366e..b9d9296e066 100644 --- a/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/BindArgsTest.kt +++ b/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/BindArgsTest.kt @@ -14,6 +14,7 @@ import com.squareup.kotlinpoet.asClassName import org.junit.Rule import org.junit.Test import org.junit.rules.TemporaryFolder +import kotlin.test.assertFailsWith class BindArgsTest { @get:Rule val tempFolder = TemporaryFolder() @@ -284,6 +285,28 @@ class BindArgsTest { } } + @Test fun `bind arg kotlin type cannot be inferred with ambiguous sql parameter types`() { + val file = FixtureCompiler.parseSql( + """ + |CREATE TABLE dummy( + | foo INTEGER + |); + | + |maxSupportsManySqlTypes: + |SELECT 1 + |FROM dummy + |WHERE MAX(:input) > 1; + """.trimMargin(), + tempFolder, + ) + + val errorMessage = assertFailsWith { + file.findChildrenOfType().single().argumentType() + } + assertThat(errorMessage.message) + .isEqualTo("The Kotlin type of the argument cannot be inferred, use CAST instead.") + } + @Test fun `bind args use proper binary operator precedence`() { val file = FixtureCompiler.parseSql( """ diff --git a/sqldelight-gradle-plugin/src/test/custom-dialect/src/main/sqldelight/schema/Foo.sq b/sqldelight-gradle-plugin/src/test/custom-dialect/src/main/sqldelight/schema/Foo.sq index b8dc4f3fee2..bbe3af719ba 100644 --- a/sqldelight-gradle-plugin/src/test/custom-dialect/src/main/sqldelight/schema/Foo.sq +++ b/sqldelight-gradle-plugin/src/test/custom-dialect/src/main/sqldelight/schema/Foo.sq @@ -4,3 +4,6 @@ id INTEGER selectFooWithId: SELECT foo(id) FROM foo; + +inferredType: +SELECT 1 FROM foo WHERE foo(:inferred) NOT NULL; diff --git a/sqldelight-gradle-plugin/src/test/custom-dialect/src/test/kotlin/Testing.kt b/sqldelight-gradle-plugin/src/test/custom-dialect/src/test/kotlin/Testing.kt new file mode 100644 index 00000000000..caa1f55cd1e --- /dev/null +++ b/sqldelight-gradle-plugin/src/test/custom-dialect/src/test/kotlin/Testing.kt @@ -0,0 +1,31 @@ +import app.cash.sqldelight.Query +import app.cash.sqldelight.driver.jdbc.JdbcDriver +import org.junit.Test +import schema.FooQueries +import java.sql.Connection +import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds + +class Testing { + @Test fun inferredCompiles() { + val fakeDriver = object : JdbcDriver() { + override fun getConnection() = TODO() + override fun closeConnection(connection: Connection) = Unit + override fun addListener(listener: Query.Listener, queryKeys: Array) = Unit + override fun removeListener(listener: Query.Listener, queryKeys: Array) = Unit + override fun notifyListeners(queryKeys: Array) = Unit + } + FooQueries(fakeDriver).inferredType(1.seconds) + } + + @Test fun customFunctionReturnsDuration() { + val fakeDriver = object : JdbcDriver() { + override fun getConnection() = TODO() + override fun closeConnection(connection: Connection) = Unit + override fun addListener(listener: Query.Listener, queryKeys: Array) = Unit + override fun removeListener(listener: Query.Listener, queryKeys: Array) = Unit + override fun notifyListeners(queryKeys: Array) = Unit + } + val unused: Duration = FooQueries(fakeDriver).selectFooWithId().executeAsOne() + } +} diff --git a/sqldelight-gradle-plugin/src/test/kotlin/app/cash/sqldelight/dialect/DialectIntegrationTests.kt b/sqldelight-gradle-plugin/src/test/kotlin/app/cash/sqldelight/dialect/DialectIntegrationTests.kt index cd6b0bbcf7d..222e812531f 100644 --- a/sqldelight-gradle-plugin/src/test/kotlin/app/cash/sqldelight/dialect/DialectIntegrationTests.kt +++ b/sqldelight-gradle-plugin/src/test/kotlin/app/cash/sqldelight/dialect/DialectIntegrationTests.kt @@ -53,7 +53,7 @@ class DialectIntegrationTests { @Test fun customFunctionDialect() { val runner = GradleRunner.create() .withCommonConfiguration(File("src/test/custom-dialect")) - .withArguments("clean", "assemble", "--stacktrace") + .withArguments("clean", "compileTestKotlin", "--stacktrace") val result = runner.build() Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")