Skip to content

Commit

Permalink
Matcher support for Span, as well as Doc explosion#5056
Browse files Browse the repository at this point in the history
  • Loading branch information
paoloq committed Mar 5, 2020
1 parent 2164e71 commit c52fdba
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
37 changes: 23 additions & 14 deletions spacy/matcher/matcher.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ from murmurhash.mrmr cimport hash64

import re
import srsly
import ctypes

from ..typedefs cimport attr_t
from ..structs cimport TokenC
from ..vocab cimport Vocab
from ..tokens.doc cimport Doc, get_token_attr
from ..tokens.span cimport Span
from ..tokens.token cimport Token
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA

Expand Down Expand Up @@ -211,22 +213,31 @@ cdef class Matcher:
else:
yield doc

def __call__(self, Doc doc):
def __call__(self, object doc_or_span):
"""Find all token sequences matching the supplied pattern.
doc (Doc): The document to match over.
RETURNS (list): A list of `(key, start, end)` tuples,
describing the matches. A match tuple describes a span
`doc[start:end]`. The `label_id` and `key` are both integers.
"""
if isinstance(doc_or_span, Doc):
doc = doc_or_span
elif isinstance(doc_or_span, Span):
doc = doc_or_span.doc
else:
raise ValueError()
if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \
and not doc.is_tagged:
raise ValueError(Errors.E155.format())
if DEP in self._seen_attrs and not doc.is_parsed:
raise ValueError(Errors.E156.format())
matches = find_matches(&self.patterns[0], self.patterns.size(), doc,
extensions=self._extensions,
predicates=self._extra_predicates)
length = (
(doc_or_span.end - doc_or_span.start)
if isinstance(doc_or_span, Span) else len(doc)
)
matches = find_matches(&self.patterns[0], self.patterns.size(), doc_or_span, length,
extensions=self._extensions, predicates=self._extra_predicates)
for i, (key, start, end) in enumerate(matches):
on_match = self._callbacks.get(key, None)
if on_match is not None:
Expand All @@ -248,9 +259,7 @@ def unpickle_matcher(vocab, patterns, callbacks):
return matcher



cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
predicates=tuple()):
cdef find_matches(TokenPatternC** patterns, int n, object doc_or_span, int length, extensions=None, predicates=tuple()):
"""Find matches in a doc, with a compiled array of patterns. Matches are
returned as a list of (id, start, end) tuples.
Expand All @@ -268,30 +277,30 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
cdef int i, j, nr_extra_attr
cdef Pool mem = Pool()
output = []
if doc.length == 0:
if length == 0:
# avoid any processing or mem alloc if the document is empty
return output
if len(predicates) > 0:
predicate_cache = <char*>mem.alloc(doc.length * len(predicates), sizeof(char))
predicate_cache = <char*>mem.alloc(length * len(predicates), sizeof(char))
if extensions is not None and len(extensions) >= 1:
nr_extra_attr = max(extensions.values()) + 1
extra_attr_values = <attr_t*>mem.alloc(doc.length * nr_extra_attr, sizeof(attr_t))
extra_attr_values = <attr_t*>mem.alloc(length * nr_extra_attr, sizeof(attr_t))
else:
nr_extra_attr = 0
extra_attr_values = <attr_t*>mem.alloc(doc.length, sizeof(attr_t))
for i, token in enumerate(doc):
extra_attr_values = <attr_t*>mem.alloc(length, sizeof(attr_t))
for i, token in enumerate(doc_or_span):
for name, index in extensions.items():
value = token._.get(name)
if isinstance(value, basestring):
value = token.vocab.strings[value]
extra_attr_values[i * nr_extra_attr + index] = value
# Main loop
cdef int nr_predicate = len(predicates)
for i in range(doc.length):
for i in range(length):
for j in range(n):
states.push_back(PatternStateC(patterns[j], i, 0))
transition_states(states, matches, predicate_cache,
doc[i], extra_attr_values, predicates)
doc_or_span[i], extra_attr_values, predicates)
extra_attr_values += nr_extra_attr
predicate_cache += len(predicates)
# Handle matches that end in 0-width patterns
Expand Down
10 changes: 9 additions & 1 deletion spacy/tests/matcher/test_matcher_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
from mock import Mock
from spacy.matcher import Matcher, DependencyMatcher
from spacy.tokens import Doc, Token
from spacy.tokens import Doc, Span, Token


@pytest.fixture
Expand Down Expand Up @@ -456,3 +456,11 @@ def test_matcher_callback(en_vocab):
doc = Doc(en_vocab, words=["This", "is", "a", "test", "."])
matches = matcher(doc)
mock.assert_called_once_with(matcher, doc, 0, matches)


def test_matcher_span(matcher):
text = "JavaScript is good but Python is better"
doc = Doc(matcher.vocab, words=text.split())
span = Span(doc, 0, 3)
matches = matcher(span)
assert len(matches) == 1

0 comments on commit c52fdba

Please sign in to comment.