-
Notifications
You must be signed in to change notification settings - Fork 496
/
QueryGenerator.kt
322 lines (282 loc) · 12.1 KB
/
QueryGenerator.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
package app.cash.sqldelight.core.compiler
import app.cash.sqldelight.core.compiler.integration.javadocText
import app.cash.sqldelight.core.compiler.model.BindableQuery
import app.cash.sqldelight.core.compiler.model.NamedMutator
import app.cash.sqldelight.core.compiler.model.NamedQuery
import app.cash.sqldelight.core.lang.ASYNC_RESULT_TYPE
import app.cash.sqldelight.core.lang.DRIVER_NAME
import app.cash.sqldelight.core.lang.MAPPER_NAME
import app.cash.sqldelight.core.lang.PREPARED_STATEMENT_TYPE
import app.cash.sqldelight.core.lang.encodedJavaType
import app.cash.sqldelight.core.lang.preparedStatementBinder
import app.cash.sqldelight.core.lang.util.childOfType
import app.cash.sqldelight.core.lang.util.columnDefSource
import app.cash.sqldelight.core.lang.util.findChildrenOfType
import app.cash.sqldelight.core.lang.util.isArrayParameter
import app.cash.sqldelight.core.lang.util.range
import app.cash.sqldelight.core.lang.util.rawSqlText
import app.cash.sqldelight.core.lang.util.sqFile
import app.cash.sqldelight.core.psi.SqlDelightStmtClojureStmtList
import app.cash.sqldelight.dialect.api.IntermediateType
import app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin
import com.alecstrong.sql.psi.core.psi.SqlBinaryEqualityExpr
import com.alecstrong.sql.psi.core.psi.SqlBindExpr
import com.alecstrong.sql.psi.core.psi.SqlStmt
import com.alecstrong.sql.psi.core.psi.SqlTypes
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiWhiteSpace
import com.intellij.psi.util.PsiTreeUtil
import com.squareup.kotlinpoet.ClassName
import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.FunSpec
import com.squareup.kotlinpoet.NameAllocator
abstract class QueryGenerator(
private val query: BindableQuery,
) {
protected val dialect = query.statement.sqFile().dialect
protected val treatNullAsUnknownForEquality = query.statement.sqFile().treatNullAsUnknownForEquality
protected val generateAsync = query.statement.sqFile().generateAsync
/**
* Creates the block of code that prepares [query] as a prepared statement and binds the
* arguments to it. This code block does not make any use of class fields, and only populates a
* single variable [STATEMENT_NAME]
*
* val numberIndexes = createArguments(count = number.size)
* val statement = database.prepareStatement("""
* |SELECT *
* |FROM player
* |WHERE number IN $numberIndexes
* """.trimMargin(), SqlPreparedStatement.Type.SELECT, 1 + (number.size - 1))
* number.forEachIndexed { index, number ->
* check(this is SqlCursorSubclass)
* statement.bindLong(index + 2, number)
* }
*/
protected fun executeBlock(): CodeBlock {
val result = CodeBlock.builder()
if (query.statement is SqlDelightStmtClojureStmtList) {
if (query is NamedQuery) {
result
.apply { if (generateAsync) beginControlFlow("return %T", ASYNC_RESULT_TYPE) }
.beginControlFlow(if (generateAsync) "transactionWithResult" else "return transactionWithResult")
} else {
result.beginControlFlow("transaction")
}
query.statement.findChildrenOfType<SqlStmt>().forEachIndexed { index, statement ->
result.add(executeBlock(statement, query.idForIndex(index)))
}
result.endControlFlow()
if (generateAsync && query is NamedQuery) { result.endControlFlow() }
} else {
result.add(executeBlock(query.statement, query.id))
}
return result.build()
}
private fun executeBlock(
statement: PsiElement,
id: Int,
): CodeBlock {
val dialectPreparedStatementType = if (generateAsync) dialect.asyncRuntimeTypes.preparedStatementType else dialect.runtimeTypes.preparedStatementType
val result = CodeBlock.builder()
val positionToArgument = mutableListOf<Triple<Int, BindableQuery.Argument, SqlBindExpr?>>()
val seenArgs = mutableSetOf<BindableQuery.Argument>()
val duplicateTypes = mutableSetOf<IntermediateType>()
query.arguments.forEach { argument ->
if (argument.bindArgs.isNotEmpty()) {
argument.bindArgs
.filter { PsiTreeUtil.isAncestor(statement, it, true) }
.forEach { bindArg ->
if (!seenArgs.add(argument)) {
duplicateTypes.add(argument.type)
}
positionToArgument.add(Triple(bindArg.node.textRange.startOffset, argument, bindArg))
}
} else {
positionToArgument.add(Triple(0, argument, null))
}
}
val bindStatements = CodeBlock.builder()
val replacements = mutableListOf<Pair<IntRange, String>>()
val argumentCounts = mutableListOf<String>()
var needsFreshStatement = false
val seenArrayArguments = mutableSetOf<BindableQuery.Argument>()
val argumentNameAllocator = NameAllocator().apply {
query.arguments.forEach { newName(it.type.name) }
}
// A list of [SqlBindExpr] in order of appearance in the query.
val orderedBindArgs = positionToArgument.sortedBy { it.first }
// The number of non-array bindArg's we've encountered so far.
var nonArrayBindArgsCount = 0
// A list of arrays we've encountered so far.
val precedingArrays = mutableListOf<String>()
val extractedVariables = mutableMapOf<IntermediateType, String>()
// extract the variable for duplicate types, so we don't encode twice
for (type in duplicateTypes) {
if (type.bindArg?.isArrayParameter() == true) continue
val encodedJavaType = type.encodedJavaType() ?: continue
val variableName = argumentNameAllocator.newName(type.name)
extractedVariables[type] = variableName
bindStatements.add("val %N = $encodedJavaType\n", variableName)
}
// For each argument in the sql
orderedBindArgs.forEach { (_, argument, bindArg) ->
val type = argument.type
// Need to replace the single argument with a group of indexed arguments, calculated at
// runtime from the list parameter:
// val idIndexes = id.mapIndexed { index, _ -> "?${previousArray.size + index}" }.joinToString(prefix = "(", postfix = ")")
val offset = (precedingArrays.map { "$it.size" } + "$nonArrayBindArgsCount")
.joinToString(separator = " + ").replace(" + 0", "")
if (bindArg?.isArrayParameter() == true) {
needsFreshStatement = true
if (seenArrayArguments.add(argument)) {
result.addStatement(
"""
|val ${type.name}Indexes = createArguments(count = ${type.name}.size)
""".trimMargin(),
)
}
// Replace the single bind argument with the array of bind arguments:
// WHERE id IN ${idIndexes}
replacements.add(bindArg.range to "\$${type.name}Indexes")
// Perform the necessary binds:
// id.forEachIndex { index, parameter ->
// statement.bindLong(previousArray.size + index, parameter)
// }
val indexCalculator = "index + $offset".replace(" + 0", "")
val elementName = argumentNameAllocator.newName(type.name)
bindStatements.add(
"""
|${type.name}.forEachIndexed { index, $elementName ->
| %L}
|
""".trimMargin(),
type.copy(name = elementName).preparedStatementBinder(indexCalculator),
)
precedingArrays.add(type.name)
argumentCounts.add("${type.name}.size")
} else {
val bindParameter = bindArg?.bindParameter as? BindParameterMixin
if (bindParameter == null || bindParameter.text != "DEFAULT") {
nonArrayBindArgsCount += 1
if (!treatNullAsUnknownForEquality && type.javaType.isNullable) {
val parent = bindArg?.parent
if (parent is SqlBinaryEqualityExpr) {
needsFreshStatement = true
var symbol = parent.childOfType(SqlTypes.EQ) ?: parent.childOfType(SqlTypes.EQ2)
val nullableEquality: String
if (symbol != null) {
nullableEquality = "${symbol.leftWhitspace()}IS${symbol.rightWhitespace()}"
} else {
symbol = parent.childOfType(SqlTypes.NEQ) ?: parent.childOfType(SqlTypes.NEQ2)!!
nullableEquality = "${symbol.leftWhitspace()}IS NOT${symbol.rightWhitespace()}"
}
val block = CodeBlock.of("if (${type.name} == null) \"$nullableEquality\" else \"${symbol.text}\"")
replacements.add(symbol.range to "\${ $block }")
}
}
// Binds each parameter to the statement:
// statement.bindLong(0, id)
bindStatements.add(type.preparedStatementBinder(offset, extractedVariables[type]))
// Replace the named argument with a non named/indexed argument.
// This allows us to use the same algorithm for non Sqlite dialects
// :name becomes ?
if (bindParameter != null) {
replacements.add(bindArg.range to bindParameter.replaceWith(generateAsync, index = nonArrayBindArgsCount))
}
}
}
}
val optimisticLock = if (query is NamedMutator.Update) {
val columnsUpdated =
query.update.updateStmtSubsequentSetterList.mapNotNull { it.columnName } +
query.update.columnNameList
columnsUpdated.singleOrNull {
it.columnDefSource()!!.columnType.node.getChildren(null).any { it.text == "LOCK" }
}
} else {
null
}
// Adds the actual SqlPreparedStatement:
// statement = database.prepareStatement("SELECT * FROM test")
val isNamedQuery = query is NamedQuery &&
(statement == query.statement || statement == query.statement.children.filterIsInstance<SqlStmt>().last())
if (nonArrayBindArgsCount != 0) {
argumentCounts.add(0, nonArrayBindArgsCount.toString())
}
val arguments = mutableListOf<Any>(
statement.rawSqlText(replacements),
argumentCounts.ifEmpty { listOf(0) }.joinToString(" + "),
)
var binder: String
if (argumentCounts.isEmpty()) {
binder = ""
} else {
val binderLambda = CodeBlock.builder()
.add(" {\n")
.indent()
if (PREPARED_STATEMENT_TYPE != dialectPreparedStatementType) {
binderLambda.add("check(this is %T)\n", dialectPreparedStatementType)
}
binderLambda.add(bindStatements.build())
.unindent()
.add("}")
arguments.add(binderLambda.build())
binder = "%L"
}
if (generateAsync) {
val awaiter = awaiting()
if (isNamedQuery) {
awaiter?.let { (bind, arg) ->
binder += bind
arguments.add(arg)
}
} else {
binder += "%L"
arguments.add(".await()")
}
}
val statementId = if (needsFreshStatement) "null" else "$id"
if (isNamedQuery) {
val execute = if (query.statement is SqlDelightStmtClojureStmtList) {
"$DRIVER_NAME.executeQuery"
} else {
"return $DRIVER_NAME.executeQuery"
}
result.addStatement(
"$execute($statementId, %P, $MAPPER_NAME, %L)$binder",
*arguments.toTypedArray(),
)
} else if (optimisticLock != null) {
result.addStatement(
"val result = $DRIVER_NAME.execute($statementId, %P, %L)$binder",
*arguments.toTypedArray(),
)
} else {
result.addStatement(
"$DRIVER_NAME.execute($statementId, %P, %L)$binder",
*arguments.toTypedArray(),
)
}
if (query is NamedMutator.Update && optimisticLock != null) {
result.addStatement(
"""
if (result%L == 0L) throw %T(%S)
""".trimIndent(),
if (generateAsync) "" else ".value",
ClassName("app.cash.sqldelight.db", "OptimisticLockException"),
"UPDATE on ${query.tablesAffected.single().name} failed because optimistic lock ${optimisticLock.name} did not match",
)
}
return result.build()
}
private fun PsiElement.leftWhitspace(): String {
return if (prevSibling is PsiWhiteSpace) "" else " "
}
private fun PsiElement.rightWhitespace(): String {
return if (nextSibling is PsiWhiteSpace) "" else " "
}
protected fun addJavadoc(builder: FunSpec.Builder) {
if (query.javadoc != null) { builder.addKdoc(javadocText(query.javadoc)) }
}
protected open fun awaiting(): Pair<String, String>? = "%L" to ".await()"
}