-
-
Notifications
You must be signed in to change notification settings - Fork 76
/
test_ner.py
433 lines (375 loc) · 11.8 KB
/
test_ner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
# mypy: ignore-errors
from pathlib import Path
import pytest
import spacy
from confection import Config
from spacy.util import make_tempdir
from spacy_llm.registry import strip_normalizer, lowercase_normalizer, fewshot_reader
from spacy_llm.tasks.ner import find_substrings, NERTask
EXAMPLES_DIR = Path(__file__).parent / "examples"
@pytest.fixture
def zeroshot_cfg_string():
return """
[nlp]
lang = "en"
pipeline = ["llm"]
batch_size = 128
[components]
[components.llm]
factory = "llm"
[components.llm.task]
@llm_tasks: "spacy.NER.v1"
labels: PER,ORG,LOC
[components.llm.task.normalizer]
@misc: "spacy.LowercaseNormalizer.v1"
[components.llm.backend]
@llm_backends: "spacy.REST.v1"
api: "OpenAI"
config: {}
"""
@pytest.fixture
def fewshot_cfg_string():
return f"""
[nlp]
lang = "en"
pipeline = ["llm"]
batch_size = 128
[components]
[components.llm]
factory = "llm"
[components.llm.task]
@llm_tasks: "spacy.NER.v1"
labels: PER,ORG,LOC
[components.llm.task.examples]
@misc: "spacy.FewShotReader.v1"
path: {str((Path(__file__).parent / "examples" / "ner_examples.yml"))}
[components.llm.task.normalizer]
@misc: "spacy.LowercaseNormalizer.v1"
[components.llm.backend]
@llm_backends: "spacy.REST.v1"
api: "OpenAI"
config: {{}}
"""
@pytest.mark.parametrize("cfg_string", ["fewshot_cfg_string"]) # "zeroshot_cfg_string",
def test_ner_config(cfg_string, request):
cfg_string = request.getfixturevalue(cfg_string)
orig_config = Config().from_str(cfg_string)
nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True)
assert nlp.pipe_names == ["llm"]
@pytest.mark.external
@pytest.mark.parametrize("cfg_string", ["zeroshot_cfg_string", "fewshot_cfg_string"])
def test_ner_predict(cfg_string, request):
"""Use OpenAI to get zero-shot NER results.
Note that this test may fail randomly, as the LLM's output is unguaranteed to be consistent/predictable
"""
cfg_string = request.getfixturevalue(cfg_string)
orig_config = Config().from_str(cfg_string)
nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True)
text = "Marc and Bob both live in Ireland."
doc = nlp(text)
assert len(doc.ents) > 0
for ent in doc.ents:
assert ent.label_ in ["PER", "ORG", "LOC"]
@pytest.mark.external
@pytest.mark.parametrize("cfg_string", ["zeroshot_cfg_string", "fewshot_cfg_string"])
def test_ner_io(cfg_string, request):
cfg_string = request.getfixturevalue(cfg_string)
orig_config = Config().from_str(cfg_string)
nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True)
assert nlp.pipe_names == ["llm"]
# ensure you can save a pipeline to disk and run it after loading
with make_tempdir() as tmpdir:
nlp.to_disk(tmpdir)
nlp2 = spacy.load(tmpdir)
assert nlp2.pipe_names == ["llm"]
text = "Marc and Bob both live in Ireland."
doc = nlp2(text)
assert len(doc.ents) > 0
for ent in doc.ents:
assert ent.label_ in ["PER", "ORG", "LOC"]
@pytest.mark.parametrize(
"text,input_strings,result_strings,result_offsets",
[
(
"Felipe and Jaime went to the library.",
["Felipe", "Jaime", "library"],
["Felipe", "Jaime", "library"],
[(0, 6), (11, 16), (29, 36)],
), # simple
(
"The Manila Observatory was founded in 1865 in Manila.",
["Manila", "The Manila Observatory"],
["Manila", "Manila", "The Manila Observatory"],
[(4, 10), (46, 52), (0, 22)],
), # overlapping and duplicated
(
"Take the road from downtown and turn left at the public market.",
["public market", "downtown"],
["public market", "downtown"],
[(49, 62), (19, 27)]
# flipped
),
],
)
def test_ensure_offsets_correspond_to_substrings(
text, input_strings, result_strings, result_offsets
):
offsets = find_substrings(text, input_strings)
# Compare strings instead of offsets, but we need to get
# those strings first from the text
assert result_offsets == offsets
found_substrings = [text[start:end] for start, end in offsets]
assert result_strings == found_substrings
@pytest.mark.parametrize(
"text,response,gold_ents",
[
# simple
(
"Jean Jacques and Jaime went to the library.",
"PER: Jean Jacques, Jaime\nLOC: library",
[("Jean Jacques", "PER"), ("Jaime", "PER"), ("library", "LOC")],
),
# overlapping: should only return the longest span
(
"The Manila Observatory was founded in 1865.",
"LOC: The Manila Observatory, Manila, Manila Observatory",
[("The Manila Observatory", "LOC")],
),
# flipped: order shouldn't matter
(
"Take the road from Downtown and turn left at the public market.",
"LOC: public market, Downtown",
[("Downtown", "LOC"), ("public market", "LOC")],
),
],
)
def test_ner_zero_shot_task(text, response, gold_ents):
labels = "PER,ORG,LOC"
llm_ner = NERTask(labels=labels)
# Prepare doc
nlp = spacy.blank("xx")
doc_in = nlp.make_doc(text)
# Pass to the parser
# Note: parser() returns a list so we get what's inside
doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0]
pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents]
assert pred_ents == gold_ents
@pytest.mark.parametrize(
"response,normalizer,gold_ents",
[
(
"PER: Jean Jacques, Jaime",
None,
[("Jean Jacques", "PER"), ("Jaime", "PER")],
),
(
"PER: Jean Jacques, Jaime",
strip_normalizer(),
[("Jean Jacques", "PER"), ("Jaime", "PER")],
),
(
"PER: Jean Jacques, Jaime",
lowercase_normalizer(),
[("Jean Jacques", "PER"), ("Jaime", "PER")],
),
(
"per: Jean Jacques, Jaime",
None,
[],
),
(
"per: Jean Jacques\nPER: Jaime",
lowercase_normalizer(),
[("Jean Jacques", "PER"), ("Jaime", "PER")],
),
(
"per: Jean Jacques, Jaime\nOrg: library",
lowercase_normalizer(),
[("Jean Jacques", "PER"), ("Jaime", "PER"), ("library", "ORG")],
),
(
"per: Jean Jacques, Jaime\nRANDOM: library",
lowercase_normalizer(),
[("Jean Jacques", "PER"), ("Jaime", "PER")],
),
],
)
def test_ner_labels(response, normalizer, gold_ents):
text = "Jean Jacques and Jaime went to the library."
labels = "PER,ORG,LOC"
llm_ner = NERTask(labels=labels, normalizer=normalizer)
# Prepare doc
nlp = spacy.blank("xx")
doc_in = nlp.make_doc(text)
# Pass to the parser
# Note: parser() returns a list
doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0]
pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents]
assert pred_ents == gold_ents
@pytest.mark.parametrize(
"response,alignment_mode,gold_ents",
[
(
"PER: Jacq",
"strict",
[],
),
(
"PER: Jacq",
"contract",
[],
),
(
"PER: Jacq",
"expand",
[("Jacques", "PER")],
),
(
"PER: Jean J",
"contract",
[("Jean", "PER")],
),
(
"PER: Jean Jacques, aim",
"strict",
[("Jean Jacques", "PER")],
),
(
"PER: random",
"expand",
[],
),
],
)
def test_ner_alignment(response, alignment_mode, gold_ents):
text = "Jean Jacques and Jaime went to the library."
labels = "PER,ORG,LOC"
llm_ner = NERTask(labels=labels, alignment_mode=alignment_mode)
# Prepare doc
nlp = spacy.blank("xx")
doc_in = nlp.make_doc(text)
# Pass to the parser
# Note: parser() returns a list
doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0]
pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents]
assert pred_ents == gold_ents
def test_invalid_alignment_mode():
labels = "PER,ORG,LOC"
with pytest.raises(ValueError, match="Unsupported alignment mode 'invalid"):
NERTask(labels=labels, alignment_mode="invalid")
@pytest.mark.parametrize(
"response,case_sensitive,single_match,gold_ents",
[
(
"PER: Jean",
False,
False,
[("jean", "PER"), ("Jean", "PER"), ("Jean", "PER")],
),
(
"PER: Jean",
False,
True,
[("jean", "PER")],
),
(
"PER: Jean",
True,
False,
[("Jean", "PER"), ("Jean", "PER")],
),
(
"PER: Jean",
True,
True,
[("Jean", "PER")],
),
],
)
def test_ner_matching(response, case_sensitive, single_match, gold_ents):
text = "This guy jean (or Jean) is the president of the Jean Foundation."
labels = "PER,ORG,LOC"
llm_ner = NERTask(
labels=labels, case_sensitive_matching=case_sensitive, single_match=single_match
)
# Prepare doc
nlp = spacy.blank("xx")
doc_in = nlp.make_doc(text)
# Pass to the parser
# Note: parser() returns a list
doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0]
pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents]
assert pred_ents == gold_ents
def test_jinja_template_rendering_without_examples():
"""Test if jinja template renders as we expected
We apply the .strip() method for each prompt so that we don't have to deal
with annoying newlines and spaces at the edge of the text.
"""
labels = "PER,ORG,LOC"
nlp = spacy.blank("xx")
doc = nlp.make_doc("Alice and Bob went to the supermarket")
llm_ner = NERTask(labels=labels, examples=None)
prompt = list(llm_ner.generate_prompts([doc]))[0]
assert (
prompt.strip()
== """
From the text below, extract the following entities in the following format:
PER: <comma delimited list of strings>
ORG: <comma delimited list of strings>
LOC: <comma delimited list of strings>
Here is the text that needs labeling:
Text:
'''
Alice and Bob went to the supermarket
'''
""".strip()
)
@pytest.mark.parametrize(
"examples_path",
[
str(EXAMPLES_DIR / "ner_examples.json"),
str(EXAMPLES_DIR / "ner_examples.yml"),
str(EXAMPLES_DIR / "ner_examples.jsonl"),
],
)
def test_jinja_template_rendering_with_examples(examples_path):
"""Test if jinja2 template renders as expected
We apply the .strip() method for each prompt so that we don't have to deal
with annoying newlines and spaces at the edge of the text.
"""
labels = "PER,ORG,LOC"
nlp = spacy.blank("xx")
doc = nlp.make_doc("Alice and Bob went to the supermarket")
examples = fewshot_reader(examples_path)
llm_ner = NERTask(labels=labels, examples=examples)
prompt = list(llm_ner.generate_prompts([doc]))[0]
assert (
prompt.strip()
== """
From the text below, extract the following entities in the following format:
PER: <comma delimited list of strings>
ORG: <comma delimited list of strings>
LOC: <comma delimited list of strings>
Below are some examples (only use these as a guide):
Text:
'''
Jack and Jill went up the hill.
'''
PER: Jack, Jill
LOC: hill
Text:
'''
Jack fell down and broke his crown.
'''
PER: Jack
Text:
'''
Jill came tumbling after.
'''
PER: Jill
Here is the text that needs labeling:
Text:
'''
Alice and Bob went to the supermarket
'''""".strip()
)