diff --git a/Tokenizer_C#/TokenizerLib/ITokenizer.cs b/Tokenizer_C#/TokenizerLib/ITokenizer.cs index 08ff714..d4faa02 100644 --- a/Tokenizer_C#/TokenizerLib/ITokenizer.cs +++ b/Tokenizer_C#/TokenizerLib/ITokenizer.cs @@ -43,5 +43,15 @@ public interface ITokenizer /// Decode an array of integer token ids /// public string Decode(int[] tokens); + + /// + /// Count a string with or without special tokens set through constructor. + /// + public int Count(string text, bool applySpecialTokens = true, int max = int.MaxValue); + + /// + /// Count a string with a set of allowed special tokens that are not broken apart. + /// + public int Count(string text, IReadOnlyCollection allowedSpecial, int max = int.MaxValue); } } diff --git a/Tokenizer_C#/TokenizerLib/TikTokenizer.cs b/Tokenizer_C#/TokenizerLib/TikTokenizer.cs index c0bba10..d52c7bf 100644 --- a/Tokenizer_C#/TokenizerLib/TikTokenizer.cs +++ b/Tokenizer_C#/TokenizerLib/TikTokenizer.cs @@ -33,7 +33,7 @@ public class TikTokenizer : ITokenizer /// public const int DefaultCacheSize = 4096; - private readonly LruCache Cache; + private readonly LruCache> Cache; public int NumOfCacheEntries => this.Cache.Count; @@ -47,7 +47,7 @@ public class TikTokenizer : ITokenizer /// Regex pattern to break a string to be encoded public TikTokenizer(IReadOnlyDictionary encoder, IReadOnlyDictionary specialTokensEncoder, string pattern, int cacheSize = DefaultCacheSize) { - Cache = new LruCache(cacheSize); + Cache = new LruCache>(cacheSize); Init(encoder, specialTokensEncoder, pattern); } @@ -59,7 +59,7 @@ public TikTokenizer(IReadOnlyDictionary encoder, IReadOnlyDictionar /// Regex pattern to break a string to be encoded public TikTokenizer(Stream tikTokenBpeFileStream, IReadOnlyDictionary specialTokensEncoder, string pattern, int cacheSize = DefaultCacheSize) { - Cache = new LruCache(cacheSize); + Cache = new LruCache>(cacheSize); var encoder = LoadTikTokenBpe(tikTokenBpeFileStream); Init(encoder, specialTokensEncoder, pattern); } @@ -251,7 +251,7 @@ private void Encode(string text, List tokenIds, int start, int end) { foreach (Match match in Regex.Matches(text[start..end])) { - if (this.Cache.Lookup(match.Value, out int[] tokens)) + if (this.Cache.Lookup(match.Value, out List tokens)) { tokenIds.AddRange(tokens); } @@ -267,7 +267,7 @@ private void Encode(string text, List tokenIds, int start, int end) { var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder); tokenIds.AddRange(encodedTokens); - this.Cache.Add(match.Value, encodedTokens.ToArray()); + this.Cache.Add(match.Value, encodedTokens); } } } @@ -290,9 +290,9 @@ private void Encode(string text, List tokenIds, int start, int end) foreach (Match match in Regex.Matches(text[start..end])) { var piece = match.Value; - if (this.Cache.Lookup(piece, out int[] tokens)) + if (this.Cache.Lookup(piece, out List tokens)) { - tokenCount += tokens.Length; + tokenCount += tokens.Count; if (tokenCount <= maxTokenCount) { encodeLength += piece.Length; @@ -323,7 +323,7 @@ private void Encode(string text, List tokenIds, int start, int end) else { var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder); - this.Cache.Add(piece, encodedTokens.ToArray()); + this.Cache.Add(piece, encodedTokens); tokenCount += encodedTokens.Count; if (tokenCount <= maxTokenCount) { @@ -505,9 +505,9 @@ private void Encode(string text, List tokenIds, int start, ref int tokenCou { var piece = match.Value; - if (this.Cache.Lookup(match.Value, out int[] tokens)) + if (this.Cache.Lookup(match.Value, out List tokens)) { - tokenCount += tokens.Length; + tokenCount += tokens.Count; encodeLength += piece.Length; tokenIds.AddRange(tokens); tokenCountMap[tokenCount] = encodeLength; @@ -526,7 +526,7 @@ private void Encode(string text, List tokenIds, int start, ref int tokenCou else { var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder); - this.Cache.Add(piece, encodedTokens.ToArray()); + this.Cache.Add(piece, encodedTokens); tokenCount += encodedTokens.Count; encodeLength += piece.Length; tokenIds.AddRange(encodedTokens); @@ -602,6 +602,99 @@ public string Decode(int[] tokens) return Encoding.UTF8.GetString(decoded.ToArray()); } - } + public int Count(string text, bool applySpecialTokens = true, int max = int.MaxValue) + { + if (applySpecialTokens && SpecialTokens.Count > 0) + { + return CountInternal(text, SpecialTokens, max); + } + + return CountTokens(text, max); + } + + public int Count(string text, IReadOnlyCollection allowedSpecial, int max = int.MaxValue) + { + if (allowedSpecial is null || allowedSpecial.Count == 0) + { + return CountTokens(text, max); + } + + return CountInternal(text, allowedSpecial, max); + } + + private int CountInternal(string text, IReadOnlyCollection allowedSpecial, int max) + { + int tokenCount = 0; + int start = 0; + while (true) + { + Match nextSpecial; + int end; + FindNextSpecialToken(text, allowedSpecial, start, out nextSpecial, out end); + if (end > start) + { + tokenCount += CountTokens(text[start..end], max - tokenCount); + if (tokenCount >= max) + { + return max; + } + } + + if (nextSpecial.Success) + { + tokenCount++; + if (tokenCount >= max) + { + return max; + } + start = nextSpecial.Index + nextSpecial.Length; + if (start >= text.Length) + { + break; + } + } + else + { + break; + } + } + + return tokenCount; + } + + private int CountTokens(string text, int max) + { + int tokenCount = 0; + foreach (Match match in Regex.Matches(text)) + { + var piece = match.Value; + if (this.Cache.Lookup(piece, out List tokens)) + { + tokenCount += tokens.Count; + } + else + { + var bytes = Encoding.UTF8.GetBytes(match.Value); + if (Encoder.TryGetValue(bytes, out int token)) + { + tokenCount++; + } + else + { + var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder); + this.Cache.Add(piece, encodedTokens); + tokenCount += encodedTokens.Count; + } + } + + if (tokenCount >= max) + { + return max; + } + } + + return tokenCount; + } + } } diff --git a/Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs b/Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs index 43b7d2e..332a4c5 100644 --- a/Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs +++ b/Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs @@ -304,5 +304,27 @@ public void TestEncodeR50kbase() Assert.AreEqual(text, decoded); } + [TestMethod] + public void TestCountR50kbbase() + { + var text = File.ReadAllText("./testData/lib.rs.txt"); + var count = Tokenizer_r50k_base.Count(text, new HashSet()); + Assert.AreEqual(11378, count); + } + + [TestMethod] + public void TestCountR50kbbaseSetMaxTokens() + { + var text = File.ReadAllText("./testData/lib.rs.txt"); + var count = Tokenizer_r50k_base.Count(text, new HashSet(), 10000); + Assert.AreEqual(10000, count); + } + + [TestMethod] + public void TestCount0Tokens() + { + var count = Tokenizer_r50k_base.Count("", new HashSet()); + Assert.AreEqual(0, count); + } } }