Skip to content

Commit

Permalink
More tokenizer's APIs cleanup (#7110)
Browse files Browse the repository at this point in the history
* More tokenizer's APIs cleanup

* Address the feedback
  • Loading branch information
tarekgh committed Apr 3, 2024
1 parent 79b5475 commit 0fd58cb
Show file tree
Hide file tree
Showing 10 changed files with 291 additions and 140 deletions.
98 changes: 77 additions & 21 deletions src/Microsoft.ML.Tokenizers/Model/BPE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;

Expand All @@ -22,6 +23,7 @@ public sealed class Bpe : Model

private const int MaxWordLengthToCache = 15;
private string? _unknownToken;
private int? _unknownTokenId;

/// <summary>
/// Gets or Sets unknown token. The unknown token to be used when we encounter an unknown char
Expand All @@ -35,25 +37,20 @@ public sealed class Bpe : Model

private set
{
_unknownToken = value;

if (VocabReverse.TryGetValue(0, out string? v))
if (value is null)
{
if (v == value)
{
return;
}

VocabReverse.Remove(0);
_vocab.Remove(new StringSpanOrdinalKey(v));
_unknownToken = value;
_unknownTokenId = null;
return;
}


if (value is not null)
if (!_vocab.TryGetValue(value, out int id))
{
_vocab[new StringSpanOrdinalKey(value)] = 0;
VocabReverse[0] = value;
throw new InvalidOperationException($"Unknown Token '{value}' was not present in '{nameof(Vocab)}'.");
}

_unknownTokenId = id;
_unknownToken = value;
}
}

Expand Down Expand Up @@ -125,8 +122,7 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri
VocabReverse.Add(kvp.Value, kvp.Key.Data!);
}


UnknownToken = unknownToken ?? (VocabReverse.TryGetValue(0, out string? unkToken) ? unkToken : null);
UnknownToken = unknownToken;

int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix.Length;

Expand Down Expand Up @@ -169,11 +165,6 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri
}
}

/// <summary>
/// Gets the Bpe decoder object.
/// </summary>
public static TokenizerDecoder Decoder { get; } = new BpeDecoder();

/// <summary>
/// Encode a text string to a list of tokens.
/// </summary>
Expand Down Expand Up @@ -244,6 +235,71 @@ public override IReadOnlyList<Token> Encode(ReadOnlySpan<char> text)
/// </summary>
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);

/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <returns>The decoded string.</returns>
public override string? Decode(IEnumerable<int> ids) => Decode(ids, considerSpecialTokens: true);

/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="considerSpecialTokens">Indicate whether to consider special tokens or not.</param>
/// <returns>The decoded string.</returns>
public string? Decode(IEnumerable<int> ids, bool considerSpecialTokens)
{
if (ids is null)
{
throw new ArgumentNullException(nameof(ids));
}

ValueStringBuilder sb = new ValueStringBuilder();

bool decodeUnknownToken = _unknownTokenId.HasValue && considerSpecialTokens;

if (decodeUnknownToken)
{
foreach (int id in ids)
{
if (MapIdToToken(id) is string s)
{
sb.Append(s);
}
}
}
else
{
foreach (int id in ids)
{
if (id == _unknownTokenId)
{
continue;
}

if (MapIdToToken(id) is string s)
{
sb.Append(s);
}
}
}

if (EndOfWordSuffix is not null)
{
sb.RemoveSuffix(EndOfWordSuffix);

sb.Replace(EndOfWordSuffix, " ");
}

if (ContinuingSubwordPrefix is not null)
{
sb.Replace(ContinuingSubwordPrefix, string.Empty);
}

return sb.ToString();
}

/// Read the given files to extract the vocab and merges
internal static (Dictionary<StringSpanOrdinalKey, int>?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges)
{
Expand Down
50 changes: 0 additions & 50 deletions src/Microsoft.ML.Tokenizers/Model/BPEDecoder.cs

This file was deleted.

19 changes: 14 additions & 5 deletions src/Microsoft.ML.Tokenizers/Model/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,21 +151,30 @@ public virtual int CountTokensFromEnd(ReadOnlySpan<char> text, out int textIndex
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="decoder">The optional Decoder to merge the given list of tokens in a string.</param>
/// <returns>The decoded string.</returns>
public virtual string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null)
/// <remarks>
/// This method does the default implementation that uses the MapIdToToken method to get the token.
/// Tokenizer models may opt to override this method to ensure accurate results if the default implementation
/// provided here proves insufficient for the model's specific scenario.
/// </remarks>
public virtual string? Decode(IEnumerable<int> ids)
{
List<string> tokens = new List<string>();
if (ids is null)
{
throw new ArgumentNullException(nameof(ids));
}

ValueStringBuilder sb = new ValueStringBuilder();

foreach (int id in ids)
{
if (MapIdToToken(id) is string s)
{
tokens.Add(s);
sb.Append(s);
}
}

return decoder?.Decode(tokens) ?? string.Concat(tokens);
return sb.ToString();
}
}
}
6 changes: 1 addition & 5 deletions src/Microsoft.ML.Tokenizers/Model/SentencePieceBpe.cs
Original file line number Diff line number Diff line change
Expand Up @@ -892,12 +892,8 @@ bool SegmentFromEnd((int Index, int Length) pieceSpan, ReadOnlySpan<char> text,
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="decoder">The optional Decoder to merge the given list of tokens in a string.</param>
/// <returns>The decoded string.</returns>
/// <remarks>
/// The decoder is not used here because the SentencePiece Bpe model knows how to decode the ids in additions to avoid any performance overhead.
/// </remarks>
public override string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null)
public override string? Decode(IEnumerable<int> ids)
{
if (ids is null)
{
Expand Down
3 changes: 1 addition & 2 deletions src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
Original file line number Diff line number Diff line change
Expand Up @@ -573,9 +573,8 @@ public override int CountTokensFromEnd(ReadOnlySpan<char> text, out int textInde
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="decoder">The optional Decoder to merge the given list of tokens in a string.</param>
/// <returns>The decoded string.</returns>
public override string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null)
public override string? Decode(IEnumerable<int> ids)
{
// Tiktoken doesn't guarantee a one-to-one correspondence between IDs and UTF-16 words.
// Consequently, decoding individual IDs into UTF-16 string is not supported; instead, decoding all IDs must be performed collectively.
Expand Down
1 change: 0 additions & 1 deletion src/Microsoft.ML.Tokenizers/PACKAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ The main types provided by this library are:
* `Microsoft.ML.Tokenizers.Bpe`
* `Microsoft.ML.Tokenizers.EnglishRoberta`
* `Microsoft.ML.Tokenizers.Tiktoken`
* `Microsoft.ML.Tokenizers.TokenizerDecoder`
* `Microsoft.ML.Tokenizers.Normalizer`
* `Microsoft.ML.Tokenizers.PreTokenizer`

Expand Down
11 changes: 2 additions & 9 deletions src/Microsoft.ML.Tokenizers/Tokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,11 @@ public partial class Tokenizer
/// <param name="model">The Model in use by the Tokenizer.</param>
/// <param name="preTokenizer">The optional PreTokenizer in use by the Tokenizer. WhiteSpace PreTokenizer will be used if this parameter is null.</param>
/// <param name="normalizer">The optional Normalizer in use by the Tokenizer.</param>
/// <param name="decoder">The optional Decoder in use by the Tokenizer during the decoding operation to merge the given list of tokens in a string.</param>
public Tokenizer(Model model, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, TokenizerDecoder? decoder = null)
public Tokenizer(Model model, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null)
{
Model = model;
PreTokenizer = preTokenizer ?? WhiteSpace.Instance;
Normalizer = normalizer;
Decoder = decoder;
}

/// <summary>
Expand All @@ -49,11 +47,6 @@ public Tokenizer(Model model, PreTokenizer? preTokenizer = null, Normalizer? nor
/// </summary>
public Normalizer? Normalizer { get; }

/// <summary>
/// Gets or sets the Decoder in use by the Tokenizer.
/// </summary>
public TokenizerDecoder? Decoder { get; }

/// <summary>
/// Encodes input text to object has the tokens list, tokens Ids, tokens offset mapping.
/// </summary>
Expand Down Expand Up @@ -278,7 +271,7 @@ private int LastIndexOf(string text, int maxTokenCount, out string processedText
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <returns>The decoded string.</returns>
public string? Decode(IEnumerable<int> ids) => Model.Decode(ids, Decoder);
public string? Decode(IEnumerable<int> ids) => Model.Decode(ids);

/// <summary>
/// Create a Tiktoken tokenizer based on model name and vocab file.
Expand Down
23 changes: 0 additions & 23 deletions src/Microsoft.ML.Tokenizers/TokenizerDecoder.cs

This file was deleted.

60 changes: 60 additions & 0 deletions src/Microsoft.ML.Tokenizers/Utils/ValueStringBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,66 @@ public bool TryCopyTo(Span<char> destination, out int charsWritten)
}
}

public void Replace(string oldValue, string newValue)
{
int oldLength = oldValue.Length;
int newLength = newValue.Length;
int index = 0;

while (index <= _pos - oldLength)
{
ReadOnlySpan<char> buffer = _chars.Slice(index);
int subIndex = buffer.IndexOf(oldValue.AsSpan(), StringComparison.Ordinal);
if (subIndex < 0)
{
break;
}

index += subIndex;

if (oldLength >= newLength)
{
newValue.AsSpan().CopyTo(_chars.Slice(index));
if (oldLength > newLength)
{
_chars.Slice(index + oldLength).CopyTo(_chars.Slice(index + newLength));
_pos -= oldLength - newLength;
}
}
else
{
Insert(index, newValue);

_chars.Slice(index + newLength + oldLength).CopyTo(_chars.Slice(index + newLength));
_pos -= oldLength;
}

index += newLength;
}
}

public bool RemoveSuffix(string value)
{
if (EndsWith(value))
{
_pos -= value.Length;
return true;
}

return false;
}

public bool EndsWith(string value)
{
int valueLength = value.Length;
if (valueLength > _pos)
{
return false;
}

return _chars.Slice(_pos - valueLength, valueLength).SequenceEqual(value.AsSpan());
}

public void Insert(int index, char value, int count)
{
if (_pos > _chars.Length - count)
Expand Down

0 comments on commit 0fd58cb

Please sign in to comment.