Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #298 -- create m2m when using _bulk_create=True #354

Merged
merged 6 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/changelog.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on: pull_request

jobs:
remind:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
if: |
!contains(github.event.pull_request.body, '[skip changelog]') &&
(github.actor != 'dependabot[bot]')
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/linter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ on:
jobs:
tests:
name: Python ${{ matrix.python-version }}
runs-on: ubuntu-latest
runs-on: ubuntu-22.04

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on: [release]
jobs:
package:
name: Build & verify package
runs-on: ubuntu-latest
runs-on: ubuntu-22.04

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ on:
jobs:
tests:
name: Python ${{ matrix.python-version }}
runs-on: ubuntu-latest
runs-on: ubuntu-22.04

services:
postgis:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
### Changed
- Fixed a bug with `seq` being passed a tz-aware start value [PR #353](https://github.com/model-bakers/model_bakery/pull/353)
- [dev] Use official postgis docker image in CI [PR #355](https://github.com/model-bakers/model_bakery/pull/355)
- Create m2m when using `_bulk_create=True` [PR #354](https://github.com/model-bakers/model_bakery/pull/354)

### Removed

Expand Down
26 changes: 25 additions & 1 deletion model_bakery/baker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
overload,
)

from django import VERSION as DJANGO_VERSION
from django.apps import apps
from django.conf import settings
from django.contrib import contenttypes
Expand Down Expand Up @@ -787,4 +788,27 @@ def _save_related_objs(model, objects) -> None:
else:
manager = baker.model._base_manager

return manager.bulk_create(entries)
existing_entries = list(manager.values_list("pk", flat=True))
created_entries = manager.bulk_create(entries)
# bulk_create in Django < 4.0 does not return ids of created objects.
# drop this after 01 Apr 2024 (Django 3.2 LTS end of life)
if DJANGO_VERSION < (4, 0):
created_entries = manager.exclude(pk__in=existing_entries)

# set many-to-many relations from kwargs
for entry in created_entries:
for field in baker.model._meta.many_to_many:
if field.name in kwargs:
through_model = getattr(entry, field.name).through
through_model.objects.bulk_create(
[
through_model(
**{
field.remote_field.name: entry,
field.related_model._meta.model_name: obj,
}
)
for obj in kwargs[field.name]
]
)
return created_entries
35 changes: 26 additions & 9 deletions tests/test_baker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest.mock import patch

import pytest
from django import VERSION as DJANGO_VERSION
from django.conf import settings
from django.contrib.contenttypes.models import ContentType
from django.db.models import Manager
Expand Down Expand Up @@ -151,17 +152,18 @@ def test_make_should_create_objects_respecting_quantity_parameter(self):
assert all(p.name == "George Washington" for p in people)

def test_make_quantity_respecting_bulk_create_parameter(self):
with self.assertNumQueries(1):
query_count = 2 if DJANGO_VERSION >= (4, 0) else 3
with self.assertNumQueries(query_count):
baker.make(models.Person, _quantity=5, _bulk_create=True)
assert models.Person.objects.count() == 5

with self.assertNumQueries(1):
with self.assertNumQueries(query_count):
people = baker.make(
models.Person, name="George Washington", _quantity=5, _bulk_create=True
)
assert all(p.name == "George Washington" for p in people)

with self.assertNumQueries(1):
with self.assertNumQueries(query_count):
baker.make(models.NonStandardManager, _quantity=3, _bulk_create=True)
assert getattr(models.NonStandardManager, "objects", None) is None
assert (
Expand Down Expand Up @@ -362,16 +364,17 @@ def test_create_multiple_one_to_one(self):
assert models.Person.objects.all().count() == 5

def test_bulk_create_multiple_one_to_one(self):
with self.assertNumQueries(6):
query_count = 7 if DJANGO_VERSION >= (4, 0) else 8
with self.assertNumQueries(query_count):
baker.make(models.LonelyPerson, _quantity=5, _bulk_create=True)

assert models.LonelyPerson.objects.all().count() == 5
assert models.Person.objects.all().count() == 5

def test_chaining_bulk_create_reduces_query_count(self):
qtd = 5
with self.assertNumQueries(3):
baker.make(models.Person, _quantity=qtd, _bulk_create=True)
query_count = 5 if DJANGO_VERSION >= (4, 0) else 7
with self.assertNumQueries(query_count):
baker.make(models.Person, _quantity=5, _bulk_create=True)
person_iter = models.Person.objects.all().iterator()
baker.make(
models.LonelyPerson,
Expand All @@ -385,7 +388,8 @@ def test_chaining_bulk_create_reduces_query_count(self):
assert models.Person.objects.all().count() == 5

def test_bulk_create_multiple_fk(self):
with self.assertNumQueries(6):
query_count = 7 if DJANGO_VERSION >= (4, 0) else 8
with self.assertNumQueries(query_count):
baker.make(models.PaymentBill, _quantity=5, _bulk_create=True)

assert models.PaymentBill.objects.all().count() == 5
Expand All @@ -396,7 +400,7 @@ def test_create_many_to_many_if_flagged(self):
assert store.employees.count() == 5
assert store.customers.count() == 5

def test_regresstion_many_to_many_field_is_accepted_as_kwargs(self):
def test_regression_many_to_many_field_is_accepted_as_kwargs(self):
employees = baker.make(models.Person, _quantity=3)
customers = baker.make(models.Person, _quantity=3)

Expand Down Expand Up @@ -1032,3 +1036,16 @@ def test_annotation_within_manager_get_queryset_are_run_on_make(self):
_from_manager="objects",
)
assert movie.title == movie.name


class TestCreateM2MWhenBulkCreate(TestCase):
@pytest.mark.django_db
def test_create(self):
query_count = 13 if DJANGO_VERSION >= (4, 0) else 14
with self.assertNumQueries(query_count):
person = baker.make(models.Person)
baker.make(
models.Classroom, students=[person], _quantity=10, _bulk_create=True
)
c1, c2 = models.Classroom.objects.all()[:2]
assert list(c1.students.all()) == list(c2.students.all()) == [person]