-
Notifications
You must be signed in to change notification settings - Fork 496
/
R2dbcDriver.kt
175 lines (145 loc) · 5.19 KB
/
R2dbcDriver.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
package app.cash.sqldelight.driver.r2dbc
import app.cash.sqldelight.Query
import app.cash.sqldelight.Transacter
import app.cash.sqldelight.db.QueryResult
import app.cash.sqldelight.db.SqlCursor
import app.cash.sqldelight.db.SqlDriver
import app.cash.sqldelight.db.SqlPreparedStatement
import io.r2dbc.spi.Connection
import io.r2dbc.spi.Statement
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.reactive.asFlow
import kotlinx.coroutines.reactive.awaitFirstOrNull
import kotlinx.coroutines.reactive.awaitSingle
class R2dbcDriver(private val connection: Connection) : SqlDriver {
override fun <R> executeQuery(
identifier: Int?,
sql: String,
mapper: (SqlCursor) -> R,
parameters: Int,
binders: (SqlPreparedStatement.() -> Unit)?,
): QueryResult<R> {
val prepared = connection.createStatement(sql).also { statement ->
R2dbcPreparedStatement(statement).apply { if (binders != null) this.binders() }
}
return QueryResult.AsyncValue {
val result = prepared.execute().awaitSingle()
val rowSet = result.map { row, rowMetadata ->
List(rowMetadata.columnMetadatas.size) { index -> row.get(index) }
}.asFlow().toList()
return@AsyncValue mapper(R2dbcCursor(rowSet))
}
}
override fun execute(
identifier: Int?,
sql: String,
parameters: Int,
binders: (SqlPreparedStatement.() -> Unit)?,
): QueryResult<Long> {
val prepared = connection.createStatement(sql).also { statement ->
R2dbcPreparedStatement(statement).apply { if (binders != null) this.binders() }
}
return QueryResult.AsyncValue {
val result = prepared.execute().awaitSingle()
return@AsyncValue result.rowsUpdated.awaitFirstOrNull()?.toLong() ?: 0
}
}
private val transactions = ThreadLocal<Transacter.Transaction>()
private var transaction: Transacter.Transaction?
get() = transactions.get()
set(value) {
transactions.set(value)
}
override fun newTransaction(): QueryResult<Transacter.Transaction> = QueryResult.AsyncValue {
val enclosing = transaction
val transaction = Transaction(enclosing, connection)
this.transaction = transaction
if (enclosing == null) {
connection.beginTransaction().awaitFirstOrNull()
}
return@AsyncValue transaction
}
override fun currentTransaction(): Transacter.Transaction? = transaction
override fun addListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun removeListener(listener: Query.Listener, queryKeys: Array<String>) = Unit
override fun notifyListeners(queryKeys: Array<String>) = Unit
override fun close() {
// TODO: Somehow await this async operation
connection.close()
}
private inner class Transaction(
override val enclosingTransaction: Transacter.Transaction?,
private val connection: Connection,
) : Transacter.Transaction() {
override fun endTransaction(successful: Boolean): QueryResult<Unit> = QueryResult.AsyncValue {
if (enclosingTransaction == null) {
if (successful) {
connection.commitTransaction().awaitFirstOrNull()
} else {
connection.rollbackTransaction().awaitFirstOrNull()
}
}
transaction = enclosingTransaction
}
}
}
class R2dbcPreparedStatement(private val statement: Statement) : SqlPreparedStatement {
override fun bindBytes(index: Int, bytes: ByteArray?) {
if (bytes == null) {
statement.bindNull(index, ByteArray::class.java)
} else {
statement.bind(index, bytes)
}
}
override fun bindLong(index: Int, long: Long?) {
if (long == null) {
statement.bindNull(index, Long::class.java)
} else {
statement.bind(index, long)
}
}
override fun bindDouble(index: Int, double: Double?) {
if (double == null) {
statement.bindNull(index, Double::class.java)
} else {
statement.bind(index, double)
}
}
override fun bindString(index: Int, string: String?) {
if (string == null) {
statement.bindNull(index, String::class.java)
} else {
statement.bind(index, string)
}
}
override fun bindBoolean(index: Int, boolean: Boolean?) {
if (boolean == null) {
statement.bindNull(index, Boolean::class.java)
} else {
statement.bind(index, boolean)
}
}
fun bindObject(index: Int, any: Any?) {
if (any == null) {
statement.bindNull(index, Any::class.java)
} else {
statement.bind(index, any)
}
}
}
/**
* TODO: Write a better async cursor API
*/
class R2dbcCursor(val rowSet: List<List<Any?>>) : SqlCursor {
var row = -1
private set
override fun next(): Boolean = ++row < rowSet.size
override fun getString(index: Int): String? = rowSet[row][index] as String?
override fun getLong(index: Int): Long? = (rowSet[row][index] as Number?)?.toLong()
override fun getBytes(index: Int): ByteArray? = rowSet[row][index] as ByteArray?
override fun getDouble(index: Int): Double? = rowSet[row][index] as Double?
override fun getBoolean(index: Int): Boolean? = rowSet[row][index] as Boolean?
inline fun <reified T : Any> getObject(index: Int): T? = rowSet[row][index] as T?
@Suppress("UNCHECKED_CAST")
fun <T> getArray(index: Int): Array<T>? = rowSet[row][index] as Array<T>?
}