Skip to content

Commit

Permalink
Allow text-dict style vocab on LM Dataset (#1235)
Browse files Browse the repository at this point in the history
* allow text-dict in LMDataset

* fix lm vocab loading, allow reuse of index

* replace SyntaxError by regex-test

* updated regex, use literal_eval

* replace * by + in regex
  • Loading branch information
JackTemaki committed Nov 30, 2022
1 parent 20f06f4 commit 2c0bf36
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions returnn/datasets/lm.py
Expand Up @@ -14,6 +14,7 @@
import gzip
import xml.etree.ElementTree as ElementTree
from returnn.util.basic import parse_orthography, parse_orthography_into_symbols, load_json, BackendEngine, unicode
from returnn.util.literal_py_to_pickle import literal_eval
from returnn.log import log
import numpy
import time
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(self,
:param bool skip_empty_lines: for line-based txt
:param str|()->str|None orth_symbols_file: a text file containing a list of orthography symbols
:param str|()->str|None orth_symbols_map_file: either a list of orth symbols, each line: "<symbol> <index>",
a python dict with {"<symbol>": <index>, ...}
or a pickled dictionary
:param str|()->str|None orth_replace_map_file: JSON file with replacement dict for orth symbols.
:param bool word_based: whether to parse single words, or otherwise will be character based.
Expand Down Expand Up @@ -134,14 +136,22 @@ def __init__(self,
self.seq_gen = None
elif orth_symbols_map_file:
assert not phone_info
orth_symbols_imap_list = [
(int(b), a)
for (a, b) in [
line.split(None, 1)
for line in open(orth_symbols_map_file).read().splitlines()]]
orth_symbols_imap_list.sort()
with open(orth_symbols_map_file, "r") as f:
test_string = f.read(1024).replace(" ", "").replace("\n", "")
match = re.search("^{[\"'].+[\"']:[0-9]+,", test_string)
f.seek(0)
if match is not None:
d = literal_eval(f.read())
orth_symbols_imap_list = [(int(v), k) for k, v in d.items()]
orth_symbols_imap_list.sort()
else:
orth_symbols_imap_list = [
(int(b), a)
for (a, b) in [
line.split(None, 1)
for line in f.read().splitlines()]]
orth_symbols_imap_list.sort()
assert orth_symbols_imap_list[0][0] == 0
assert orth_symbols_imap_list[-1][0] == len(orth_symbols_imap_list) - 1
self.orth_symbols_map = {sym: i for (i, sym) in orth_symbols_imap_list}
self.orth_symbols = [sym for (i, sym) in orth_symbols_imap_list]
self.labels["data"] = self.orth_symbols
Expand Down

0 comments on commit 2c0bf36

Please sign in to comment.