Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support inferring kotlin type from other parameters #3431

Merged
merged 1 commit into from Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -57,10 +57,18 @@ fun TypeResolver.encapsulatingType(
val types = exprList.map { resolvedType(it) }
val sqlTypes = types.map { it.dialectType }

val type = typeOrder.lastOrNull { it in sqlTypes }
?: if (PrimitiveType.ARGUMENT in sqlTypes && typeOrder.size == 1) {
if (PrimitiveType.ARGUMENT in sqlTypes) {
if (typeOrder.size == 1) {
return IntermediateType(typeOrder.single())
} else error("The Kotlin type of the argument cannot be inferred, use CAST instead.")
}
val otherFunctionParameters = sqlTypes.distinct() - PrimitiveType.ARGUMENT
if (otherFunctionParameters.size == 1) {
return IntermediateType(otherFunctionParameters.single())
}
error("The Kotlin type of the argument cannot be inferred, use CAST instead.")
}

val type = typeOrder.last { it in sqlTypes }

if (!nullableIfAny && types.all { it.javaType.isNullable } ||
nullableIfAny && types.any { it.javaType.isNullable }
Expand Down
Expand Up @@ -307,6 +307,49 @@ class BindArgsTest {
.isEqualTo("The Kotlin type of the argument cannot be inferred, use CAST instead.")
}

@Test fun `bind arg kotlin type can be inferred with other types`() {
val file = FixtureCompiler.parseSql(
"""
|CREATE TABLE dummy(
| foo INTEGER
|);
|
|inferredNullableLong:
|SELECT 1
|FROM dummy
|WHERE MAX(1, :input) > 1;
""".trimMargin(),
tempFolder,
)

file.findChildrenOfType<SqlBindExpr>().map { it.argumentType() }.let { args ->
assertThat(args[0].dialectType).isEqualTo(PrimitiveType.INTEGER)
assertThat(args[0].javaType).isEqualTo(Long::class.asClassName().copy(nullable = true))
}
}

@Test fun `bind arg kotlin type cannot be inferred with other different types`() {
val file = FixtureCompiler.parseSql(
"""
|CREATE TABLE dummy(
| foo INTEGER
|);
|
|differentSqlTypes:
|SELECT 1
|FROM dummy
|WHERE MAX(1, 'FOO', :input) > 1;
""".trimMargin(),
tempFolder,
)

val errorMessage = assertFailsWith<IllegalStateException> {
file.findChildrenOfType<SqlBindExpr>().single().argumentType()
}
assertThat(errorMessage.message)
.isEqualTo("The Kotlin type of the argument cannot be inferred, use CAST instead.")
}

@Test fun `bind args use proper binary operator precedence`() {
val file = FixtureCompiler.parseSql(
"""
Expand Down
Expand Up @@ -7,3 +7,6 @@ SELECT foo(id) FROM foo;

inferredType:
SELECT 1 FROM foo WHERE foo(:inferred) NOT NULL;

inferredTypeFromMax:
SELECT 1 FROM foo WHERE max(1, :inferred);
Expand Up @@ -7,25 +7,23 @@ import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds

class Testing {
@Test fun inferredCompiles() {
val fakeDriver = object : JdbcDriver() {
override fun getConnection() = TODO()
override fun closeConnection(connection: Connection) = Unit
override fun addListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun removeListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun notifyListeners(queryKeys: Array<String>) = Unit
}
FooQueries(fakeDriver).inferredType(1.seconds)
val fakeDriver = object : JdbcDriver() {
override fun getConnection() = TODO()
override fun closeConnection(connection: Connection) = Unit
override fun addListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun removeListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun notifyListeners(queryKeys: Array<String>) = Unit
}

@Test fun customFunctionReturnsDuration() {
val fakeDriver = object : JdbcDriver() {
override fun getConnection() = TODO()
override fun closeConnection(connection: Connection) = Unit
override fun addListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun removeListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun notifyListeners(queryKeys: Array<String>) = Unit
}
val unused: Duration = FooQueries(fakeDriver).selectFooWithId().executeAsOne()
}

@Test fun inferredCompiles() {
FooQueries(fakeDriver).inferredType(1.seconds)
}

@Test fun inferredTypeFromMaxIsLong() {
FooQueries(fakeDriver).inferredTypeFromMax(1L).executeAsOne()
}
}