mirror of
https://github.com/Mintplex-Labs/tiktoken.git
synced 2026-07-01 18:48:04 -04:00
80 lines
2.6 KiB
Python
80 lines
2.6 KiB
Python
from typing import Callable
|
|
|
|
import hypothesis
|
|
import pytest
|
|
from hypothesis import strategies as st
|
|
|
|
import tiktoken
|
|
|
|
from .test_helpers import MAX_EXAMPLES, SOME_ENCODING_FACTORIES
|
|
|
|
|
|
def _common_prefix_len(a, b):
|
|
i = 0
|
|
while i < len(a) and i < len(b) and a[i] == b[i]:
|
|
i += 1
|
|
return i
|
|
|
|
|
|
def _token_offsets_reference(enc, tokens):
|
|
text = enc.decode(tokens, errors="strict")
|
|
res = []
|
|
for i in range(len(tokens)):
|
|
prefix = enc.decode(tokens[:i], errors="ignore")
|
|
res.append(_common_prefix_len(text, prefix))
|
|
return res
|
|
|
|
|
|
@pytest.mark.parametrize("make_enc", SOME_ENCODING_FACTORIES)
|
|
@hypothesis.given(data=st.data())
|
|
@hypothesis.settings(deadline=None, max_examples=MAX_EXAMPLES)
|
|
def test_hyp_offsets(make_enc: Callable[[], tiktoken.Encoding], data):
|
|
enc = make_enc()
|
|
|
|
tokens_st = st.lists(
|
|
st.integers(0, enc.n_vocab - 1).filter(
|
|
lambda x: x in enc._special_tokens.values() or x in enc._mergeable_ranks.values()
|
|
),
|
|
min_size=1,
|
|
max_size=20,
|
|
)
|
|
tokens = data.draw(tokens_st)
|
|
|
|
# This is a dumb hack to make sure that our tokens are a valid UTF-8 string
|
|
# We could potentially drop this, see the TODO in decode_with_offsets
|
|
tokens = enc.encode(enc.decode(tokens, errors="ignore"), allowed_special="all")
|
|
assert enc.decode_with_offsets(tokens)[1] == _token_offsets_reference(enc, tokens)
|
|
|
|
|
|
def test_basic_offsets():
|
|
enc = tiktoken.get_encoding("cl100k_base")
|
|
|
|
prompt = "hello world"
|
|
p, o = enc.decode_with_offsets(enc.encode(prompt))
|
|
assert p == prompt
|
|
assert o == [0, 5]
|
|
|
|
prompt = "hello world<|endoftext|> green cow"
|
|
p, o = enc.decode_with_offsets(enc.encode(prompt, allowed_special="all"))
|
|
assert p == prompt
|
|
assert o == [0, 5, 11, 24, 30]
|
|
|
|
prompt = "我非常渴望与人工智能一起工作"
|
|
p, o = enc.decode_with_offsets(enc.encode(prompt))
|
|
assert p == prompt
|
|
assert o == [0, 1, 2, 3, 3, 4, 4, 5, 6, 7, 8, 8, 9, 10, 11, 12, 13]
|
|
|
|
# contains the interesting tokens b'\xe0\xae\xbf\xe0\xae' and b'\xe0\xaf\x8d\xe0\xae'
|
|
# in which \xe0 is the start of a 3-byte UTF-8 character
|
|
prompt = "நடிகர் சூர்யா"
|
|
p, o = enc.decode_with_offsets(enc.encode(prompt))
|
|
assert p == prompt
|
|
assert o == [0, 0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 8, 8, 9, 9, 10, 11, 12, 12]
|
|
|
|
# contains the interesting token b'\xa0\xe9\x99\xa4'
|
|
# in which \xe9 is the start of a 3-byte UTF-8 character and \xa0 is a continuation byte
|
|
prompt = " Ġ除"
|
|
p, o = enc.decode_with_offsets(enc.encode(prompt))
|
|
assert p == prompt
|
|
assert o == [0, 1]
|