Skip to content

Commit f8403ed

Browse files
cccclaifacebook-github-bot
authored andcommitted
Error out when token is outside of vocab size (#3535)
Summary: Pull Request resolved: #3535 Ideally it shouldn't happen, but if we post process the weight somehow too much it might happen. In Android, it just seg fault directly if it's outside of the range without error message. After this change, it's clearer: ``` E 00:00:00.180911 executorch:bpe_tokenizer.cpp:155] token 18446744073709551615 is out side of vacab range 512 Aborted ``` Reviewed By: larryliu0820 Differential Revision: D57057026 fbshipit-source-id: 838260d60b75e7c392d7f496d7cdf6f81957f56c
1 parent 2d68bd3 commit f8403ed

5 files changed

Lines changed: 35 additions & 7 deletions

File tree

examples/models/llama2/tokenizer/bpe_tokenizer.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,7 @@ BPETokenizer::~BPETokenizer() {
146146
* token.
147147
*/
148148
Result<std::string> BPETokenizer::decode(uint64_t prev_token, uint64_t token) {
149-
if (!initialized_) {
150-
ET_LOG(Error, "Tokenizer not initialized");
151-
return Error::NotSupported;
152-
}
149+
ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(token));
153150
const char* piece = vocab_[token];
154151
// following BOS token, sentencepiece decoder strips any leading
155152
// whitespace

examples/models/llama2/tokenizer/test/test_bpe_tokenizer.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ TEST_F(TokenizerExtensionTest, DecodeWithoutLoadFails) {
3939
EXPECT_EQ(result.error(), Error::NotSupported);
4040
}
4141

42+
TEST_F(TokenizerExtensionTest, DecodeOutOfRangeFails) {
43+
Error res = tokenizer_->load(modelPath_.c_str());
44+
EXPECT_EQ(res, Error::Ok);
45+
auto result = tokenizer_->decode(0, 64000);
46+
// The vocab size is 32000, and token 64000 is out of vocab range.
47+
EXPECT_EQ(result.error(), Error::NotSupported);
48+
}
49+
4250
TEST_F(TokenizerExtensionTest, TokenizerVocabSizeIsExpected) {
4351
Error res = tokenizer_->load(modelPath_.c_str());
4452
EXPECT_EQ(res, Error::Ok);

examples/models/llama2/tokenizer/test/test_tiktoken.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,14 @@ TEST_F(TiktokenExtensionTest, TokenizerDecodeCorrectly) {
7777
}
7878
}
7979

80+
TEST_F(TiktokenExtensionTest, TokenizerDecodeOutOfRangeFails) {
81+
Error res = tokenizer_->load(modelPath_.c_str());
82+
EXPECT_EQ(res, Error::Ok);
83+
// The vocab size is 128256, addes 256 just so the token is out of vocab
84+
// range.
85+
Result<std::string> out = tokenizer_->decode(0, 128256 + 256);
86+
EXPECT_EQ(out.error(), Error::NotSupported);
87+
}
88+
8089
} // namespace executor
8190
} // namespace torch

examples/models/llama2/tokenizer/tiktoken.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,7 @@ Tiktoken::encode(const std::string& text, int8_t bos, int8_t eos) {
364364

365365
Result<std::string> Tiktoken::decode(uint64_t prev, uint64_t cur) {
366366
(void)prev;
367-
if (!initialized_) {
368-
return Error::NotSupported;
369-
}
367+
ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(cur));
370368
std::string ret;
371369

372370
std::string token_bytes;

examples/models/llama2/tokenizer/tokenizer.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,22 @@ class Tokenizer {
4040
virtual Result<std::vector<uint64_t>>
4141
encode(const std::string& input, int8_t bos, int8_t eos) = 0;
4242

43+
Error decode_verify(uint64_t token) const {
44+
if (!initialized_) {
45+
ET_LOG(Error, "Tokenizer not initialized");
46+
return Error::NotSupported;
47+
}
48+
if (token >= vocab_size_) {
49+
ET_LOG(
50+
Error,
51+
"token %" PRIu64 " is out side of vacab range %d",
52+
token,
53+
vocab_size_);
54+
return Error::NotSupported;
55+
}
56+
return Error::Ok;
57+
}
58+
4359
virtual Result<std::string> decode(uint64_t prev_token, uint64_t token) = 0;
4460

4561
// getters

0 commit comments

Comments
 (0)