Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow executemany to return rows #1110

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 16 additions & 4 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
)
return status.decode()

async def executemany(self, command: str, args, *, timeout: float=None):
async def executemany(self, command: str, args, *, timeout: float=None,
return_rows: bool=False):
"""Execute an SQL *command* for each sequence of arguments in *args*.

Example:
Expand All @@ -373,7 +374,13 @@ async def executemany(self, command: str, args, *, timeout: float=None):
:param command: Command to execute.
:param args: An iterable containing sequences of arguments.
:param float timeout: Optional timeout value in seconds.
:return None: This method discards the results of the operations.
:param bool return_rows:
If ``True``, the resulting rows of each command will be
returned as a list of :class:`~asyncpg.Record`
(defaults to ``False``).
:return:
None, or a list of :class:`~asyncpg.Record` instances
if `return_rows` is true.

.. versionadded:: 0.7.0

Expand All @@ -386,9 +393,13 @@ async def executemany(self, command: str, args, *, timeout: float=None):
to prior versions, where the effect of already-processed iterations
would remain in place when an error has occurred, unless
``executemany()`` was called in a transaction.

.. versionchanged:: 0.30.0
Added `return_rows` keyword-only parameter.
"""
self._check_open()
return await self._executemany(command, args, timeout)
return await self._executemany(
command, args, timeout, return_rows=return_rows)

async def _get_statement(
self,
Expand Down Expand Up @@ -1898,12 +1909,13 @@ async def __execute(
)
return result, stmt

async def _executemany(self, query, args, timeout):
async def _executemany(self, query, args, timeout, return_rows):
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
state=stmt,
args=args,
portal_name='',
timeout=timeout,
return_rows=return_rows,
)
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
Expand Down
6 changes: 4 additions & 2 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,8 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
async with self.acquire() as con:
return await con.execute(query, *args, timeout=timeout)

async def executemany(self, command: str, args, *, timeout: float=None):
async def executemany(self, command: str, args, *, timeout: float=None,
return_rows: bool=False):
"""Execute an SQL *command* for each sequence of arguments in *args*.

Pool performs this operation using one of its connections. Other than
Expand All @@ -549,7 +550,8 @@ async def executemany(self, command: str, args, *, timeout: float=None):
.. versionadded:: 0.10.0
"""
async with self.acquire() as con:
return await con.executemany(command, args, timeout=timeout)
return await con.executemany(
command, args, timeout=timeout, return_rows=return_rows)

async def fetch(
self,
Expand Down
16 changes: 13 additions & 3 deletions asyncpg/prepared_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,18 +211,28 @@ async def fetchrow(self, *args, timeout=None):
return data[0]

@connresource.guarded
async def executemany(self, args, *, timeout: float=None):
async def executemany(self, args, *, timeout: float=None,
return_rows: bool=False):
"""Execute the statement for each sequence of arguments in *args*.

:param args: An iterable containing sequences of arguments.
:param float timeout: Optional timeout value in seconds.
:return None: This method discards the results of the operations.
:param bool return_rows:
If ``True``, the resulting rows of each command will be
returned as a list of :class:`~asyncpg.Record`
(defaults to ``False``).
:return:
None, or a list of :class:`~asyncpg.Record` instances
if `return_rows` is true.

.. versionadded:: 0.22.0

.. versionchanged:: 0.30.0
Added `return_rows` keyword-only parameter.
"""
return await self.__do_execute(
lambda protocol: protocol.bind_execute_many(
self._state, args, '', timeout))
self._state, args, '', timeout, return_rows=return_rows))

async def __do_execute(self, executor):
protocol = self._connection._protocol
Expand Down
2 changes: 1 addition & 1 deletion asyncpg/protocol/coreproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ cdef class CoreProtocol:
cdef _bind_execute(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data)
object bind_data, bint return_rows)
cdef bint _bind_execute_many_more(self, bint first=*)
cdef _bind_execute_many_fail(self, object error, bint first=*)
cdef _bind(self, str portal_name, str stmt_name,
Expand Down
6 changes: 3 additions & 3 deletions asyncpg/protocol/coreproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -940,12 +940,12 @@ cdef class CoreProtocol:
self._send_bind_message(portal_name, stmt_name, bind_data, limit)

cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data):
object bind_data, bint return_rows):
self._ensure_connected()
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)

self.result = None
self._discard_data = True
self.result = [] if return_rows else None
self._discard_data = not return_rows
self._execute_iter = bind_data
self._execute_portal_name = portal_name
self._execute_stmt_name = stmt_name
Expand Down
4 changes: 3 additions & 1 deletion asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ cdef class BaseProtocol(CoreProtocol):
args,
portal_name: str,
timeout,
return_rows: bool,
):
if self.cancel_waiter is not None:
await self.cancel_waiter
Expand All @@ -238,7 +239,8 @@ cdef class BaseProtocol(CoreProtocol):
more = self._bind_execute_many(
portal_name,
state.name,
arg_bufs) # network op
arg_bufs,
return_rows) # network op

self.last_query = state.query
self.statement = state
Expand Down
53 changes: 52 additions & 1 deletion tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,45 @@ async def test_executemany_basic(self):
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])

async def test_executemany_returning(self):
result = await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2) RETURNING a, b
''', [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
], return_rows=True)
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])
result = await self.con.fetch('''
SELECT * FROM exmany
''')
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])

# Empty set
await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2) RETURNING a, b
''', (), return_rows=True)
result = await self.con.fetch('''
SELECT * FROM exmany
''')
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])

# Without "RETURNING"
result = await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2)
''', [('e', 5), ('f', 6)], return_rows=True)
self.assertEqual(result, [])
result = await self.con.fetch('''
SELECT * FROM exmany
''')
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6)
])

async def test_executemany_bad_input(self):
with self.assertRaisesRegex(
exceptions.DataError,
Expand Down Expand Up @@ -288,11 +327,13 @@ async def test_executemany_client_server_failure_conflict(self):

async def test_executemany_prepare(self):
stmt = await self.con.prepare('''
INSERT INTO exmany VALUES($1, $2)
INSERT INTO exmany VALUES($1, $2) RETURNING a, b
''')
result = await stmt.executemany([
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])
# While the query contains a "RETURNING" clause, by default
# `executemany` does not return anything
self.assertIsNone(result)
result = await self.con.fetch('''
SELECT * FROM exmany
Expand All @@ -308,3 +349,13 @@ async def test_executemany_prepare(self):
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4)
])
# Now with `return_rows=True`, we should retrieve the tuples
# from the "RETURNING" clause.
result = await stmt.executemany([('e', 5), ('f', 6)], return_rows=True)
self.assertEqual(result, [('e', 5), ('f', 6)])
result = await self.con.fetch('''
SELECT * FROM exmany
''')
self.assertEqual(result, [
('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6)
])