Skip to content

Commit

Permalink
Rules states should not include the initial values of slots. (RasaHQ#…
Browse files Browse the repository at this point in the history
…8161)

* Rules states should not include the initial values of slots.

* changelog entry

* Apply suggestions from code review

docstring changes

Co-authored-by: Sam Sucik <s.sucik@rasa.com>

* Apply suggestions from code review

docstring updates in unchanged code

Co-authored-by: Sam Sucik <s.sucik@rasa.com>

* update tests

* remove unnecessary listen action

* Apply suggestions from code review

Co-authored-by: Tobias Wochinger <t.wochinger@rasa.com>

* black

Co-authored-by: Sam Sucik <s.sucik@rasa.com>
Co-authored-by: Tobias Wochinger <t.wochinger@rasa.com>
  • Loading branch information
3 people committed Mar 12, 2021
1 parent 8692ea3 commit 2aade32
Show file tree
Hide file tree
Showing 9 changed files with 244 additions and 33 deletions.
2 changes: 2 additions & 0 deletions changelog/7450.bugfix.md
@@ -0,0 +1,2 @@
Rule tracker states no longer include the initial value of slots.
Rules now only require slot values when explicitly stated in the rule.
43 changes: 34 additions & 9 deletions rasa/core/featurizers/tracker_featurizers.py
Expand Up @@ -51,17 +51,20 @@ def __init__(
self.state_featurizer = state_featurizer

@staticmethod
def _create_states(tracker: DialogueStateTracker, domain: Domain) -> List[State]:
def _create_states(
tracker: DialogueStateTracker, domain: Domain, omit_unset_slots: bool = False,
) -> List[State]:
"""Create states for the given tracker.
Args:
tracker: a :class:`rasa.core.trackers.DialogueStateTracker`
domain: a :class:`rasa.shared.core.domain.Domain`
omit_unset_slots: If `True` do not include the initial values of slots.
Returns:
a list of states
"""
return tracker.past_states(domain)
return tracker.past_states(domain, omit_unset_slots=omit_unset_slots)

def _featurize_states(
self,
Expand Down Expand Up @@ -127,13 +130,17 @@ def _remove_user_text_if_intent(trackers_as_states: List[List[State]]) -> None:
del state[USER][TEXT]

def training_states_actions_and_entities(
self, trackers: List[DialogueStateTracker], domain: Domain
self,
trackers: List[DialogueStateTracker],
domain: Domain,
omit_unset_slots: bool = False,
) -> Tuple[List[List[State]], List[List[Text]], List[List[Dict[Text, Any]]]]:
"""Transforms list of trackers to lists of states, actions and entity data.
Args:
trackers: The trackers to transform
domain: The domain
omit_unset_slots: If `True` do not include the initial values of slots.
Returns:
A tuple of list of states, list of actions and list of entity data.
Expand All @@ -143,13 +150,17 @@ def training_states_actions_and_entities(
)

def training_states_and_actions(
self, trackers: List[DialogueStateTracker], domain: Domain
self,
trackers: List[DialogueStateTracker],
domain: Domain,
omit_unset_slots: bool = False,
) -> Tuple[List[List[State]], List[List[Text]]]:
"""Transforms list of trackers to lists of states and actions.
Args:
trackers: The trackers to transform
domain: The domain
omit_unset_slots: If `True` do not include the initial values of slots.
Returns:
A tuple of list of states and list of actions.
Expand All @@ -158,7 +169,9 @@ def training_states_and_actions(
trackers_as_states,
trackers_as_actions,
_,
) = self.training_states_actions_and_entities(trackers, domain)
) = self.training_states_actions_and_entities(
trackers, domain, omit_unset_slots=omit_unset_slots
)
return trackers_as_states, trackers_as_actions

def featurize_trackers(
Expand Down Expand Up @@ -332,13 +345,17 @@ class FullDialogueTrackerFeaturizer(TrackerFeaturizer):
"""

def training_states_actions_and_entities(
self, trackers: List[DialogueStateTracker], domain: Domain
self,
trackers: List[DialogueStateTracker],
domain: Domain,
omit_unset_slots: bool = False,
) -> Tuple[List[List[State]], List[List[Text]], List[List[Dict[Text, Any]]]]:
"""Transforms list of trackers to lists of states, actions and entity data.
Args:
trackers: The trackers to transform
domain: The domain
omit_unset_slots: If `True` do not include the initial values of slots.
Returns:
A tuple of list of states, list of actions and list of entity data.
Expand All @@ -358,7 +375,9 @@ def training_states_actions_and_entities(
disable=rasa.shared.utils.io.is_logging_disabled(),
)
for tracker in pbar:
states = self._create_states(tracker, domain)
states = self._create_states(
tracker, domain, omit_unset_slots=omit_unset_slots
)

delete_first_state = False
actions = []
Expand Down Expand Up @@ -476,13 +495,17 @@ def _hash_example(
return hash((frozen_states, frozen_actions))

def training_states_actions_and_entities(
self, trackers: List[DialogueStateTracker], domain: Domain
self,
trackers: List[DialogueStateTracker],
domain: Domain,
omit_unset_slots: bool = False,
) -> Tuple[List[List[State]], List[List[Text]], List[List[Dict[Text, Any]]]]:
"""Transforms list of trackers to lists of states, actions and entity data.
Args:
trackers: The trackers to transform
domain: The domain
omit_unset_slots: If `True` do not include the initial values of slots.
Returns:
A tuple of list of states, list of actions and list of entity data.
Expand All @@ -506,7 +529,9 @@ def training_states_actions_and_entities(
disable=rasa.shared.utils.io.is_logging_disabled(),
)
for tracker in pbar:
states = self._create_states(tracker, domain)
states = self._create_states(
tracker, domain, omit_unset_slots=omit_unset_slots
)

states_length_for_action = 0
entity_data = {}
Expand Down
4 changes: 3 additions & 1 deletion rasa/core/policies/rule_policy.py
Expand Up @@ -597,7 +597,9 @@ def train(
(
rule_trackers_as_states,
rule_trackers_as_actions,
) = self.featurizer.training_states_and_actions(rule_trackers, domain)
) = self.featurizer.training_states_and_actions(
rule_trackers, domain, omit_unset_slots=True
)

rules_lookup = self._create_lookup_from_states(
rule_trackers_as_states, rule_trackers_as_actions
Expand Down
41 changes: 33 additions & 8 deletions rasa/shared/core/domain.py
Expand Up @@ -1043,17 +1043,22 @@ def _get_user_sub_state(

@staticmethod
def _get_slots_sub_state(
tracker: "DialogueStateTracker",
tracker: "DialogueStateTracker", omit_unset_slots: bool = False,
) -> Dict[Text, Union[Text, Tuple[float]]]:
"""Set all set slots with the featurization of the stored value
"""Sets all set slots with the featurization of the stored value.
Args:
tracker: dialog state tracker containing the dialog so far
omit_unset_slots: If `True` do not include the initial values of slots.
Returns:
a dictionary mapping slot names to their featurization
"""
slots = {}
for slot_name, slot in tracker.slots.items():
if slot is not None and slot.as_feature():
if omit_unset_slots and not slot.has_been_set:
continue
if slot.value == rasa.shared.core.constants.SHOULD_NOT_BE_SET:
slots[slot_name] = rasa.shared.core.constants.SHOULD_NOT_BE_SET
elif any(slot.as_feature()):
Expand Down Expand Up @@ -1101,11 +1106,22 @@ def _clean_state(state: State) -> State:
if sub_state
}

def get_active_states(self, tracker: "DialogueStateTracker") -> State:
"""Return a bag of active states from the tracker state."""
def get_active_states(
self, tracker: "DialogueStateTracker", omit_unset_slots: bool = False,
) -> State:
"""Returns a bag of active states from the tracker state.
Args:
tracker: dialog state tracker containing the dialog so far
omit_unset_slots: If `True` do not include the initial values of slots.
Returns `State` containing all active states.
"""
state = {
rasa.shared.core.constants.USER: self._get_user_sub_state(tracker),
rasa.shared.core.constants.SLOTS: self._get_slots_sub_state(tracker),
rasa.shared.core.constants.SLOTS: self._get_slots_sub_state(
tracker, omit_unset_slots=omit_unset_slots
),
rasa.shared.core.constants.PREVIOUS_ACTION: self._get_prev_action_sub_state(
tracker
),
Expand All @@ -1116,14 +1132,23 @@ def get_active_states(self, tracker: "DialogueStateTracker") -> State:
return self._clean_state(state)

def states_for_tracker_history(
self, tracker: "DialogueStateTracker"
self, tracker: "DialogueStateTracker", omit_unset_slots: bool = False,
) -> List[State]:
"""Array of states for each state of the trackers history."""
"""List of states for each state of the trackers history.
Args:
tracker: dialog state tracker containing the dialog so far
omit_unset_slots: If `True` do not include the initial values of slots.
Returns: A `State` for each prior tracker.
"""
return [
self.get_active_states(tr) for tr in tracker.generate_all_prior_trackers()
self.get_active_states(tr, omit_unset_slots=omit_unset_slots)
for tr in tracker.generate_all_prior_trackers()
]

def slots_for_entities(self, entities: List[Dict[Text, Any]]) -> List[SlotSet]:
"""Returns `SlotSet` events for extracted entities."""
if self.store_entities_as_slots:
slot_events = []
for s in self.slots:
Expand Down
34 changes: 29 additions & 5 deletions rasa/shared/core/generator.py
Expand Up @@ -84,16 +84,27 @@ def from_events(
tracker.update(e)
return tracker

def past_states_for_hashing(self, domain: Domain) -> Deque[FrozenState]:
def past_states_for_hashing(
self, domain: Domain, omit_unset_slots: bool = False,
) -> Deque[FrozenState]:
"""Generates and caches the past states of this tracker based on the history.
Args:
domain: a :class:`rasa.shared.core.domain.Domain`
omit_unset_slots: If `True` do not include the initial values of slots.
Returns:
A list of states
"""
# we need to make sure this is the same domain, otherwise things will
# go south. but really, the same tracker shouldn't be used across
# go wrong. but really, the same tracker shouldn't be used across
# domains
assert domain == self.domain

# if don't have it cached, we use the domain to calculate the states
# from the events
if self._states_for_hashing is None:
states = super().past_states(domain)
states = super().past_states(domain, omit_unset_slots=omit_unset_slots)
self._states_for_hashing = deque(
self.freeze_current_state(s) for s in states
)
Expand All @@ -107,8 +118,21 @@ def _unfreeze_states(frozen_states: Deque[FrozenState]) -> List[State]:
for frozen_state in frozen_states
]

def past_states(self, domain: Domain) -> List[State]:
states_for_hashing = self.past_states_for_hashing(domain)
def past_states(
self, domain: Domain, omit_unset_slots: bool = False,
) -> List[State]:
"""Generates the past states of this tracker based on the history.
Args:
domain: a :class:`rasa.shared.core.domain.Domain`
omit_unset_slots: If `True` do not include the initial values of slots.
Returns:
A list of states
"""
states_for_hashing = self.past_states_for_hashing(
domain, omit_unset_slots=omit_unset_slots
)
return self._unfreeze_states(states_for_hashing)

def clear_states(self) -> None:
Expand Down
21 changes: 20 additions & 1 deletion rasa/shared/core/slots.py
Expand Up @@ -45,11 +45,12 @@ def __init__(
influence the predictions of the dialogue polices.
"""
self.name = name
self.value = initial_value
self._value = initial_value
self.initial_value = initial_value
self._value_reset_delay = value_reset_delay
self.auto_fill = auto_fill
self.influence_conversation = influence_conversation
self._has_been_set = False

def feature_dimensionality(self) -> int:
"""How many features this single slot creates.
Expand Down Expand Up @@ -98,7 +99,25 @@ def _as_feature(self) -> List[float]:
)

def reset(self) -> None:
"""Resets the slot's value to the initial value."""
self.value = self.initial_value
self._has_been_set = False

@property
def value(self) -> Any:
"""Gets the slot's value."""
return self._value

@value.setter
def value(self, value: Any) -> None:
"""Sets the slot's value."""
self._value = value
self._has_been_set = True

@property
def has_been_set(self) -> bool:
"""Indicates if the slot's value has been set."""
return self._has_been_set

def __str__(self) -> Text:
return f"{self.__class__.__name__}({self.name}: {self.value})"
Expand Down
19 changes: 12 additions & 7 deletions rasa/shared/core/trackers.py
Expand Up @@ -273,16 +273,21 @@ def freeze_current_state(state: State) -> FrozenState:
}.items()
)

def past_states(self, domain: Domain) -> List[State]:
"""Generate the past states of this tracker based on the history.
def past_states(
self, domain: Domain, omit_unset_slots: bool = False,
) -> List[State]:
"""Generates the past states of this tracker based on the history.
Args:
domain: a :class:`rasa.shared.core.domain.Domain`
omit_unset_slots: If `True` do not include the initial values of slots.
Returns:
a list of states
A list of states
"""
return domain.states_for_tracker_history(self)
return domain.states_for_tracker_history(
self, omit_unset_slots=omit_unset_slots
)

def change_loop_to(self, loop_name: Optional[Text]) -> None:
"""Set the currently active loop.
Expand Down Expand Up @@ -782,10 +787,10 @@ def _reset_slots(self) -> None:
slot.reset()

def _set_slot(self, key: Text, value: Any) -> None:
"""Set the value of a slot if that slot exists."""

"""Sets the value of a slot if that slot exists."""
if key in self.slots:
self.slots[key].value = value
slot = self.slots[key]
slot.value = value
else:
logger.error(
f"Tried to set non existent slot '{key}'. Make sure you "
Expand Down

0 comments on commit 2aade32

Please sign in to comment.