Skip to content

Commit

Permalink
Retrieve agent and user ids from JSON in json_to_dialogues (#225)
Browse files Browse the repository at this point in the history
* Retrieve agent and user ids from JSON in `json_to_dialogues`
Fixes #223

* Use list of ids for filtering

* Fix version of websockets to solve error in CI
Related to this issue: sanic-org/sanic#2733
  • Loading branch information
NoB0 committed Apr 14, 2023
1 parent fc7d519 commit d7028da
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 11 deletions.
20 changes: 16 additions & 4 deletions dialoguekit/utils/dialogue_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
_FIELD_CONVERSATION = "conversation"
_FIELD_CONVERSATION_ID = "conversation ID"
_FIELD_PARTICIPANT = "participant"
_FIELD_AGENT = "agent"
_FIELD_USER = "user"


def json_to_annotated_utterance(
Expand Down Expand Up @@ -71,15 +73,17 @@ def json_to_annotated_utterance(

def json_to_dialogues(
filepath: str,
agent_id: str,
user_id: str,
agent_ids: List[str] = None,
user_ids: List[str] = None,
) -> List[Dialogue]:
"""Parses a JSON file containing dialogues.
Args:
filepath: Path to JSON file containing the dialogues.
agent_id: Agent ID in the dialogues.
user_id: User ID in the dialogues.
agent_ids: List of agents' id to filter loaded dialogues. Defaults to
None.
user_ids: List of users' id to filter loaded dialogues. Defaults to
None.
Returns:
A list of Dialogue objects.
Expand All @@ -90,6 +94,14 @@ def json_to_dialogues(
dialogues = []
for dialogue_data in data:
conversation_id = dialogue_data.get(_FIELD_CONVERSATION_ID, None)
agent_id = dialogue_data.get(_FIELD_AGENT, {}).get("id", "Agent")
user_id = dialogue_data.get(_FIELD_USER, {}).get("id", "User")
if (agent_ids and agent_id not in agent_ids) or (
user_ids and user_id not in user_ids
):
# Filter loaded dialogues based on agent_ids and/or user_ids if
# provided
continue
dialogue = Dialogue(agent_id, user_id, conversation_id)
for utterance_data in dialogue_data.get(_FIELD_CONVERSATION):
annotated_utterance = json_to_annotated_utterance(utterance_data)
Expand Down
3 changes: 2 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-r pre_commit.txt
scikit-learn >= 0.24
rasa >= 3.0.8
rasa >= 3.0.8
websockets<11.0
10 changes: 9 additions & 1 deletion tests/data/annotated_dialogues.json
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,15 @@
"utterance": "You are exiting. I hope you found a movie. Bye.\",\n",
"intent": "END"
}
]
],
"agent": {
"id": "MovieBotTester",
"type": "AGENT"
},
"user": {
"id": "TEST03",
"type": "USER"
}
},
{
"conversation ID": "\"3WOKGM4L721JEJM00BS2Y3IMR910O8\"",
Expand Down
34 changes: 31 additions & 3 deletions tests/utils/test_dialogue_reader.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,46 @@
"""Tests for the dialogue reader."""

from typing import List

import pytest

from dialoguekit.utils.dialogue_reader import json_to_dialogues


def test_json_to_dialogues():
def test_json_to_dialogues() -> None:
"""Tests reading of json dialogues."""
dialogues = json_to_dialogues(
filepath="tests/data/annotated_dialogues.json",
agent_id="TestAGENT",
user_id="TestUSER",
)
assert len(dialogues) == 3
assert len(dialogues[0].utterances) > 0
assert dialogues[0].agent_id == "MovieBotTester"
assert dialogues[0].user_id == "TEST03"
assert dialogues[-1].agent_id == "Agent"
assert dialogues[-1].user_id == "User"
assert dialogues[0].utterances[0].participant == "USER"
assert dialogues[0].utterances[1].participant == "AGENT"
assert dialogues[0].utterances[0].intent.label == "DISCLOSE.NON-DISCLOSE"
assert dialogues[0].utterances[1].intent.label == "INQUIRE.ELICIT"


@pytest.mark.parametrize(
"agent_ids, user_ids, expected_dialogue_count",
[
(["MovieBotTester"], None, 1),
(None, None, 3),
(["TestAgent"], ["TestUser"], 0),
(None, ["TEST03"], 1),
(["MovieBotTester"], ["TEST03"], 1),
],
)
def test_json_to_dialogues_filtered(
agent_ids: List[str], user_ids: List[str], expected_dialogue_count: int
) -> None:
"""Tests reading of json dialogues with filtering parameters."""
dialogues = json_to_dialogues(
filepath="tests/data/annotated_dialogues.json",
agent_ids=agent_ids,
user_ids=user_ids,
)
assert len(dialogues) == expected_dialogue_count
2 changes: 0 additions & 2 deletions tests/utils/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ def annotated_dialogues() -> List[Dialogue]:
"""Test dialogue fixture."""
export_dialogues = json_to_dialogues(
filepath="tests/data/annotated_dialogues.json",
agent_id=DialogueParticipant.AGENT,
user_id=DialogueParticipant.USER,
)
return export_dialogues

Expand Down

0 comments on commit d7028da

Please sign in to comment.