-
Notifications
You must be signed in to change notification settings - Fork 24.3k
/
XLMRobertaTokenizer.java
299 lines (258 loc) · 10.2 KB
/
XLMRobertaTokenizer.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.XLMRobertaTokenization;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import java.io.IOException;
import java.io.Reader;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.OptionalInt;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class XLMRobertaTokenizer extends NlpTokenizer {
public static final String UNKNOWN_TOKEN = "<unk>";
public static final String SEPARATOR_TOKEN = "</s>";
public static final String PAD_TOKEN = "<pad>";
public static final String CLASS_TOKEN = "<s>";
public static final String MASK_TOKEN = XLMRobertaTokenization.MASK_TOKEN;
private static final Set<String> NEVER_SPLIT = Set.of(MASK_TOKEN);
private final XLMAnalyzer xlmAnalyzer;
protected final List<String> originalVocab;
// TODO Not sure this needs to be a sorted map
private final SortedMap<String, Integer> vocab;
protected final boolean withSpecialTokens;
protected final int sepTokenId;
private final int clsTokenId;
protected final int padTokenId;
private final int maxSequenceLength;
protected XLMRobertaTokenizer(
List<String> originalVocab,
SortedMap<String, Integer> vocab,
List<Double> scores,
boolean withSpecialTokens,
int maxSequenceLength,
Set<String> neverSplit
) throws IOException {
this.originalVocab = originalVocab;
this.xlmAnalyzer = new XLMAnalyzer(originalVocab, scores, new ArrayList<>(Sets.union(NEVER_SPLIT, neverSplit)), UNKNOWN_TOKEN);
this.vocab = vocab;
this.withSpecialTokens = withSpecialTokens;
this.maxSequenceLength = maxSequenceLength;
if (vocab.containsKey(UNKNOWN_TOKEN) == false) {
throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", UNKNOWN_TOKEN);
}
if (vocab.containsKey(PAD_TOKEN) == false) {
throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required [{}] token", PAD_TOKEN);
}
this.padTokenId = vocab.get(PAD_TOKEN);
if (withSpecialTokens) {
Set<String> missingSpecialTokens = Sets.difference(Set.of(SEPARATOR_TOKEN, CLASS_TOKEN), vocab.keySet());
if (missingSpecialTokens.isEmpty() == false) {
throw ExceptionsHelper.conflictStatusException("stored vocabulary is missing required {} token(s)", missingSpecialTokens);
}
this.sepTokenId = vocab.get(SEPARATOR_TOKEN);
this.clsTokenId = vocab.get(CLASS_TOKEN);
} else {
this.sepTokenId = -1;
this.clsTokenId = -1;
}
}
@Override
int sepTokenId() {
return sepTokenId;
}
@Override
int maxSequenceLength() {
return maxSequenceLength;
}
@Override
boolean isWithSpecialTokens() {
return withSpecialTokens;
}
@Override
int getNumExtraTokensForSeqPair() {
return 4;
}
@Override
int defaultSpanForChunking(int maxWindowSize) {
return (maxWindowSize - numExtraTokensForSingleSequence()) / 2;
}
@Override
int numExtraTokensForSingleSequence() {
return 2;
}
@Override
int clsTokenId() {
return clsTokenId;
}
public String getPadToken() {
return PAD_TOKEN;
}
public String getUnknownToken() {
return UNKNOWN_TOKEN;
}
@Override
public void close() {
this.xlmAnalyzer.close();
}
@Override
public TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokens> tokenizations) {
return new XLMRobertaTokenizationResult(originalVocab, tokenizations, padTokenId);
}
@Override
public NlpTask.RequestBuilder requestBuilder() {
return (inputs, requestId, truncate, span, windowSize) -> buildTokenizationResult(
IntStream.range(0, inputs.size())
.boxed()
.flatMap(seqId -> tokenize(inputs.get(seqId), truncate, span, seqId, windowSize).stream())
.collect(Collectors.toList())
).buildRequest(requestId, truncate);
}
@Override
public OptionalInt getPadTokenId() {
return OptionalInt.of(padTokenId);
}
@Override
public OptionalInt getMaskTokenId() {
Integer maskId = vocab.get(MASK_TOKEN);
if (maskId == null) {
return OptionalInt.empty();
}
return OptionalInt.of(maskId);
}
@Override
public String getMaskToken() {
return MASK_TOKEN;
}
@Override
public List<String> getVocabulary() {
return originalVocab;
}
@Override
TokenizationResult.TokensBuilder createTokensBuilder(int clsTokenId, int sepTokenId, boolean withSpecialTokens) {
return new XLMRobertaTokenizationResult.XLMRobertaTokensBuilder(withSpecialTokens, clsTokenId, sepTokenId);
}
/**
* @param seq cannot be null
* @return InnerTokenization
*/
@Override
public InnerTokenization innerTokenize(String seq) {
List<Integer> tokenPositionMap = new ArrayList<>();
try (TokenStream ts = xlmAnalyzer.tokenStream("input", seq)) {
ts.reset();
PositionIncrementAttribute tokenPos = ts.addAttribute(PositionIncrementAttribute.class);
int currPos = -1;
while (ts.incrementToken()) {
currPos += tokenPos.getPositionIncrement();
tokenPositionMap.add(currPos);
}
} catch (IOException ex) {
throw new UncheckedIOException(ex);
}
return new InnerTokenization(new ArrayList<>(xlmAnalyzer.getTokens()), tokenPositionMap);
}
public static Builder builder(List<String> vocab, List<Double> scores, XLMRobertaTokenization tokenization) {
return new Builder(vocab, scores, tokenization);
}
public static class Builder {
protected final List<String> originalVocab;
protected final List<Double> scores;
protected final SortedMap<String, Integer> vocab;
protected boolean withSpecialTokens;
protected int maxSequenceLength;
protected Set<String> neverSplit;
protected Builder(List<String> vocab, List<Double> scores, XLMRobertaTokenization tokenization) {
this.originalVocab = vocab;
this.vocab = buildSortedVocab(vocab);
this.scores = scores;
this.withSpecialTokens = tokenization.withSpecialTokens();
this.maxSequenceLength = tokenization.maxSequenceLength();
}
private static SortedMap<String, Integer> buildSortedVocab(List<String> vocab) {
SortedMap<String, Integer> sortedVocab = new TreeMap<>();
for (int i = 0; i < vocab.size(); i++) {
sortedVocab.put(vocab.get(i), i);
}
return sortedVocab;
}
public Builder setNeverSplit(Set<String> neverSplit) {
this.neverSplit = neverSplit;
return this;
}
public Builder setMaxSequenceLength(int maxSequenceLength) {
this.maxSequenceLength = maxSequenceLength;
return this;
}
/**
* Include CLS and SEP tokens
* @param withSpecialTokens if true include CLS and SEP tokens
* @return this
*/
public Builder setWithSpecialTokens(boolean withSpecialTokens) {
this.withSpecialTokens = withSpecialTokens;
return this;
}
public XLMRobertaTokenizer build() throws IOException {
if (neverSplit == null) {
neverSplit = Collections.emptySet();
}
return new XLMRobertaTokenizer(originalVocab, vocab, scores, withSpecialTokens, maxSequenceLength, neverSplit);
}
}
static class XLMAnalyzer extends Analyzer {
private final List<String> vocabulary;
private final List<String> neverSplit;
private final double[] scores;
private UnigramTokenizer innerTokenizer;
private final String unknownToken;
private final PrecompiledCharMapNormalizer.Config normalizer;
XLMAnalyzer(List<String> vocabulary, List<Double> scores, List<String> neverSplit, String unknownToken) throws IOException {
this.vocabulary = vocabulary;
this.neverSplit = neverSplit;
this.unknownToken = unknownToken;
this.scores = new double[scores.size()];
int i = 0;
for (Double s : scores) {
this.scores[i++] = s;
}
normalizer = PrecompiledCharMapNormalizer.fromBase64EncodedResource(
"/org/elasticsearch/xpack/ml/inference.nlp.tokenizers/spm_precompiled_normalizer.txt"
);
}
@Override
protected Reader initReader(String fieldName, Reader reader) {
if (normalizer.offsets().length > 0) {
return new PrecompiledCharMapNormalizer(normalizer.offsets(), normalizer.utf8str(), reader);
}
return reader;
}
@Override
protected TokenStreamComponents createComponents(String fieldName) {
this.innerTokenizer = UnigramTokenizer.build(neverSplit, vocabulary, scores, unknownToken);
return new TokenStreamComponents(this.innerTokenizer);
}
public List<DelimitedToken.Encoded> getTokens() {
if (innerTokenizer != null) {
return innerTokenizer.getTokenizedValues();
} else {
return List.of();
}
}
}
}