-
-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Initial implementation for few-shot learning * Add initial implementation of the reader * Make ExampleReader more generalized * Add examples for NER few-shot learning * Update tests for jinja2 templating * Update prompt for NER to incorporate few-shot learning * Make input to ExampleReader a string * Add initial set of tests for jinja2 templating * Use JSON instead of JSONL to be consistent with Prodigy * Update prompt and typing and fix tests * Move reader from registry.util to its own file * Add fewshot config for pipe tests * Fix type errors * Update codebase with black * Fix linting errors * Refactor FewShotReader for reading task examples Yes, it's now FewShotReader. We changed it to this name because ExampleReader kinda overloads spaCy's Example primitive--we don't want to cause that confusion. There are also several updates such as: - Accepting a union of path, str: to cover our bases - Using registry instead of spacy.registry: better coupling with spaCy - Support JSONL: why not? I also updated the tests * Remove spurious imports * Rename example reader function itself * Fix errors from merge * Replace MiniChain.v1 with REST.v1 * Update some docstrings * Initial implementation of textcat task Not super confident with this implementation as it departed quite a bit from Prodigy. Instead of getting accept and reject responses, I just ask the LLM to return POS and NEG. I also made a few decisions on what the categories would be when using binary text classification. For now, I'm thinking about using POS and NEG by default. Not sure if this should be parametrized. * Binary and exclusive should raise an error instead * Add initial test suite Still a work in progress (see TODOs) * Remove jinja2 partials * Fix unbound variables * Fix incorrect variable calls * Strip labels to ensure clean output * Initial refactor of textcat task * Initial refactor of textcat tests * Fix some minor bugs in imports * Make task examples more reusable * Ensure that lowercase and uppercase POS is considered by parser * Add condition if exclusive but llm returned multiple' * Test multiple iterations for multilabel responses * Add tests for jinja rendering in textcat I also added some fixes in the template whenever necessary. * Fix tests and their corresponding fixtures * small fixes * Update spacy_llm/tasks/textcat.py Co-authored-by: Raphael Mitsch <r.mitsch@outlook.com> * Update spacy_llm/tasks/ner.py Co-authored-by: Raphael Mitsch <r.mitsch@outlook.com> * Update how scores are returned For negative cases, we get something like {LABEL: 0.0} For positive cases, we get something like {LABEL: 1.0} * Apply suggestions from code review Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Add verbose flag and hide the debug message * Apply suggestions from code review Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Apply black formatting * Create strip normalizer instead of noop normalizer The lowercase normalizer also contains this. * Apply suggestions from code review Docstrings Co-authored-by: Raphael Mitsch <r.mitsch@outlook.com> --------- Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: Raphael Mitsch <r.mitsch@outlook.com>
- Loading branch information
1 parent
d70a626
commit 2c4100c
Showing
16 changed files
with
682 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,10 @@ | ||
from .normalizer import lowercase_normalizer, noop_normalizer | ||
from .normalizer import lowercase_normalizer, strip_normalizer | ||
from .reader import fewshot_reader | ||
from .util import registry | ||
|
||
__all__ = [ | ||
"lowercase_normalizer", | ||
"noop_normalizer", | ||
"strip_normalizer", | ||
"fewshot_reader", | ||
"registry", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .ner import NERTask | ||
from .textcat import TextCatTask | ||
from .noop import NoopTask | ||
|
||
__all__ = ["NoopTask", "NERTask"] | ||
__all__ = ["NoopTask", "NERTask", "TextCatTask"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
from typing import Any, Callable, Dict, Iterable, Optional | ||
|
||
import jinja2 | ||
from spacy.tokens import Doc | ||
from wasabi import msg | ||
|
||
from ..registry import strip_normalizer, registry | ||
|
||
|
||
@registry.llm_tasks("spacy.TextCat.v1") | ||
class TextCatTask: | ||
_TEMPLATE_STR = """ | ||
{%- if labels|length == 1 -%} | ||
{%- set label = labels[0] -%} | ||
Classify whether the text below belongs to the {{ label }} category or not. | ||
If it is a {{ label }}, answer `POS`. If it is not a {{ label }}, answer `NEG`. | ||
{%- else -%} | ||
Classify the text below to any of the following labels: {{ labels|join(", ") }} | ||
{# whitespace #} | ||
{%- if exclusive_classes -%} | ||
The task is exclusive, so only choose one label from what I provided. | ||
{%- else -%} | ||
The task is non-exclusive, so you can provide more than one label as long as | ||
they're comma-delimited. For example: Label1, Label2, Label3. | ||
{%- endif -%} | ||
{# whitespace #} | ||
{%- endif -%} | ||
{# whitespace #} | ||
{%- if examples -%} | ||
{# whitespace #} | ||
Below are some examples (only use these as a guide): | ||
{# whitespace #} | ||
{# whitespace #} | ||
{%- for example in examples -%} | ||
{# whitespace #} | ||
Text: | ||
''' | ||
{{ example['text'] }} | ||
''' | ||
{# whitespace #} | ||
{{ example['answer']}} | ||
{# whitespace #} | ||
{%- endfor -%} | ||
{%- endif -%} | ||
{# whitespace #} | ||
{# whitespace #} | ||
Here is the text that needs classification | ||
{# whitespace #} | ||
{# whitespace #} | ||
Text: | ||
''' | ||
{{ text }} | ||
''' | ||
""" | ||
|
||
def __init__( | ||
self, | ||
labels: str, | ||
examples: Optional[Callable[[], Iterable[Any]]] = None, | ||
normalizer: Optional[Callable[[str], str]] = None, | ||
exclusive_classes: bool = False, | ||
verbose: bool = False, | ||
): | ||
"""Default TextCat task. | ||
You can use either binary or multilabel text classification based on the | ||
labels you provide. | ||
If a single label is provided, binary classification | ||
will be used. The label will get a score of `0` or `1` in `doc.cats`. | ||
If a comma-separated list of labels is provided, multilabel | ||
classification will be used. The document labels in `doc.cats` will be a | ||
dictionary of strings and their score. | ||
Lastly, you can toggle between exclusive or no-exclusive text | ||
categorization by passing a flag to the `exclusive_classes` parameter. | ||
labels (str): Comma-separated list of labels to pass to the template. This task | ||
assumes binary classification if a single label is provided. | ||
examples (Optional[Callable[[], Iterable[Any]]]): Optional callable that | ||
reads a file containing task examples for few-shot learning. If None is | ||
passed, then zero-shot learning will be used. | ||
normalizer (Optional[Callable[[str], str]]): Optional normalizer function. | ||
exclusive_classes (bool): If True, require the language model to suggest only one | ||
label per class. This is automatically set when using binary classification. | ||
verbose (bool): If True, show extra information. | ||
""" | ||
self._normalizer = normalizer if normalizer else strip_normalizer() | ||
self._label_dict = { | ||
self._normalizer(label): label for label in labels.split(",") | ||
} | ||
self._examples = examples() if examples else None | ||
# Textcat configuration | ||
self._use_binary = True if len(self._label_dict) == 1 else False | ||
self._exclusive_classes = exclusive_classes | ||
self._verbose = verbose | ||
|
||
if self._use_binary and not self._exclusive_classes: | ||
msg.warn( | ||
"Binary classification should always be exclusive. Setting " | ||
"the `exclusive_classes` parameter to True." | ||
) | ||
self._exclusive_classes = True | ||
|
||
def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: | ||
environment = jinja2.Environment() | ||
_template = environment.from_string(self._TEMPLATE_STR) | ||
for doc in docs: | ||
prompt = _template.render( | ||
text=doc.text, | ||
labels=list(self._label_dict.values()), | ||
examples=self._examples, | ||
exclusive_classes=self._exclusive_classes, | ||
) | ||
yield prompt | ||
|
||
def _format_response(self, response: str) -> Dict[str, float]: | ||
"""Parse raw string response into a structured format | ||
The returned dictionary contains the labels mapped to their score. | ||
""" | ||
categories = {} | ||
|
||
if self._use_binary: | ||
# Binary classification: We only have one label | ||
label: str = list(self._label_dict.values())[0] | ||
score = 1.0 if response.upper() == "POS" else 0.0 | ||
categories = {label: score} | ||
else: | ||
# Multilabel classification | ||
categories = {label: 0 for label in self._label_dict.values()} | ||
|
||
pred_labels = response.split(",") | ||
if self._exclusive_classes and len(pred_labels) > 1: | ||
# Don't use anything but raise a debug message | ||
# Don't raise an error. Let user abort if they want to. | ||
msg.text( | ||
f"LLM returned multiple labels for this exclusive task: {pred_labels}.", | ||
" Will store an empty label instead.", | ||
show=self._verbose, | ||
) | ||
pred_labels = [] | ||
|
||
for pred in pred_labels: | ||
if self._normalizer(pred.strip()) in self._label_dict: | ||
category = self._label_dict[self._normalizer(pred.strip())] | ||
categories[category] = 1.0 | ||
return categories | ||
|
||
def parse_responses( | ||
self, docs: Iterable[Doc], responses: Iterable[str] | ||
) -> Iterable[Doc]: | ||
for doc, prompt_response in zip(docs, responses): | ||
cats = self._format_response(prompt_response) | ||
doc.cats = cats | ||
yield doc |
14 changes: 14 additions & 0 deletions
14
spacy_llm/tests/tasks/examples/textcat_binary_examples.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
[ | ||
{ | ||
"text": "Macaroni and cheese is the best budget meal for students, unhealthy tho", | ||
"answer": "NEG" | ||
}, | ||
{ | ||
"text": "2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper, mix then well and you get an adobo", | ||
"answer": "POS" | ||
}, | ||
{ | ||
"text": "You can still add more layers to that croissant, get extra butter and add a few cups of flour", | ||
"answer": "POS" | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{"text":"Macaroni and cheese is the best budget meal for students, unhealthy tho","answer":"NEG"} | ||
{"text":"2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper, mix then well and you get an adobo","answer":"POS"} | ||
{"text":"You can still add more layers to that croissant, get extra butter and add a few cups of flour","answer":"POS"} |
10 changes: 10 additions & 0 deletions
10
spacy_llm/tests/tasks/examples/textcat_binary_examples.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
- text: Macaroni and cheese is the best budget meal for students, unhealthy tho | ||
answer: NEG | ||
- text: | ||
2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper, | ||
mix then well and you get an adobo | ||
answer: POS | ||
- text: | ||
You can still add more layers to that croissant, get extra butter and add | ||
a few cups of flour | ||
answer: POS |
14 changes: 14 additions & 0 deletions
14
spacy_llm/tests/tasks/examples/textcat_multi_excl_examples.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
[ | ||
{ | ||
"text":"Macaroni and cheese is the best budget meal for students, unhealthy tho", | ||
"answer":"Comment" | ||
}, | ||
{ | ||
"text":"2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper, mix then well and you get an adobo", | ||
"answer":"Recipe" | ||
}, | ||
{ | ||
"text":"You can still add more layers to that croissant, get extra butter and add a few cups of flour", | ||
"answer":"Feedback" | ||
} | ||
] |
3 changes: 3 additions & 0 deletions
3
spacy_llm/tests/tasks/examples/textcat_multi_excl_examples.jsonl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{"text":"Macaroni and cheese is the best budget meal for students, unhealthy tho","answer":"Comment"} | ||
{"text":"2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper, mix then well and you get an adobo","answer":"Recipe"} | ||
{"text":"You can still add more layers to that croissant, get extra butter and add a few cups of flour","answer":"Feedback"} |
8 changes: 8 additions & 0 deletions
8
spacy_llm/tests/tasks/examples/textcat_multi_excl_examples.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
- text: Macaroni and cheese is the best budget meal for students, unhealthy tho | ||
answer: Comment | ||
- text: 2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper, | ||
mix then well and you get an adobo | ||
answer: Recipe | ||
- text: You can still add more layers to that croissant, get extra butter and add | ||
a few cups of flour | ||
answer: Feedback |
14 changes: 14 additions & 0 deletions
14
spacy_llm/tests/tasks/examples/textcat_multi_nonexcl_examples.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
[ | ||
{ | ||
"text":"Macaroni and cheese is the best budget meal for students, unhealthy tho", | ||
"answer":"Comment,Feedback" | ||
}, | ||
{ | ||
"text":"2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper, mix then well and you get an adobo", | ||
"answer":"Recipe" | ||
}, | ||
{ | ||
"text":"You can still add more layers to that croissant, get extra butter and add a few cups of flour", | ||
"answer":"Feedback,Recipe" | ||
} | ||
] |
3 changes: 3 additions & 0 deletions
3
spacy_llm/tests/tasks/examples/textcat_multi_nonexcl_examples.jsonl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{"text":"Macaroni and cheese is the best budget meal for students, unhealthy tho","answer":"Comment,Feedback"} | ||
{"text":"2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper, mix then well and you get an adobo","answer":"Recipe"} | ||
{"text":"You can still add more layers to that croissant, get extra butter and add a few cups of flour","answer":"Feedback,Recipe"} |
8 changes: 8 additions & 0 deletions
8
spacy_llm/tests/tasks/examples/textcat_multi_nonexcl_examples.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
- text: Macaroni and cheese is the best budget meal for students, unhealthy tho | ||
answer: Comment,Feedback | ||
- text: 2 cups soy sauce, 1/2 lb. of chicken, 1/2 cup vinegar, then salt and paper, | ||
mix then well and you get an adobo | ||
answer: Recipe | ||
- text: You can still add more layers to that croissant, get extra butter and add | ||
a few cups of flour | ||
answer: Feedback,Recipe |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.