Skip to content

Commit

Permalink
Infer the Kotlin type of bind parameter, if possible, or fail with a …
Browse files Browse the repository at this point in the history
…better error message (#3413)

Co-authored-by: hfhbd <hfhbd@users.noreply.github.com>
  • Loading branch information
hfhbd and hfhbd committed Aug 6, 2022
1 parent 500d5c1 commit cc708e4
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 2 deletions.
Expand Up @@ -57,7 +57,11 @@ fun TypeResolver.encapsulatingType(
val types = exprList.map { resolvedType(it) }
val sqlTypes = types.map { it.dialectType }

val type = typeOrder.last { it in sqlTypes }
val type = typeOrder.lastOrNull { it in sqlTypes }
?: if (PrimitiveType.ARGUMENT in sqlTypes && typeOrder.size == 1) {
return IntermediateType(typeOrder.single())
} else error("The Kotlin type of the argument cannot be inferred, use CAST instead.")

if (!nullableIfAny && types.all { it.javaType.isNullable } ||
nullableIfAny && types.any { it.javaType.isNullable }
) {
Expand Down
Expand Up @@ -14,6 +14,7 @@ import com.squareup.kotlinpoet.asClassName
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import kotlin.test.assertFailsWith

class BindArgsTest {
@get:Rule val tempFolder = TemporaryFolder()
Expand Down Expand Up @@ -284,6 +285,28 @@ class BindArgsTest {
}
}

@Test fun `bind arg kotlin type cannot be inferred with ambiguous sql parameter types`() {
val file = FixtureCompiler.parseSql(
"""
|CREATE TABLE dummy(
| foo INTEGER
|);
|
|maxSupportsManySqlTypes:
|SELECT 1
|FROM dummy
|WHERE MAX(:input) > 1;
""".trimMargin(),
tempFolder,
)

val errorMessage = assertFailsWith<IllegalStateException> {
file.findChildrenOfType<SqlBindExpr>().single().argumentType()
}
assertThat(errorMessage.message)
.isEqualTo("The Kotlin type of the argument cannot be inferred, use CAST instead.")
}

@Test fun `bind args use proper binary operator precedence`() {
val file = FixtureCompiler.parseSql(
"""
Expand Down
Expand Up @@ -4,3 +4,6 @@ id INTEGER

selectFooWithId:
SELECT foo(id) FROM foo;

inferredType:
SELECT 1 FROM foo WHERE foo(:inferred) NOT NULL;
@@ -0,0 +1,31 @@
import app.cash.sqldelight.Query
import app.cash.sqldelight.driver.jdbc.JdbcDriver
import org.junit.Test
import schema.FooQueries
import java.sql.Connection
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds

class Testing {
@Test fun inferredCompiles() {
val fakeDriver = object : JdbcDriver() {
override fun getConnection() = TODO()
override fun closeConnection(connection: Connection) = Unit
override fun addListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun removeListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun notifyListeners(queryKeys: Array<String>) = Unit
}
FooQueries(fakeDriver).inferredType(1.seconds)
}

@Test fun customFunctionReturnsDuration() {
val fakeDriver = object : JdbcDriver() {
override fun getConnection() = TODO()
override fun closeConnection(connection: Connection) = Unit
override fun addListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun removeListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun notifyListeners(queryKeys: Array<String>) = Unit
}
val unused: Duration = FooQueries(fakeDriver).selectFooWithId().executeAsOne()
}
}
Expand Up @@ -53,7 +53,7 @@ class DialectIntegrationTests {
@Test fun customFunctionDialect() {
val runner = GradleRunner.create()
.withCommonConfiguration(File("src/test/custom-dialect"))
.withArguments("clean", "assemble", "--stacktrace")
.withArguments("clean", "compileTestKotlin", "--stacktrace")

val result = runner.build()
Truth.assertThat(result.output).contains("BUILD SUCCESSFUL")
Expand Down

0 comments on commit cc708e4

Please sign in to comment.