diff --git a/sqldelight-compiler/dialect/src/main/kotlin/app/cash/sqldelight/dialect/api/TypeResolver.kt b/sqldelight-compiler/dialect/src/main/kotlin/app/cash/sqldelight/dialect/api/TypeResolver.kt index 4fbc778c46f..1d99ce9a9f1 100644 --- a/sqldelight-compiler/dialect/src/main/kotlin/app/cash/sqldelight/dialect/api/TypeResolver.kt +++ b/sqldelight-compiler/dialect/src/main/kotlin/app/cash/sqldelight/dialect/api/TypeResolver.kt @@ -57,10 +57,18 @@ fun TypeResolver.encapsulatingType( val types = exprList.map { resolvedType(it) } val sqlTypes = types.map { it.dialectType } - val type = typeOrder.lastOrNull { it in sqlTypes } - ?: if (PrimitiveType.ARGUMENT in sqlTypes && typeOrder.size == 1) { + if (PrimitiveType.ARGUMENT in sqlTypes) { + if (typeOrder.size == 1) { return IntermediateType(typeOrder.single()) - } else error("The Kotlin type of the argument cannot be inferred, use CAST instead.") + } + val otherFunctionParameters = sqlTypes.distinct() - PrimitiveType.ARGUMENT + if (otherFunctionParameters.size == 1) { + return IntermediateType(otherFunctionParameters.single()) + } + error("The Kotlin type of the argument cannot be inferred, use CAST instead.") + } + + val type = typeOrder.last { it in sqlTypes } if (!nullableIfAny && types.all { it.javaType.isNullable } || nullableIfAny && types.any { it.javaType.isNullable } diff --git a/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/BindArgsTest.kt b/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/BindArgsTest.kt index b9d9296e066..34977a2363b 100644 --- a/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/BindArgsTest.kt +++ b/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/BindArgsTest.kt @@ -307,6 +307,49 @@ class BindArgsTest { .isEqualTo("The Kotlin type of the argument cannot be inferred, use CAST instead.") } + @Test fun `bind arg kotlin type can be inferred with other types`() { + val file = FixtureCompiler.parseSql( + """ + |CREATE TABLE dummy( + | foo INTEGER + |); + | + |inferredNullableLong: + |SELECT 1 + |FROM dummy + |WHERE MAX(1, :input) > 1; + """.trimMargin(), + tempFolder, + ) + + file.findChildrenOfType().map { it.argumentType() }.let { args -> + assertThat(args[0].dialectType).isEqualTo(PrimitiveType.INTEGER) + assertThat(args[0].javaType).isEqualTo(Long::class.asClassName().copy(nullable = true)) + } + } + + @Test fun `bind arg kotlin type cannot be inferred with other different types`() { + val file = FixtureCompiler.parseSql( + """ + |CREATE TABLE dummy( + | foo INTEGER + |); + | + |differentSqlTypes: + |SELECT 1 + |FROM dummy + |WHERE MAX(1, 'FOO', :input) > 1; + """.trimMargin(), + tempFolder, + ) + + val errorMessage = assertFailsWith { + file.findChildrenOfType().single().argumentType() + } + assertThat(errorMessage.message) + .isEqualTo("The Kotlin type of the argument cannot be inferred, use CAST instead.") + } + @Test fun `bind args use proper binary operator precedence`() { val file = FixtureCompiler.parseSql( """ diff --git a/sqldelight-gradle-plugin/src/test/custom-dialect/src/main/sqldelight/schema/Foo.sq b/sqldelight-gradle-plugin/src/test/custom-dialect/src/main/sqldelight/schema/Foo.sq index bbe3af719ba..8d42d71a8a0 100644 --- a/sqldelight-gradle-plugin/src/test/custom-dialect/src/main/sqldelight/schema/Foo.sq +++ b/sqldelight-gradle-plugin/src/test/custom-dialect/src/main/sqldelight/schema/Foo.sq @@ -7,3 +7,6 @@ SELECT foo(id) FROM foo; inferredType: SELECT 1 FROM foo WHERE foo(:inferred) NOT NULL; + +inferredTypeFromMax: +SELECT 1 FROM foo WHERE max(1, :inferred); diff --git a/sqldelight-gradle-plugin/src/test/custom-dialect/src/test/kotlin/Testing.kt b/sqldelight-gradle-plugin/src/test/custom-dialect/src/test/kotlin/Testing.kt index caa1f55cd1e..e373264e9e7 100644 --- a/sqldelight-gradle-plugin/src/test/custom-dialect/src/test/kotlin/Testing.kt +++ b/sqldelight-gradle-plugin/src/test/custom-dialect/src/test/kotlin/Testing.kt @@ -7,25 +7,23 @@ import kotlin.time.Duration import kotlin.time.Duration.Companion.seconds class Testing { - @Test fun inferredCompiles() { - val fakeDriver = object : JdbcDriver() { - override fun getConnection() = TODO() - override fun closeConnection(connection: Connection) = Unit - override fun addListener(listener: Query.Listener, queryKeys: Array) = Unit - override fun removeListener(listener: Query.Listener, queryKeys: Array) = Unit - override fun notifyListeners(queryKeys: Array) = Unit - } - FooQueries(fakeDriver).inferredType(1.seconds) + val fakeDriver = object : JdbcDriver() { + override fun getConnection() = TODO() + override fun closeConnection(connection: Connection) = Unit + override fun addListener(listener: Query.Listener, queryKeys: Array) = Unit + override fun removeListener(listener: Query.Listener, queryKeys: Array) = Unit + override fun notifyListeners(queryKeys: Array) = Unit } @Test fun customFunctionReturnsDuration() { - val fakeDriver = object : JdbcDriver() { - override fun getConnection() = TODO() - override fun closeConnection(connection: Connection) = Unit - override fun addListener(listener: Query.Listener, queryKeys: Array) = Unit - override fun removeListener(listener: Query.Listener, queryKeys: Array) = Unit - override fun notifyListeners(queryKeys: Array) = Unit - } val unused: Duration = FooQueries(fakeDriver).selectFooWithId().executeAsOne() } + + @Test fun inferredCompiles() { + FooQueries(fakeDriver).inferredType(1.seconds) + } + + @Test fun inferredTypeFromMaxIsLong() { + FooQueries(fakeDriver).inferredTypeFromMax(1L).executeAsOne() + } }