Skip to content

Commit

Permalink
Add Span support in tokenizer's Model abstraction (#7035)
Browse files Browse the repository at this point in the history
* Add Span support in tokenizer's Model abstraction

* Address the feedback

* Use stackalloc instead of the ArrayPool
  • Loading branch information
tarekgh committed Mar 1, 2024
1 parent c6f5397 commit 99c620a
Show file tree
Hide file tree
Showing 12 changed files with 427 additions and 311 deletions.
122 changes: 71 additions & 51 deletions src/Microsoft.ML.Tokenizers/Model/BPE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Serialization;
Expand Down Expand Up @@ -34,20 +36,21 @@ private set
{
_unknownToken = value;

if (value is null)
if (VocabReverse.TryGetValue(0, out string? v))
{
if (VocabReverse.TryGetValue(0, out string? v))
if (v == value)
{
VocabReverse.Remove(0);
if (_vocab.TryGetValue(v, out int id))
{
_vocab.Remove(v);
}
return;
}

VocabReverse.Remove(0);
_vocab.Remove(new StringSpanOrdinalKey(v));
}
else


if (value is not null)
{
_vocab[value] = 0;
_vocab[new StringSpanOrdinalKey(value)] = 0;
VocabReverse[0] = value;
}
}
Expand All @@ -68,7 +71,6 @@ private set
/// </summary>
public bool FuseUnknownTokens { get; }


/// <summary>
/// Construct a new Bpe model object to use for text encoding.
/// </summary>
Expand Down Expand Up @@ -111,23 +113,19 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri
ContinuingSubwordPrefix = continuingSubwordPrefix;
EndOfWordSuffix = endOfWordSuffix;

(Dictionary<string, int>? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
_vocab = vocab1 ?? new Dictionary<string, int>();
Cache = new Cache<string, Word>();
(Dictionary<StringSpanOrdinalKey, int>? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
_vocab = vocab1 ?? new Dictionary<StringSpanOrdinalKey, int>();
Cache = new StringSpanOrdinalKeyCache<Word>();

VocabReverse = new();

foreach (KeyValuePair<string, int> kvp in Vocab)
foreach (KeyValuePair<StringSpanOrdinalKey, int> kvp in _vocab)
{
VocabReverse.Add(kvp.Value, kvp.Key);
VocabReverse.Add(kvp.Value, kvp.Key.Data!);
}

if (unknownToken is null && VocabReverse.TryGetValue(0, out string? unkToken))
{
unknownToken = unkToken;
}

UnknownToken = unknownToken;
UnknownToken = unknownToken ?? (VocabReverse.TryGetValue(0, out string? unkToken) ? unkToken : null);

int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix.Length;

Expand Down Expand Up @@ -197,31 +195,23 @@ public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = f
/// <param name="text">The text to split.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
public override void EncodeToIds(string text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds);
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds);

/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="isSpecialToken">Indicate if the token is special token.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokens(string text, bool isSpecialToken) => EncodeToIdsWithCache(text, null);
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken) => EncodeToIdsWithCache(text, null);

/// <summary>
/// Map the token to encoded Id.
/// </summary>
/// <param name="token">The token to map to the Id.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>The mapped Id of the token.</returns>
public override int? MapTokenToId(string token, bool considerSpecialTokens = true)
{
if (_vocab.TryGetValue(token, out int value))
{
return value;
}

return null;
}
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out int value) ? value : null;

/// <summary>
/// Map the encoded Id to the token.
Expand All @@ -242,24 +232,27 @@ public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = f
/// <summary>
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
public IReadOnlyDictionary<string, int> Vocab => _vocab;
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);

/// Read the given files to extract the vocab and merges
internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges)
internal static (Dictionary<StringSpanOrdinalKey, int>?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges)
{
Dictionary<string, int>? dic = JsonSerializer.Deserialize<Dictionary<string, int>>(vocab) as Dictionary<string, int>;
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
Dictionary<StringSpanOrdinalKey, int>? dic = JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, int>>(vocab, options) as Dictionary<StringSpanOrdinalKey, int>;

return (dic, ConvertMergesToHashmap(merges));
}

/// The vocabulary assigns a number to each token.
private readonly Dictionary<string, int> _vocab;
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;

private Dictionary<string, int>? _vocabOriginal;

/// Contains the mapping between Pairs and their (rank, newId).
internal Dictionary<Pair<int>, (int, int)> Merges { get; }

/// Contains the cache for optimizing the encoding step.
internal Cache<string, Word>? Cache { get; }
internal StringSpanOrdinalKeyCache<Word>? Cache { get; }

internal static readonly int DefaultCacheCapacity = 10_000;

Expand Down Expand Up @@ -309,9 +302,6 @@ internal static Vec<(string, string)> ConvertMergesToHashmap(Stream? mergesStrea
return merges;
}

/// Reset the cache.
internal void ClearCache() => Cache?.Clear();

private readonly Dictionary<char, string> _charToString = new Dictionary<char, string>();

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand All @@ -327,38 +317,68 @@ internal string CharToString(char c)
return s;
}

internal Word MergeWord(string w)
internal Word MergeWord(ReadOnlySpan<char> w)
{
Word word = Word.WithCapacity(w.Length);
(int Id, int Len)? unk = null;
int i = 0;

Span<char> buffer = stackalloc char[256];
scoped ReadOnlySpan<char> s;

while (i < w.Length)
{
int length;
string s;

if (Char.IsHighSurrogate(w[i]) && i < w.Length - 1 && Char.IsLowSurrogate(w[i + 1]))
{
length = 2;
s = w.Substring(i, length);
s = w.Slice(i, 2);
}
else
{
length = 1;
s = CharToString(w[i]);
s = w.Slice(i, 1);
}

// Add the `continuing_subword_prefix` if relevant
if (i > 0 && ContinuingSubwordPrefix is not null)
{
s = $"{ContinuingSubwordPrefix}{s}";
if (ContinuingSubwordPrefix.Length + s.Length <= buffer.Length)
{
ContinuingSubwordPrefix.AsSpan().CopyTo(buffer);
s.CopyTo(buffer.Slice(ContinuingSubwordPrefix.Length));
s = buffer.Slice(0, ContinuingSubwordPrefix.Length + s.Length);
}
else
{
#if NETCOREAPP
s = $"{ContinuingSubwordPrefix}{s}".AsSpan();
#else
string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString();
s = $"{ContinuingSubwordPrefix}{s1}".AsSpan();
#endif
}
}

// Add the `end_of_word_suffix` if relevant
if (i + length >= w.Length && EndOfWordSuffix is not null)
{
s = $"{s}{EndOfWordSuffix}";
if (s.Length + EndOfWordSuffix.Length <= buffer.Length)
{
s.CopyTo(buffer);
EndOfWordSuffix.AsSpan().CopyTo(buffer.Slice(s.Length));
s = buffer.Slice(0, s.Length + EndOfWordSuffix.Length);
}
else
{
#if NETCOREAPP
s = $"{s}{EndOfWordSuffix}".AsSpan();
#else
string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString();
s = $"{s1}{EndOfWordSuffix}".AsSpan();
#endif
}
}

if (_vocab.TryGetValue(s, out int id))
Expand Down Expand Up @@ -419,17 +439,17 @@ internal List<Token> EncodeWithCache(string text)
Word word;
if (Cache is not null)
{
if (Cache.TryGet(text, out word))
if (Cache.TryGetValue(text, out word))
{
return WordToTokens(ref word);
}

word = MergeWord(text);
word = MergeWord(text.AsSpan());
Cache.Set(text, word);
}
else
{
word = MergeWord(text);
word = MergeWord(text.AsSpan());
}

return WordToTokens(ref word);
Expand All @@ -445,19 +465,19 @@ internal int WordToIds(ref Word word, IList<int>? accumulatedIds)
return word.SymbolsCount;
}

internal int EncodeToIdsWithCache(string text, IList<int>? accumulatedIds)
internal int EncodeToIdsWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds)
{
Word word;

if (Cache is not null)
{
if (Cache.TryGet(text, out Word hit))
if (Cache.TryGetValue(text, out Word hit))
{
return WordToIds(ref hit, accumulatedIds);
}

word = MergeWord(text);
Cache.Set(text, word);
Cache.Set(text.ToString(), word);
}
else
{
Expand Down

0 comments on commit 99c620a

Please sign in to comment.