diff --git a/CHANGELOG.md b/CHANGELOG.md index f55041b5..d0eacf0d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Changed - Validate `increment_by` parameter of `seq` helper when `value` is an instance of `datetime` [PR #247](https://github.com/model-bakers/model_bakery/pull/247) - Fix a simple typo in `bulk_create` disclaimer in docs +- Allow relation `_id` fields to use sequences [PR #253](https://github.com/model-bakers/model_bakery/pull/253/) - Fix bulk_create not working with multi-database setup [PR #252](https://github.com/model-bakers/model_bakery/pull/252) - Conditionally support NullBooleanField, it's under deprecation and will be removed in Django 4.0 [PR #25](https://github.com/model-bakers/model_bakery/pull/250) - Fix Django max version pin in requirements file [PR #251](https://github.com/model-bakers/model_bakery/pull/251) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 242baa64..b0f7cef8 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -363,10 +363,20 @@ def _make( ] else: self.m2m_dict[field.name] = self.model_attrs.pop(field.name) + # is an _id relation that has a sequence defined + elif ( + (isinstance(field, OneToOneField) or isinstance(field, ForeignKey)) + and hasattr(field, "attname") + and field.attname in self.iterator_attrs + ): + self.model_attrs[field.attname] = next( + self.iterator_attrs[field.attname] + ) elif field.name not in self.model_attrs: if ( not isinstance(field, ForeignKey) - or "{0}_id".format(field.name) not in self.model_attrs + or hasattr(field, "attname") + and field.attname not in self.model_attrs ): self.model_attrs[field.name] = self.generate_value( field, commit_related @@ -476,6 +486,14 @@ def _skip_field(self, field: Field) -> bool: if isinstance(field, FileField) and not self.create_files: return True + # Don't Skip related _id fields defined in the iterator attributes + if ( + (isinstance(field, OneToOneField) or isinstance(field, ForeignKey)) + and hasattr(field, "attname") + and field.attname in self.iterator_attrs + ): + return False + # Skip links to parent so parent is not created twice. if isinstance(field, OneToOneField) and self._remote_field(field).parent_link: return True diff --git a/tests/test_recipes.py b/tests/test_recipes.py index 817c135d..74c8037a 100644 --- a/tests/test_recipes.py +++ b/tests/test_recipes.py @@ -17,7 +17,10 @@ Dog, DummyBlankFieldsModel, DummyNumbersModel, + LonelyPerson, Person, + Profile, + User, ) recipe_attrs = { @@ -32,6 +35,8 @@ "birth_time": now(), } person_recipe = Recipe(Person, **recipe_attrs) +user_recipe = Recipe(User) +lonely_person_recipe = Recipe(LonelyPerson) def test_import_seq_from_recipe(): @@ -450,6 +455,28 @@ def test_increment_for_strings_with_suffix(self): person = fred_person.make() assert person.email == "fred3@example.com" + def test_increment_for_fks(self): + from model_bakery.recipe import seq # NoQA + + profiles = baker.make(Profile, _quantity=3) + start_id = profiles[0].id + seq_user = user_recipe.extend(username="name", profile_id=seq(start_id)) + user = seq_user.make() + assert user.profile_id == start_id + 1 + user = seq_user.make() + assert user.profile_id == start_id + 2 + + def test_increment_for_one_to_one(self): + from model_bakery.recipe import seq # NoQA + + people = baker.make(Person, _quantity=3) + start_id = people[0].id + seq_lonely_person = lonely_person_recipe.extend(only_friend_id=seq(start_id)) + person = seq_lonely_person.make() + assert person.only_friend_id == start_id + 1 + user = seq_lonely_person.make() + assert user.only_friend_id == start_id + 2 + def test_increment_for_strings_with_bad_suffix(self): from model_bakery.recipe import seq # NoQA