Skip to content

Commit

Permalink
Allow executemany to return rows
Browse files Browse the repository at this point in the history
  • Loading branch information
LeonardBesson committed Dec 23, 2023
1 parent c2c8d20 commit ebcf3af
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 15 deletions.
20 changes: 16 additions & 4 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
)
return status.decode()

async def executemany(self, command: str, args, *, timeout: float=None):
async def executemany(self, command: str, args, *, timeout: float=None,
return_rows: bool=False):
"""Execute an SQL *command* for each sequence of arguments in *args*.
Example:
Expand All @@ -373,7 +374,13 @@ async def executemany(self, command: str, args, *, timeout: float=None):
:param command: Command to execute.
:param args: An iterable containing sequences of arguments.
:param float timeout: Optional timeout value in seconds.
:return None: This method discards the results of the operations.
:param bool return_rows:
If ``True``, the resulting rows of each command will be
returned as a list of :class:`~asyncpg.Record`
(defaults to ``False``).
:return:
None, or a list of :class:`~asyncpg.Record` instances
if `return_rows` is true.
.. versionadded:: 0.7.0
Expand All @@ -386,9 +393,13 @@ async def executemany(self, command: str, args, *, timeout: float=None):
to prior versions, where the effect of already-processed iterations
would remain in place when an error has occurred, unless
``executemany()`` was called in a transaction.
.. versionchanged:: 0.30.0
Added `return_rows` keyword-only parameter.
"""
self._check_open()
return await self._executemany(command, args, timeout)
return await self._executemany(
command, args, timeout, return_rows=return_rows)

async def _get_statement(
self,
Expand Down Expand Up @@ -1898,12 +1909,13 @@ async def __execute(
)
return result, stmt

async def _executemany(self, query, args, timeout):
async def _executemany(self, query, args, timeout, return_rows):
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
state=stmt,
args=args,
portal_name='',
timeout=timeout,
return_rows=return_rows,
)
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
Expand Down
6 changes: 4 additions & 2 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,8 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
async with self.acquire() as con:
return await con.execute(query, *args, timeout=timeout)

async def executemany(self, command: str, args, *, timeout: float=None):
async def executemany(self, command: str, args, *, timeout: float=None,
return_rows: bool=False):
"""Execute an SQL *command* for each sequence of arguments in *args*.
Pool performs this operation using one of its connections. Other than
Expand All @@ -549,7 +550,8 @@ async def executemany(self, command: str, args, *, timeout: float=None):
.. versionadded:: 0.10.0
"""
async with self.acquire() as con:
return await con.executemany(command, args, timeout=timeout)
return await con.executemany(
command, args, timeout=timeout, return_rows=return_rows)

async def fetch(
self,
Expand Down
16 changes: 13 additions & 3 deletions asyncpg/prepared_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,18 +211,28 @@ async def fetchrow(self, *args, timeout=None):
return data[0]

@connresource.guarded
async def executemany(self, args, *, timeout: float=None):
async def executemany(self, args, *, timeout: float=None,
return_rows: bool=False):
"""Execute the statement for each sequence of arguments in *args*.
:param args: An iterable containing sequences of arguments.
:param float timeout: Optional timeout value in seconds.
:return None: This method discards the results of the operations.
:param bool return_rows:
If ``True``, the resulting rows of each command will be
returned as a list of :class:`~asyncpg.Record`
(defaults to ``False``).
:return:
None, or a list of :class:`~asyncpg.Record` instances
if `return_rows` is true.
.. versionadded:: 0.22.0
.. versionchanged:: 0.30.0
Added `return_rows` keyword-only parameter.
"""
return await self.__do_execute(
lambda protocol: protocol.bind_execute_many(
self._state, args, '', timeout))
self._state, args, '', timeout, return_rows=return_rows))

async def __do_execute(self, executor):
protocol = self._connection._protocol
Expand Down
2 changes: 1 addition & 1 deletion asyncpg/protocol/coreproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ cdef class CoreProtocol:
cdef _bind_execute(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data)
object bind_data, bint return_rows)
cdef bint _bind_execute_many_more(self, bint first=*)
cdef _bind_execute_many_fail(self, object error, bint first=*)
cdef _bind(self, str portal_name, str stmt_name,
Expand Down
6 changes: 3 additions & 3 deletions asyncpg/protocol/coreproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -940,12 +940,12 @@ cdef class CoreProtocol:
self._send_bind_message(portal_name, stmt_name, bind_data, limit)

cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data):
object bind_data, bint return_rows):
self._ensure_connected()
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)

self.result = None
self._discard_data = True
self.result = [] if return_rows else None
self._discard_data = not return_rows
self._execute_iter = bind_data
self._execute_portal_name = portal_name
self._execute_stmt_name = stmt_name
Expand Down
4 changes: 3 additions & 1 deletion asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ cdef class BaseProtocol(CoreProtocol):
args,
portal_name: str,
timeout,
return_rows: bool,
):
if self.cancel_waiter is not None:
await self.cancel_waiter
Expand All @@ -238,7 +239,8 @@ cdef class BaseProtocol(CoreProtocol):
more = self._bind_execute_many(
portal_name,
state.name,
arg_bufs) # network op
arg_bufs,
return_rows) # network op

self.last_query = state.query
self.statement = state
Expand Down
53 changes: 52 additions & 1 deletion tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,45 @@ async def test_executemany_basic(self):
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])

async def test_executemany_returning(self):
result = await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2) RETURNING a, b
''', [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
], return_rows=True)
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])
result = await self.con.fetch('''
SELECT * FROM exmany
''')
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])

# Empty set
await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2) RETURNING a, b
''', (), return_rows=True)
result = await self.con.fetch('''
SELECT * FROM exmany
''')
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])

# Without "RETURNING"
result = await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2)
''', [('e', 5), ('f', 6)], return_rows=True)
self.assertEqual(result, [])
result = await self.con.fetch('''
SELECT * FROM exmany
''')
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6)
])

async def test_executemany_bad_input(self):
with self.assertRaisesRegex(
exceptions.DataError,
Expand Down Expand Up @@ -288,11 +327,13 @@ async def test_executemany_client_server_failure_conflict(self):

async def test_executemany_prepare(self):
stmt = await self.con.prepare('''
INSERT INTO exmany VALUES($1, $2)
INSERT INTO exmany VALUES($1, $2) RETURNING a, b
''')
result = await stmt.executemany([
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])
# While the query contains a "RETURNING" clause, by default
# `executemany` does not return anything
self.assertIsNone(result)
result = await self.con.fetch('''
SELECT * FROM exmany
Expand All @@ -308,3 +349,13 @@ async def test_executemany_prepare(self):
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])
# Now with `return_rows=True`, we should retrieve the tuples
# from the "RETURNING" clause.
result = await stmt.executemany([('e', 5), ('f', 6)], return_rows=True)
self.assertEqual(result, [('e', 5), ('f', 6)])
result = await self.con.fetch('''
SELECT * FROM exmany
''')
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6)
])

0 comments on commit ebcf3af

Please sign in to comment.