diff --git a/CHANGES.rst b/CHANGES.rst index 06e005b9..89ffe1a5 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,21 +7,31 @@ Unreleased - Drop support for `Django 3.0` - Added urlsafe token field. +- Introduce context manager for FieldTracker state reset (GH-#491) 4.1.1 (2020-12-01) ------------------ + - Applied `isort` to codebase (Refs GH-#402) - Fix `TypeError` in save when model inherits from both TimeStampModel and StatusModel. (Fixes GH-465) 4.1.0 (2020-11-29) ------------------ + +**Breaking changes:** +- `FieldTracker` now marks fields as not changed after `refresh_from_db` + respecting `fields` argument (GH-#404) +- `FieldTracker` now respects `update_fields` changed in overridden `save()` + method (GH-#404) +- `FieldTracker` now resets states after `pre_save()` and not anymore `save()` + signals, possibly altering the behaviour of overridden `save()` + methods (GH-#404) + +**Other changes:** - Update InheritanceQuerySetMixin to avoid querying too much tables - TimeStampedModel now automatically adds 'modified' field as an update_fields parameter even if it is forgotten while using save() -- `FieldTracker` now marks fields as not changed after `refresh_from_db` -- `FieldTracker` now respects `update_fields` changed in overridden `save()` - method - Replace ugettext_lazy with gettext_lazy to satisfy Django deprecation warning - Add available_objects manager to SoftDeletableModel and add deprecation warning to objects manager. diff --git a/docs/utilities.rst b/docs/utilities.rst index cc154beb..b1a618d3 100644 --- a/docs/utilities.rst +++ b/docs/utilities.rst @@ -346,3 +346,75 @@ This is how ``FieldTracker`` tracks field changes on ``instance.save`` call. 8. ``instance.refresh_from_db()`` call causes initial state reset like for ``save_base()``. +When FieldTracker resets fields state +------------------------------------- + +By the definition: + +.. NOTE:: + * Field value *is changed* if it differs from current database value. + * Field value *was changed* if value has changed in database and field state didn't reset. + +.. code-block:: python + + instance = Tracked.objects.get(pk=1) + # name not changed + instance.name += '_changed' + # name is changed + instance.save() + # name is not changed again + +Current implementation resets fields state after ``post_save`` signals emitting. This is convenient for "outer" code +like in example above, but does not help when model ``save`` method is overridden. + +.. code-block:: python + + class MyModel(models.Model) + name = models.CharField(max_length=64) + tracker = FieldsTracker() + + def save(self): # erroneous implementation + self.name = self.name.replace(' ', '_') + name_changed = self.tracker.has_changed('name') + super().save() + # changed state has been reset here, so we need to store previous state somewhere else + if name_changed: + do_something_about_it() + +``FieldTracker`` provides a context manager interface to postpone fields state reset in complicate situations. + +* Fields state resets after exiting from outer-most context +* By default, all fields are reset, but field list can be provided +* Fields are counted separately depending on field list passed to context managers +* Tracker can be used as decorator +* Different instances have their own context state +* Different trackers in same instance have separate context state + +.. code-block:: python + + class MyModel(models.Model) + name = models.CharField(max_length=64) + tracker = FieldTracker() + + def save(self): # correct implementation + self.name = self.name.replace(' ', '_') + + with self.tracker: + super().save() + # changed state reset is postponed + if self.tracker.has_changed('name'): + do_something_about_it() + + # Decorator example + @tracker + def save(self): ... + + # Restrict a set of fields to reset here + @tracker(fields=('name')) + def save(self): ... + + # Context manager with field list + def save(self): + with self.tracker('name'): + ... + diff --git a/model_utils/tracker.py b/model_utils/tracker.py index dd1315f0..d19c8c42 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -86,11 +86,86 @@ def __delete__(self, obj): self.descriptor.__delete__(obj) +class FieldsContext: + """ + A context manager for tracking nested reset fields contexts. + + If tracked fields is mentioned in more than one FieldsContext, it's state + is being reset on exiting last context that mentions that field. + + >>> with fields_context(obj.tracker, 'f1', state=state): + ... with fields_context(obj.tracker, 'f1', 'f2', state=state): + ... obj.do_something_useful() + ... # f2 is reset after inner context exit + ... obj.do_something_else() + ... # f1 is reset after outer context exit + >>> + + * Note that fields are counted by passing same state dict + * FieldsContext is instantiated using FieldInstanceTracker (`obj.tracker`) + * Different objects has own state stack + + """ + + def __init__(self, tracker, *fields, state=None): + """ + :param tracker: FieldInstanceTracker instance to be reset after + context exit + :param fields: a list of field names to be tracked in current context + :param state: shared state dict used to count number of field + occurrences in context stack. + + On context enter each field mentioned in `fields` has +1 in shared + state, and on exit it receives -1. Fields that have zero after context + exit are reset in tracker instance. + """ + if state is None: + state = {} + self.tracker = tracker + self.fields = fields + self.state = state + + def __enter__(self): + """ + Increments tracked fields occurrences count in shared state. + """ + for f in self.fields: + self.state.setdefault(f, 0) + self.state[f] += 1 + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Decrements tracked fields occurrences count in shared state. + + If any field has no more occurrences in shared state, this field is + being reset by tracker. + """ + reset_fields = [] + for f in self.fields: + self.state[f] -= 1 + if self.state[f] == 0: + reset_fields.append(f) + del self.state[f] + if reset_fields: + self.tracker.set_saved_fields(fields=reset_fields) + + class FieldInstanceTracker: def __init__(self, instance, fields, field_map): self.instance = instance self.fields = fields self.field_map = field_map + self.context = FieldsContext(self, *self.fields) + + def __enter__(self): + return self.context.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + return self.context.__exit__(exc_type, exc_val, exc_tb) + + def __call__(self, *fields): + return FieldsContext(self, *fields, state=self.context.state) @property def deferred_fields(self): @@ -195,6 +270,20 @@ class FieldTracker: def __init__(self, fields=None): self.fields = fields + def __call__(self, func=None, fields=None): + def decorator(f): + @wraps(f) + def inner(obj, *args, **kwargs): + tracker = getattr(obj, self.attname) + field_list = tracker.fields if fields is None else fields + with tracker(*field_list): + return f(obj, *args, **kwargs) + + return inner + if func is None: + return decorator + return decorator(func) + def get_field_map(self, cls): """Returns dict mapping fields names to model attribute names""" field_map = {field: field for field in self.fields} @@ -240,21 +329,17 @@ def _patch(self, model, method, fields_kwarg): @wraps(original) def inner(instance, *args, **kwargs): - ret = original(instance, *args, **kwargs) update_fields = kwargs.get(fields_kwarg) - if not update_fields and update_fields is not None: # () or [] - fields = update_fields - elif update_fields is None: - fields = None + if update_fields is None: + fields = self.fields else: fields = ( field for field in update_fields if field in self.fields ) - getattr(instance, self.attname).set_saved_fields( - fields=fields - ) - return ret + tracker = getattr(instance, self.attname) + with tracker(*fields): + return original(instance, *args, **kwargs) setattr(model, method, inner) diff --git a/tests/settings.py b/tests/settings.py index 6fc2bdb2..ac1a80f7 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -21,3 +21,5 @@ 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', } } + +DEFAULT_AUTO_FIELD = 'django.db.models.AutoField' diff --git a/tests/test_fields/test_field_tracker.py b/tests/test_fields/test_field_tracker.py index 164b9b7d..cd9b4858 100644 --- a/tests/test_fields/test_field_tracker.py +++ b/tests/test_fields/test_field_tracker.py @@ -831,3 +831,76 @@ def test_child_fields_not_tracked(self): class AbstractModelTrackerTests(ModelTrackerTests): tracked_class = TrackedAbstract + + +class TrackerContextDecoratorTests(TestCase): + + def setUp(self): + self.instance = Tracked.objects.create(number=1) + self.tracker = self.instance.tracker + + def assertChanged(self, *fields): + for f in fields: + self.assertTrue(self.tracker.has_changed(f)) + + def assertNotChanged(self, *fields): + for f in fields: + self.assertFalse(self.tracker.has_changed(f)) + + def test_context_manager(self): + with self.tracker: + with self.tracker: + self.instance.name = 'new' + + self.assertChanged('name') + + self.assertChanged('name') + + self.assertNotChanged('name') + + def test_context_manager_fields(self): + with self.tracker('number'): + with self.tracker('number', 'name'): + self.instance.name = 'new' + self.instance.number += 1 + + self.assertChanged('name', 'number') + + self.assertChanged('number') + self.assertNotChanged('name') + + self.assertNotChanged('number', 'name') + + def test_tracker_decorator(self): + + @Tracked.tracker + def tracked_method(obj): + obj.name = 'new' + self.assertChanged('name') + + tracked_method(self.instance) + + self.assertNotChanged('name') + + def test_tracker_decorator_fields(self): + + @Tracked.tracker(fields=['name']) + def tracked_method(obj): + obj.name = 'new' + obj.number += 1 + self.assertChanged('name', 'number') + + tracked_method(self.instance) + + self.assertChanged('number') + self.assertNotChanged('name') + + def test_tracker_context_with_save(self): + + with self.tracker: + self.instance.name = 'new' + self.instance.save() + + self.assertChanged('name') + + self.assertNotChanged('name')