Skip to content

Commit

Permalink
Fixed get_or_create method (#1404)
Browse files Browse the repository at this point in the history
* fixed get_or_create

* added transaction

---------

Co-authored-by: Andrey Bondar <andrey@bondar.ru>
  • Loading branch information
ipakeev and abondar committed May 14, 2024
1 parent 9f1de42 commit 7cb7e72
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Fixed
- Fix `bulk_update` when using custom fields. (#1564)
- Fix `optional` parameter in `pydantic_model_creator` does not work for pydantic v2. (#1551)
- Fix `get_annotations` now evaluates annotations in the default scope instead of the app namespace. (#1552)
- Fix `get_or_create` method. (#1404)

Changed
^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions CONTRIBUTORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Contributors
* Yuval Ben-Arie ``@yuvalbenarie``
* Stephan Klein ``@privatwolke``
* ``@WizzyGeek``
* Ivan Pakeev ``@ipakeev``
* Abdeldjalil Hezouat ``@Abdeldjalil-H``

Special Thanks
Expand Down
8 changes: 0 additions & 8 deletions tests/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,3 @@ async def test_nonconcurrent_get_or_create(self):
self.assertEqual(len(una_created), 1)
for una in unas:
self.assertEqual(una[0], unas[0][0])

@test.skipIf(sys.version_info < (3, 7), "aiocontextvars backport not handling this well")
async def test_concurrent_get_or_create(self):
unas = await asyncio.gather(*[UniqueName.get_or_create(name="b") for _ in range(10)])
una_created = [una[1] for una in unas if una[1] is True]
self.assertEqual(len(una_created), 1)
for una in unas:
self.assertEqual(una[0], unas[0][0])
21 changes: 10 additions & 11 deletions tortoise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
IntegrityError,
OperationalError,
ParamsError,
TransactionManagementError,
)
from tortoise.fields.base import Field
from tortoise.fields.data import IntField
Expand Down Expand Up @@ -1068,22 +1067,22 @@ async def get_or_create(
:param using_db: Specific DB connection to use instead of default bound
:param kwargs: Query parameters.
:raises IntegrityError: If create failed
:raises TransactionManagementError: If transaction error
"""
if not defaults:
defaults = {}
db = using_db or cls._choose_db(True)
async with in_transaction(connection_name=db.connection_name) as connection:
try:
return await cls.filter(**kwargs).using_db(db).get(), False
except DoesNotExist:
try:
return (
await cls.select_for_update().filter(**kwargs).using_db(connection).get(),
False,
)
except DoesNotExist:
try:
async with in_transaction(connection_name=db.connection_name) as connection:
return await cls.create(using_db=connection, **defaults, **kwargs), True
except (IntegrityError, TransactionManagementError):
return await cls.filter(**kwargs).using_db(connection).get(), False
except IntegrityError as exc:
try:
return await cls.filter(**kwargs).using_db(db).get(), False
except DoesNotExist:
pass
raise exc

@classmethod
def select_for_update(
Expand Down

0 comments on commit 7cb7e72

Please sign in to comment.