diff --git a/Tokenizer_C#/TokenizerLib/ITokenizer.cs b/Tokenizer_C#/TokenizerLib/ITokenizer.cs
index 08ff714..d4faa02 100644
--- a/Tokenizer_C#/TokenizerLib/ITokenizer.cs
+++ b/Tokenizer_C#/TokenizerLib/ITokenizer.cs
@@ -43,5 +43,15 @@ public interface ITokenizer
/// Decode an array of integer token ids
///
public string Decode(int[] tokens);
+
+ ///
+ /// Count a string with or without special tokens set through constructor.
+ ///
+ public int Count(string text, bool applySpecialTokens = true, int max = int.MaxValue);
+
+ ///
+ /// Count a string with a set of allowed special tokens that are not broken apart.
+ ///
+ public int Count(string text, IReadOnlyCollection allowedSpecial, int max = int.MaxValue);
}
}
diff --git a/Tokenizer_C#/TokenizerLib/TikTokenizer.cs b/Tokenizer_C#/TokenizerLib/TikTokenizer.cs
index c0bba10..d52c7bf 100644
--- a/Tokenizer_C#/TokenizerLib/TikTokenizer.cs
+++ b/Tokenizer_C#/TokenizerLib/TikTokenizer.cs
@@ -33,7 +33,7 @@ public class TikTokenizer : ITokenizer
///
public const int DefaultCacheSize = 4096;
- private readonly LruCache Cache;
+ private readonly LruCache> Cache;
public int NumOfCacheEntries => this.Cache.Count;
@@ -47,7 +47,7 @@ public class TikTokenizer : ITokenizer
/// Regex pattern to break a string to be encoded
public TikTokenizer(IReadOnlyDictionary encoder, IReadOnlyDictionary specialTokensEncoder, string pattern, int cacheSize = DefaultCacheSize)
{
- Cache = new LruCache(cacheSize);
+ Cache = new LruCache>(cacheSize);
Init(encoder, specialTokensEncoder, pattern);
}
@@ -59,7 +59,7 @@ public TikTokenizer(IReadOnlyDictionary encoder, IReadOnlyDictionar
/// Regex pattern to break a string to be encoded
public TikTokenizer(Stream tikTokenBpeFileStream, IReadOnlyDictionary specialTokensEncoder, string pattern, int cacheSize = DefaultCacheSize)
{
- Cache = new LruCache(cacheSize);
+ Cache = new LruCache>(cacheSize);
var encoder = LoadTikTokenBpe(tikTokenBpeFileStream);
Init(encoder, specialTokensEncoder, pattern);
}
@@ -251,7 +251,7 @@ private void Encode(string text, List tokenIds, int start, int end)
{
foreach (Match match in Regex.Matches(text[start..end]))
{
- if (this.Cache.Lookup(match.Value, out int[] tokens))
+ if (this.Cache.Lookup(match.Value, out List tokens))
{
tokenIds.AddRange(tokens);
}
@@ -267,7 +267,7 @@ private void Encode(string text, List tokenIds, int start, int end)
{
var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder);
tokenIds.AddRange(encodedTokens);
- this.Cache.Add(match.Value, encodedTokens.ToArray());
+ this.Cache.Add(match.Value, encodedTokens);
}
}
}
@@ -290,9 +290,9 @@ private void Encode(string text, List tokenIds, int start, int end)
foreach (Match match in Regex.Matches(text[start..end]))
{
var piece = match.Value;
- if (this.Cache.Lookup(piece, out int[] tokens))
+ if (this.Cache.Lookup(piece, out List tokens))
{
- tokenCount += tokens.Length;
+ tokenCount += tokens.Count;
if (tokenCount <= maxTokenCount)
{
encodeLength += piece.Length;
@@ -323,7 +323,7 @@ private void Encode(string text, List tokenIds, int start, int end)
else
{
var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder);
- this.Cache.Add(piece, encodedTokens.ToArray());
+ this.Cache.Add(piece, encodedTokens);
tokenCount += encodedTokens.Count;
if (tokenCount <= maxTokenCount)
{
@@ -505,9 +505,9 @@ private void Encode(string text, List tokenIds, int start, ref int tokenCou
{
var piece = match.Value;
- if (this.Cache.Lookup(match.Value, out int[] tokens))
+ if (this.Cache.Lookup(match.Value, out List tokens))
{
- tokenCount += tokens.Length;
+ tokenCount += tokens.Count;
encodeLength += piece.Length;
tokenIds.AddRange(tokens);
tokenCountMap[tokenCount] = encodeLength;
@@ -526,7 +526,7 @@ private void Encode(string text, List tokenIds, int start, ref int tokenCou
else
{
var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder);
- this.Cache.Add(piece, encodedTokens.ToArray());
+ this.Cache.Add(piece, encodedTokens);
tokenCount += encodedTokens.Count;
encodeLength += piece.Length;
tokenIds.AddRange(encodedTokens);
@@ -602,6 +602,99 @@ public string Decode(int[] tokens)
return Encoding.UTF8.GetString(decoded.ToArray());
}
- }
+ public int Count(string text, bool applySpecialTokens = true, int max = int.MaxValue)
+ {
+ if (applySpecialTokens && SpecialTokens.Count > 0)
+ {
+ return CountInternal(text, SpecialTokens, max);
+ }
+
+ return CountTokens(text, max);
+ }
+
+ public int Count(string text, IReadOnlyCollection allowedSpecial, int max = int.MaxValue)
+ {
+ if (allowedSpecial is null || allowedSpecial.Count == 0)
+ {
+ return CountTokens(text, max);
+ }
+
+ return CountInternal(text, allowedSpecial, max);
+ }
+
+ private int CountInternal(string text, IReadOnlyCollection allowedSpecial, int max)
+ {
+ int tokenCount = 0;
+ int start = 0;
+ while (true)
+ {
+ Match nextSpecial;
+ int end;
+ FindNextSpecialToken(text, allowedSpecial, start, out nextSpecial, out end);
+ if (end > start)
+ {
+ tokenCount += CountTokens(text[start..end], max - tokenCount);
+ if (tokenCount >= max)
+ {
+ return max;
+ }
+ }
+
+ if (nextSpecial.Success)
+ {
+ tokenCount++;
+ if (tokenCount >= max)
+ {
+ return max;
+ }
+ start = nextSpecial.Index + nextSpecial.Length;
+ if (start >= text.Length)
+ {
+ break;
+ }
+ }
+ else
+ {
+ break;
+ }
+ }
+
+ return tokenCount;
+ }
+
+ private int CountTokens(string text, int max)
+ {
+ int tokenCount = 0;
+ foreach (Match match in Regex.Matches(text))
+ {
+ var piece = match.Value;
+ if (this.Cache.Lookup(piece, out List tokens))
+ {
+ tokenCount += tokens.Count;
+ }
+ else
+ {
+ var bytes = Encoding.UTF8.GetBytes(match.Value);
+ if (Encoder.TryGetValue(bytes, out int token))
+ {
+ tokenCount++;
+ }
+ else
+ {
+ var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder);
+ this.Cache.Add(piece, encodedTokens);
+ tokenCount += encodedTokens.Count;
+ }
+ }
+
+ if (tokenCount >= max)
+ {
+ return max;
+ }
+ }
+
+ return tokenCount;
+ }
+ }
}
diff --git a/Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs b/Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs
index 43b7d2e..332a4c5 100644
--- a/Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs
+++ b/Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs
@@ -304,5 +304,27 @@ public void TestEncodeR50kbase()
Assert.AreEqual(text, decoded);
}
+ [TestMethod]
+ public void TestCountR50kbbase()
+ {
+ var text = File.ReadAllText("./testData/lib.rs.txt");
+ var count = Tokenizer_r50k_base.Count(text, new HashSet());
+ Assert.AreEqual(11378, count);
+ }
+
+ [TestMethod]
+ public void TestCountR50kbbaseSetMaxTokens()
+ {
+ var text = File.ReadAllText("./testData/lib.rs.txt");
+ var count = Tokenizer_r50k_base.Count(text, new HashSet(), 10000);
+ Assert.AreEqual(10000, count);
+ }
+
+ [TestMethod]
+ public void TestCount0Tokens()
+ {
+ var count = Tokenizer_r50k_base.Count("", new HashSet());
+ Assert.AreEqual(0, count);
+ }
}
}