Skip to content

Commit

Permalink
Support inferring kotlin type from other parameters (#3431)
Browse files Browse the repository at this point in the history
Co-authored-by: hfhbd <hfhbd@users.noreply.github.com>
  • Loading branch information
hfhbd and hfhbd committed Aug 11, 2022
1 parent 00a2a77 commit 7c32c0e
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 19 deletions.
Expand Up @@ -57,10 +57,18 @@ fun TypeResolver.encapsulatingType(
val types = exprList.map { resolvedType(it) }
val sqlTypes = types.map { it.dialectType }

val type = typeOrder.lastOrNull { it in sqlTypes }
?: if (PrimitiveType.ARGUMENT in sqlTypes && typeOrder.size == 1) {
if (PrimitiveType.ARGUMENT in sqlTypes) {
if (typeOrder.size == 1) {
return IntermediateType(typeOrder.single())
} else error("The Kotlin type of the argument cannot be inferred, use CAST instead.")
}
val otherFunctionParameters = sqlTypes.distinct() - PrimitiveType.ARGUMENT
if (otherFunctionParameters.size == 1) {
return IntermediateType(otherFunctionParameters.single())
}
error("The Kotlin type of the argument cannot be inferred, use CAST instead.")
}

val type = typeOrder.last { it in sqlTypes }

if (!nullableIfAny && types.all { it.javaType.isNullable } ||
nullableIfAny && types.any { it.javaType.isNullable }
Expand Down
Expand Up @@ -307,6 +307,49 @@ class BindArgsTest {
.isEqualTo("The Kotlin type of the argument cannot be inferred, use CAST instead.")
}

@Test fun `bind arg kotlin type can be inferred with other types`() {
val file = FixtureCompiler.parseSql(
"""
|CREATE TABLE dummy(
| foo INTEGER
|);
|
|inferredNullableLong:
|SELECT 1
|FROM dummy
|WHERE MAX(1, :input) > 1;
""".trimMargin(),
tempFolder,
)

file.findChildrenOfType<SqlBindExpr>().map { it.argumentType() }.let { args ->
assertThat(args[0].dialectType).isEqualTo(PrimitiveType.INTEGER)
assertThat(args[0].javaType).isEqualTo(Long::class.asClassName().copy(nullable = true))
}
}

@Test fun `bind arg kotlin type cannot be inferred with other different types`() {
val file = FixtureCompiler.parseSql(
"""
|CREATE TABLE dummy(
| foo INTEGER
|);
|
|differentSqlTypes:
|SELECT 1
|FROM dummy
|WHERE MAX(1, 'FOO', :input) > 1;
""".trimMargin(),
tempFolder,
)

val errorMessage = assertFailsWith<IllegalStateException> {
file.findChildrenOfType<SqlBindExpr>().single().argumentType()
}
assertThat(errorMessage.message)
.isEqualTo("The Kotlin type of the argument cannot be inferred, use CAST instead.")
}

@Test fun `bind args use proper binary operator precedence`() {
val file = FixtureCompiler.parseSql(
"""
Expand Down
Expand Up @@ -7,3 +7,6 @@ SELECT foo(id) FROM foo;

inferredType:
SELECT 1 FROM foo WHERE foo(:inferred) NOT NULL;

inferredTypeFromMax:
SELECT 1 FROM foo WHERE max(1, :inferred);
Expand Up @@ -7,25 +7,23 @@ import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds

class Testing {
@Test fun inferredCompiles() {
val fakeDriver = object : JdbcDriver() {
override fun getConnection() = TODO()
override fun closeConnection(connection: Connection) = Unit
override fun addListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun removeListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun notifyListeners(queryKeys: Array<String>) = Unit
}
FooQueries(fakeDriver).inferredType(1.seconds)
val fakeDriver = object : JdbcDriver() {
override fun getConnection() = TODO()
override fun closeConnection(connection: Connection) = Unit
override fun addListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun removeListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun notifyListeners(queryKeys: Array<String>) = Unit
}

@Test fun customFunctionReturnsDuration() {
val fakeDriver = object : JdbcDriver() {
override fun getConnection() = TODO()
override fun closeConnection(connection: Connection) = Unit
override fun addListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun removeListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun notifyListeners(queryKeys: Array<String>) = Unit
}
val unused: Duration = FooQueries(fakeDriver).selectFooWithId().executeAsOne()
}

@Test fun inferredCompiles() {
FooQueries(fakeDriver).inferredType(1.seconds)
}

@Test fun inferredTypeFromMaxIsLong() {
FooQueries(fakeDriver).inferredTypeFromMax(1L).executeAsOne()
}
}

0 comments on commit 7c32c0e

Please sign in to comment.