Skip to content

Commit

Permalink
Fix #298 -- create m2m when using _bulk_create=True
Browse files Browse the repository at this point in the history
  • Loading branch information
amureki committed Oct 15, 2022
1 parent b90253b commit 9006355
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

### Changed
- Fixed a bug with `seq` being passed a tz-aware start value [PR #353](https://github.com/model-bakers/model_bakery/pull/353)
- Create m2m when using `_bulk_create=True` [PR #354](https://github.com/model-bakers/model_bakery/pull/354)

### Removed

Expand Down
19 changes: 18 additions & 1 deletion model_bakery/baker.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,4 +787,21 @@ def _save_related_objs(model, objects) -> None:
else:
manager = baker.model._base_manager

return manager.bulk_create(entries)
created_entries = manager.bulk_create(entries)
# set many-to-many relations from kwargs
for entry in created_entries:
for field in baker.model._meta.many_to_many:
if field.name in kwargs:
through_model = getattr(entry, field.name).through
through_model.objects.bulk_create(
[
through_model(
**{
field.remote_field.name: entry,
field.related_model._meta.model_name: obj,
}
)
for obj in kwargs[field.name]
]
)
return created_entries
12 changes: 12 additions & 0 deletions tests/test_baker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,3 +1032,15 @@ def test_annotation_within_manager_get_queryset_are_run_on_make(self):
_from_manager="objects",
)
assert movie.title == movie.name


class TestCreateM2MWhenBulkCreate(TestCase):
@pytest.mark.django_db
def test_create(self):
with self.assertNumQueries(22):
person = baker.make(models.Person)
baker.make(
models.Classroom, students=[person], _quantity=20, _bulk_create=True
)
c1, c2 = models.Classroom.objects.all()[:2]
assert list(c1.students.all()) == list(c2.students.all()) == [person]

0 comments on commit 9006355

Please sign in to comment.