Skip to content

Commit

Permalink
Vocabulary typing, fix for SentencePieces, faster
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Apr 26, 2024
1 parent 3f055e6 commit 24ba9b6
Showing 1 changed file with 50 additions and 54 deletions.
104 changes: 50 additions & 54 deletions returnn/datasets/util/vocabulary.py
Expand Up @@ -13,9 +13,9 @@
"Utf8ByteTargets",
]

import sys
from typing import Optional, Union, Type, List
import typing

import sys
import numpy

from returnn.log import log
Expand Down Expand Up @@ -107,13 +107,13 @@ def __repr__(self):
parts.append("pad_label=%r" % self.id_to_label(self.pad_label_id))
return "%s(%s)" % (self.__class__.__name__, ", ".join(parts))

def set_random_seed(self, seed):
def set_random_seed(self, seed: int):
"""
This can be called for a new epoch or so.
Usually it has no effect, as there is no randomness.
However, some vocab class could introduce some sampling process.
:param int seed:
:param seed:
"""
pass # usually there is no randomness, so ignore

Expand Down Expand Up @@ -205,12 +205,16 @@ def init_vocab_var(session):

return init_vocab_var

def to_id(self, label, default=KeyError, allow_none=False):
def to_id(
self,
label: Union[str, int, None],
default: Union[str, Type[KeyError], None] = KeyError,
allow_none: bool = False,
) -> Optional[int]:
"""
:param str|int|None label:
:param str|type[KeyError]|None default:
:param bool allow_none: whether label can be None. in this case, None is returned
:rtype: int|None
:param label:
:param default:
:param allow_none: whether label can be None. in this case, None is returned
"""
if isinstance(label, str):
return self.label_to_id(label, default=default)
Expand All @@ -226,65 +230,63 @@ def to_id(self, label, default=KeyError, allow_none=False):
return None
raise TypeError("invalid label type %r" % type(label))

def label_to_id(self, label, default=KeyError):
def label_to_id(self, label: str, default: Union[int, Type[KeyError], None] = KeyError) -> Optional[int]:
"""
:param str label:
:param int|type[KeyError]|None default:
:rtype: int|None
:param label:
:param default:
"""
if default is KeyError:
return self._vocab[label]
return self._vocab.get(label, default)

def id_to_label(self, idx, default=KeyError):
def id_to_label(self, idx: int, default: Union[str, Type[KeyError], None] = KeyError) -> Optional[str]:
"""
:param int idx:
:param str|KeyError|None default:
:rtype: str|None
:param idx:
:param default:
"""
if self.is_id_valid(idx):
return self._labels[idx]
if default is KeyError:
raise KeyError("idx %i out of range" % idx)
return default

def is_id_valid(self, idx):
def is_id_valid(self, idx: int) -> bool:
"""
:param int idx:
:rtype: bool
:param idx:
"""
return 0 <= idx < len(self._labels)

@property
def labels(self):
"""
:rtype: list[str]
"""
def labels(self) -> List[str]:
"""list of labels"""
return self._labels

def get_seq(self, sentence):
def get_seq(self, sentence: str) -> List[int]:
"""
:param str sentence: assumed to be seq of vocab entries separated by whitespace
:rtype: list[int]
:param sentence: assumed to be seq of vocab entries separated by whitespace
:return: seq of label indices
"""
segments = sentence.split()
return self.get_seq_indices(segments) + self.seq_postfix

def get_seq_indices(self, seq):
def get_seq_indices(self, seq: List[str]) -> List[int]:
"""
:param list[str] seq:
:rtype: list[int]
:param seq: seq of labels (entries in vocab)
:return: seq of label indices, returns unknown_label_id if unknown_label is set
"""
if self.unknown_label is not None:
return [self._vocab.get(k, self.unknown_label_id) for k in seq]
return [self._vocab[k] for k in seq]

def get_seq_labels(self, seq):
def get_seq_labels(self, seq: Union[List[int], numpy.ndarray]) -> str:
"""
:param list[int]|numpy.ndarray seq: 1D sequence
:rtype: str
Inverse of :func:`get_seq`.
:param seq: 1D sequence of label indices
:return: serialized sequence string, such that ``get_seq(get_seq_labels(seq)) == seq``
"""
return " ".join(map(self._labels.__getitem__, seq))
labels = self.labels
return " ".join(map(labels.__getitem__, seq))


class BytePairEncoding(Vocabulary):
Expand Down Expand Up @@ -421,10 +423,8 @@ def _parse_vocab(self):
# Do not load labels/vocab here. This is not really needed.

@property
def labels(self):
"""
:rtype: list[str]
"""
def labels(self) -> List[str]:
"""list of labels"""
if self._cache_key and self._cache_key in self._cache:
self._vocab, self._labels = self._cache[self._cache_key]
assert self.num_labels == len(self._vocab) == len(self._labels)
Expand All @@ -435,28 +435,25 @@ def labels(self):
self._cache[self._cache_key] = (self._vocab, self._labels)
return self._labels

def is_id_valid(self, idx):
def is_id_valid(self, idx: int) -> bool:
"""
:param int idx:
:rtype: bool
:param idx:
"""
return not self.sp.IsUnused(idx)

def id_to_label(self, idx, default=KeyError):
def id_to_label(self, idx: int, default: Union[str, Type[KeyError], None] = KeyError) -> Optional[str]:
"""
:param int idx:
:param str|KeyError|None default:
:rtype: str|None
:param idx:
:param default:
"""
if default is not KeyError and not self.is_id_valid(idx):
return default
return self.sp.IdToPiece(idx)

def label_to_id(self, label, default=KeyError):
def label_to_id(self, label: str, default: Union[int, Type[KeyError], None] = KeyError) -> Optional[int]:
"""
:param str label:
:param int|type[KeyError]|None default:
:rtype: int|None
:param label:
:param default:
"""
res = self.sp.PieceToId(label)
if res == self.unknown_label_id or res < 0 or res is None:
Expand All @@ -468,9 +465,9 @@ def label_to_id(self, label, default=KeyError):
return default
return res

def set_random_seed(self, seed):
def set_random_seed(self, seed: int):
"""
:param int seed:
:param seed:
"""
# Unfortunately, there is only a global seed,
# and also, it will only be used for new threads
Expand All @@ -480,10 +477,9 @@ def set_random_seed(self, seed):

spm.set_random_generator_seed(seed)

def get_seq(self, sentence):
def get_seq(self, sentence: str) -> List[int]:
"""
:param str sentence: assumed to be seq of vocab entries separated by whitespace
:rtype: list[int]
:param sentence: assumed to be seq of vocab entries separated by whitespace
"""
return self.sp.encode(sentence, out_type=int) # noqa

Expand Down

0 comments on commit 24ba9b6

Please sign in to comment.