Skip to content

Commit

Permalink
Add text classification task (#32)
Browse files Browse the repository at this point in the history
* Initial implementation for few-shot learning

* Add initial implementation of the reader

* Make ExampleReader more generalized

* Add examples for NER few-shot learning

* Update tests for jinja2 templating

* Update prompt for NER to incorporate few-shot learning

* Make input to ExampleReader a string

* Add initial set of tests for jinja2 templating

* Use JSON instead of JSONL to be consistent with Prodigy

* Update prompt and typing and fix tests

* Move reader from registry.util to its own file

* Add fewshot config for pipe tests

* Fix type errors

* Update codebase with black

* Fix linting errors

* Refactor FewShotReader for reading task examples

Yes, it's now FewShotReader. We changed it to this name because
ExampleReader kinda overloads spaCy's Example primitive--we don't want
to cause that confusion.

There are also several updates such as:
- Accepting a union of path, str: to cover our bases
- Using registry instead of spacy.registry: better coupling with spaCy
- Support JSONL: why not? I also updated the tests

* Remove spurious imports

* Rename example reader function itself

* Fix errors from merge

* Replace MiniChain.v1 with REST.v1

* Update some docstrings

* Initial implementation of textcat task

Not super confident with this implementation as it departed quite a bit
from Prodigy. Instead of getting accept and reject responses, I just ask
the LLM to return POS and NEG.

I also made a few decisions on what the categories would be when using
binary text classification. For now, I'm thinking about using POS and
NEG by default. Not sure if this should be parametrized.

* Binary and exclusive should raise an error instead

* Add initial test suite

Still a work in progress (see TODOs)

* Remove jinja2 partials

* Fix unbound variables

* Fix incorrect variable calls

* Strip labels to ensure clean output

* Initial refactor of textcat task

* Initial refactor of textcat tests

* Fix some minor bugs in imports

* Make task examples more reusable

* Ensure that lowercase and uppercase POS is considered by parser

* Add condition if exclusive but llm returned multiple'

* Test multiple iterations for multilabel responses

* Add tests for jinja rendering in textcat

I also added some fixes in the template whenever necessary.

* Fix tests and their corresponding fixtures

* small fixes

* Update spacy_llm/tasks/textcat.py

Co-authored-by: Raphael Mitsch <r.mitsch@outlook.com>

* Update spacy_llm/tasks/ner.py

Co-authored-by: Raphael Mitsch <r.mitsch@outlook.com>

* Update how scores are returned

For negative cases, we get something like

{LABEL: 0.0}

For positive cases, we get something like

{LABEL: 1.0}

* Apply suggestions from code review

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Add verbose flag and hide the debug message

* Apply suggestions from code review

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Apply black formatting

* Create strip normalizer instead of noop normalizer

The lowercase normalizer also contains this.

* Apply suggestions from code review

Docstrings

Co-authored-by: Raphael Mitsch <r.mitsch@outlook.com>

---------

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
Co-authored-by: Raphael Mitsch <r.mitsch@outlook.com>
  • Loading branch information
3 people committed May 10, 2023
1 parent d70a626 commit 2c4100c
Show file tree
Hide file tree
Showing 16 changed files with 682 additions and 16 deletions.
4 changes: 2 additions & 2 deletions spacy_llm/registry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .normalizer import lowercase_normalizer, noop_normalizer
from .normalizer import lowercase_normalizer, strip_normalizer
from .reader import fewshot_reader
from .util import registry

__all__ = [
"lowercase_normalizer",
"noop_normalizer",
"strip_normalizer",
"fewshot_reader",
"registry",
]
10 changes: 5 additions & 5 deletions spacy_llm/registry/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from .util import registry


@registry.misc("spacy.NoopNormalizer.v1")
def noop_normalizer() -> Callable[[str], str]:
"""Return the labels as-is
@registry.misc("spacy.StripNormalizer.v1")
def strip_normalizer() -> Callable[[str], str]:
"""Return the labels as-is with stripped whitespaces
RETURNS (Callable[[str], str])
"""

def noop(s: str) -> str:
return s
return s.strip()

return noop

Expand All @@ -24,6 +24,6 @@ def lowercase_normalizer() -> Callable[[str], str]:
"""

def lowercase(s: str) -> str:
return s.lower()
return s.strip().lower()

return lowercase
3 changes: 2 additions & 1 deletion spacy_llm/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .ner import NERTask
from .textcat import TextCatTask
from .noop import NoopTask

__all__ = ["NoopTask", "NERTask"]
__all__ = ["NoopTask", "NERTask", "TextCatTask"]
12 changes: 6 additions & 6 deletions spacy_llm/tasks/ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from spacy.tokens import Doc
from spacy.util import filter_spans

from ..registry import noop_normalizer, registry
from ..registry import strip_normalizer, registry


def find_substrings(
Expand Down Expand Up @@ -97,19 +97,19 @@ def __init__(
case_sensitive_matching: bool = False,
single_match: bool = False,
):
"""Default NER template for LLM annotation
"""Default NER task.
labels (str): comma-separated list of labels to pass to the template.
examples (Optional[Callable[[], Iterable[Any]]]): optional callable that
labels (str): Comma-separated list of labels to pass to the template.
examples (Optional[Callable[[], Iterable[Any]]]): Optional callable that
reads a file containing task examples for few-shot learning. If None is
passed, then zero-shot learning will be used.
normalizer (Optional[Callable[[str], str]]): optional normalizer function.
alignment_mode (str): "strict", "contract" or "expand".
case_sensitive: Whether to search without case sensitivity.
single_match: If False, allow one substring to match multiple times in
single_match (bool): If False, allow one substring to match multiple times in
the text. If True, returns the first hit.
"""
self._normalizer = normalizer if normalizer else noop_normalizer()
self._normalizer = normalizer if normalizer else strip_normalizer()
self._label_dict = {
self._normalizer(label): label for label in labels.split(",")
}
Expand Down
157 changes: 157 additions & 0 deletions spacy_llm/tasks/textcat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from typing import Any, Callable, Dict, Iterable, Optional

import jinja2
from spacy.tokens import Doc
from wasabi import msg

from ..registry import strip_normalizer, registry


@registry.llm_tasks("spacy.TextCat.v1")
class TextCatTask:
_TEMPLATE_STR = """
{%- if labels|length == 1 -%}
{%- set label = labels[0] -%}
Classify whether the text below belongs to the {{ label }} category or not.
If it is a {{ label }}, answer `POS`. If it is not a {{ label }}, answer `NEG`.
{%- else -%}
Classify the text below to any of the following labels: {{ labels|join(", ") }}
{# whitespace #}
{%- if exclusive_classes -%}
The task is exclusive, so only choose one label from what I provided.
{%- else -%}
The task is non-exclusive, so you can provide more than one label as long as
they're comma-delimited. For example: Label1, Label2, Label3.
{%- endif -%}
{# whitespace #}
{%- endif -%}
{# whitespace #}
{%- if examples -%}
{# whitespace #}
Below are some examples (only use these as a guide):
{# whitespace #}
{# whitespace #}
{%- for example in examples -%}
{# whitespace #}
Text:
'''
{{ example['text'] }}
'''
{# whitespace #}
{{ example['answer']}}
{# whitespace #}
{%- endfor -%}
{%- endif -%}
{# whitespace #}
{# whitespace #}
Here is the text that needs classification
{# whitespace #}
{# whitespace #}
Text:
'''
{{ text }}
'''
"""

def __init__(
self,
labels: str,
examples: Optional[Callable[[], Iterable[Any]]] = None,
normalizer: Optional[Callable[[str], str]] = None,
exclusive_classes: bool = False,
verbose: bool = False,
):
"""Default TextCat task.
You can use either binary or multilabel text classification based on the
labels you provide.
If a single label is provided, binary classification
will be used. The label will get a score of `0` or `1` in `doc.cats`.
If a comma-separated list of labels is provided, multilabel
classification will be used. The document labels in `doc.cats` will be a
dictionary of strings and their score.
Lastly, you can toggle between exclusive or no-exclusive text
categorization by passing a flag to the `exclusive_classes` parameter.
labels (str): Comma-separated list of labels to pass to the template. This task
assumes binary classification if a single label is provided.
examples (Optional[Callable[[], Iterable[Any]]]): Optional callable that
reads a file containing task examples for few-shot learning. If None is
passed, then zero-shot learning will be used.
normalizer (Optional[Callable[[str], str]]): Optional normalizer function.
exclusive_classes (bool): If True, require the language model to suggest only one
label per class. This is automatically set when using binary classification.
verbose (bool): If True, show extra information.
"""
self._normalizer = normalizer if normalizer else strip_normalizer()
self._label_dict = {
self._normalizer(label): label for label in labels.split(",")
}
self._examples = examples() if examples else None
# Textcat configuration
self._use_binary = True if len(self._label_dict) == 1 else False
self._exclusive_classes = exclusive_classes
self._verbose = verbose

if self._use_binary and not self._exclusive_classes:
msg.warn(
"Binary classification should always be exclusive. Setting "
"the `exclusive_classes` parameter to True."
)
self._exclusive_classes = True

def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]:
environment = jinja2.Environment()
_template = environment.from_string(self._TEMPLATE_STR)
for doc in docs:
prompt = _template.render(
text=doc.text,
labels=list(self._label_dict.values()),
examples=self._examples,
exclusive_classes=self._exclusive_classes,
)
yield prompt

def _format_response(self, response: str) -> Dict[str, float]:
"""Parse raw string response into a structured format
The returned dictionary contains the labels mapped to their score.
"""
categories = {}

if self._use_binary:
# Binary classification: We only have one label
label: str = list(self._label_dict.values())[0]
score = 1.0 if response.upper() == "POS" else 0.0
categories = {label: score}
else:
# Multilabel classification
categories = {label: 0 for label in self._label_dict.values()}

pred_labels = response.split(",")
if self._exclusive_classes and len(pred_labels) > 1:
# Don't use anything but raise a debug message
# Don't raise an error. Let user abort if they want to.
msg.text(
f"LLM returned multiple labels for this exclusive task: {pred_labels}.",
" Will store an empty label instead.",
show=self._verbose,
)
pred_labels = []

for pred in pred_labels:
if self._normalizer(pred.strip()) in self._label_dict:
category = self._label_dict[self._normalizer(pred.strip())]
categories[category] = 1.0
return categories

def parse_responses(
self, docs: Iterable[Doc], responses: Iterable[str]
) -> Iterable[Doc]:
for doc, prompt_response in zip(docs, responses):
cats = self._format_response(prompt_response)
doc.cats = cats
yield doc
14 changes: 14 additions & 0 deletions spacy_llm/tests/tasks/examples/textcat_binary_examples.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[
{
"text": "Macaroni and cheese is the best budget meal for students, unhealthy tho",
"answer": "NEG"
},
{
"text": "2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper, mix then well and you get an adobo",
"answer": "POS"
},
{
"text": "You can still add more layers to that croissant, get extra butter and add a few cups of flour",
"answer": "POS"
}
]
3 changes: 3 additions & 0 deletions spacy_llm/tests/tasks/examples/textcat_binary_examples.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"text":"Macaroni and cheese is the best budget meal for students, unhealthy tho","answer":"NEG"}
{"text":"2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper, mix then well and you get an adobo","answer":"POS"}
{"text":"You can still add more layers to that croissant, get extra butter and add a few cups of flour","answer":"POS"}
10 changes: 10 additions & 0 deletions spacy_llm/tests/tasks/examples/textcat_binary_examples.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
- text: Macaroni and cheese is the best budget meal for students, unhealthy tho
answer: NEG
- text:
2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper,
mix then well and you get an adobo
answer: POS
- text:
You can still add more layers to that croissant, get extra butter and add
a few cups of flour
answer: POS
14 changes: 14 additions & 0 deletions spacy_llm/tests/tasks/examples/textcat_multi_excl_examples.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[
{
"text":"Macaroni and cheese is the best budget meal for students, unhealthy tho",
"answer":"Comment"
},
{
"text":"2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper, mix then well and you get an adobo",
"answer":"Recipe"
},
{
"text":"You can still add more layers to that croissant, get extra butter and add a few cups of flour",
"answer":"Feedback"
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"text":"Macaroni and cheese is the best budget meal for students, unhealthy tho","answer":"Comment"}
{"text":"2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper, mix then well and you get an adobo","answer":"Recipe"}
{"text":"You can still add more layers to that croissant, get extra butter and add a few cups of flour","answer":"Feedback"}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
- text: Macaroni and cheese is the best budget meal for students, unhealthy tho
answer: Comment
- text: 2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper,
mix then well and you get an adobo
answer: Recipe
- text: You can still add more layers to that croissant, get extra butter and add
a few cups of flour
answer: Feedback
14 changes: 14 additions & 0 deletions spacy_llm/tests/tasks/examples/textcat_multi_nonexcl_examples.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[
{
"text":"Macaroni and cheese is the best budget meal for students, unhealthy tho",
"answer":"Comment,Feedback"
},
{
"text":"2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper, mix then well and you get an adobo",
"answer":"Recipe"
},
{
"text":"You can still add more layers to that croissant, get extra butter and add a few cups of flour",
"answer":"Feedback,Recipe"
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"text":"Macaroni and cheese is the best budget meal for students, unhealthy tho","answer":"Comment,Feedback"}
{"text":"2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper, mix then well and you get an adobo","answer":"Recipe"}
{"text":"You can still add more layers to that croissant, get extra butter and add a few cups of flour","answer":"Feedback,Recipe"}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
- text: Macaroni and cheese is the best budget meal for students, unhealthy tho
answer: Comment,Feedback
- text: 2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper,
mix then well and you get an adobo
answer: Recipe
- text: You can still add more layers to that croissant, get extra butter and add
a few cups of flour
answer: Feedback,Recipe
4 changes: 2 additions & 2 deletions spacy_llm/tests/tasks/test_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from confection import Config
from spacy.util import make_tempdir

from spacy_llm.registry import noop_normalizer, lowercase_normalizer, fewshot_reader
from spacy_llm.registry import strip_normalizer, lowercase_normalizer, fewshot_reader
from spacy_llm.tasks.ner import find_substrings, NERTask

EXAMPLES_DIR = Path(__file__).parent / "examples"
Expand Down Expand Up @@ -194,7 +194,7 @@ def test_ner_zero_shot_task(text, response, gold_ents):
),
(
"PER: Jean Jacques, Jaime",
noop_normalizer(),
strip_normalizer(),
[("Jean Jacques", "PER"), ("Jaime", "PER")],
),
(
Expand Down

0 comments on commit 2c4100c

Please sign in to comment.