mirror of
https://github.com/Mintplex-Labs/tiktoken.git
synced 2026-07-01 18:48:04 -04:00
Merge remote-tracking branch 'upstream/main'
This commit is contained in:
@@ -33,6 +33,7 @@ MANIFEST
|
||||
# Tools
|
||||
.mypy_cache
|
||||
.coverage
|
||||
.hypothesis
|
||||
htmlcov
|
||||
|
||||
# General
|
||||
|
||||
+3
-3
@@ -10,10 +10,10 @@ crate-type = ["lib"]
|
||||
|
||||
[dependencies]
|
||||
# tiktoken dependencies
|
||||
fancy-regex = "0.10.0"
|
||||
regex = "1.7.0"
|
||||
fancy-regex = "0.11.0"
|
||||
regex = "1.8.3"
|
||||
rustc-hash = "1.1.0"
|
||||
bstr = "1.0.1"
|
||||
bstr = "1.5.0"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
+3
-2
@@ -41,5 +41,6 @@ macos.archs = ["x86_64", "arm64"]
|
||||
# Warnings will be silenced with following CIBW_TEST_SKIP
|
||||
test-skip = "*-macosx_arm64"
|
||||
|
||||
before-test = "pip install pytest"
|
||||
test-command = "pytest {project}/tests"
|
||||
before-test = "pip install pytest hypothesis"
|
||||
test-command = "pytest {project}/tests --import-mode=append"
|
||||
|
||||
|
||||
+1
-1
@@ -9,6 +9,6 @@ name = "_tiktoken"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
pyo3 = { version = "0.17.3", features = ["extension-module"] }
|
||||
pyo3 = { version = "0.19.0", features = ["extension-module"] }
|
||||
tiktoken_core = { path = "../core", features = ["multithreading"] }
|
||||
rustc-hash = "1.1.0"
|
||||
|
||||
@@ -0,0 +1,231 @@
|
||||
# Note that there are more actual tests, they're just not currently public :-)
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import hypothesis
|
||||
import hypothesis.strategies as st
|
||||
import pytest
|
||||
|
||||
import tiktoken
|
||||
|
||||
from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES
|
||||
|
||||
|
||||
def test_simple():
|
||||
enc = tiktoken.get_encoding("gpt2")
|
||||
assert enc.encode("hello world") == [31373, 995]
|
||||
assert enc.decode([31373, 995]) == "hello world"
|
||||
assert enc.encode("hello <|endoftext|>", allowed_special="all") == [31373, 220, 50256]
|
||||
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
assert enc.encode("hello world") == [15339, 1917]
|
||||
assert enc.decode([15339, 1917]) == "hello world"
|
||||
assert enc.encode("hello <|endoftext|>", allowed_special="all") == [15339, 220, 100257]
|
||||
|
||||
for enc_name in tiktoken.list_encoding_names():
|
||||
enc = tiktoken.get_encoding(enc_name)
|
||||
for token in range(10_000):
|
||||
assert enc.encode_single_token(enc.decode_single_token_bytes(token)) == token
|
||||
|
||||
|
||||
def test_simple_repeated():
|
||||
enc = tiktoken.get_encoding("gpt2")
|
||||
assert enc.encode("0") == [15]
|
||||
assert enc.encode("00") == [405]
|
||||
assert enc.encode("000") == [830]
|
||||
assert enc.encode("0000") == [2388]
|
||||
assert enc.encode("00000") == [20483]
|
||||
assert enc.encode("000000") == [10535]
|
||||
assert enc.encode("0000000") == [24598]
|
||||
assert enc.encode("00000000") == [8269]
|
||||
assert enc.encode("000000000") == [10535, 830]
|
||||
assert enc.encode("0000000000") == [8269, 405]
|
||||
assert enc.encode("00000000000") == [8269, 830]
|
||||
assert enc.encode("000000000000") == [8269, 2388]
|
||||
assert enc.encode("0000000000000") == [8269, 20483]
|
||||
assert enc.encode("00000000000000") == [8269, 10535]
|
||||
assert enc.encode("000000000000000") == [8269, 24598]
|
||||
assert enc.encode("0000000000000000") == [25645]
|
||||
assert enc.encode("00000000000000000") == [8269, 10535, 830]
|
||||
|
||||
|
||||
def test_simple_regex():
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
assert enc.encode("rer") == [38149]
|
||||
assert enc.encode("'rer") == [2351, 81]
|
||||
assert enc.encode("today\n ") == [31213, 198, 220]
|
||||
assert enc.encode("today\n \n") == [31213, 27907]
|
||||
assert enc.encode("today\n \n") == [31213, 14211]
|
||||
|
||||
|
||||
def test_basic_encode():
|
||||
enc = tiktoken.get_encoding("r50k_base")
|
||||
assert enc.encode("hello world") == [31373, 995]
|
||||
|
||||
enc = tiktoken.get_encoding("p50k_base")
|
||||
assert enc.encode("hello world") == [31373, 995]
|
||||
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
assert enc.encode("hello world") == [15339, 1917]
|
||||
assert enc.encode(" \x850") == [220, 126, 227, 15]
|
||||
|
||||
|
||||
def test_encode_empty():
|
||||
enc = tiktoken.get_encoding("r50k_base")
|
||||
assert enc.encode("") == []
|
||||
|
||||
|
||||
def test_encode_bytes():
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
assert enc._encode_bytes(b" \xec\x8b\xa4\xed") == [62085]
|
||||
|
||||
|
||||
def test_encode_surrogate_pairs():
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
assert enc.encode("👍") == [9468, 239, 235]
|
||||
# surrogate pair gets converted to codepoint
|
||||
assert enc.encode("\ud83d\udc4d") == [9468, 239, 235]
|
||||
|
||||
# lone surrogate just gets replaced
|
||||
assert enc.encode("\ud83d") == enc.encode("�")
|
||||
|
||||
|
||||
# ====================
|
||||
# Roundtrip
|
||||
# ====================
|
||||
|
||||
|
||||
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
|
||||
def test_basic_roundtrip(make_enc):
|
||||
enc = make_enc()
|
||||
for value in (
|
||||
"hello",
|
||||
"hello ",
|
||||
"hello ",
|
||||
" hello",
|
||||
" hello ",
|
||||
" hello ",
|
||||
"hello world",
|
||||
"请考试我的软件!12345",
|
||||
):
|
||||
assert value == enc.decode(enc.encode(value))
|
||||
assert value == enc.decode(enc.encode_ordinary(value))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
|
||||
@hypothesis.given(text=st.text())
|
||||
@hypothesis.settings(deadline=None)
|
||||
def test_hyp_roundtrip(make_enc: Callable[[], tiktoken.Encoding], text):
|
||||
enc = make_enc()
|
||||
|
||||
assert text == enc.decode(enc.encode(text))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
|
||||
def test_single_token_roundtrip(make_enc: Callable[[], tiktoken.Encoding]):
|
||||
enc = make_enc()
|
||||
|
||||
for token in range(enc.n_vocab):
|
||||
try:
|
||||
token_bytes = enc.decode_single_token_bytes(token)
|
||||
except KeyError:
|
||||
continue
|
||||
assert enc.encode_single_token(token_bytes) == token
|
||||
|
||||
|
||||
# ====================
|
||||
# Special tokens
|
||||
# ====================
|
||||
|
||||
|
||||
def test_special_token():
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
eot = enc.encode_single_token("<|endoftext|>")
|
||||
assert eot == enc.eot_token
|
||||
fip = enc.encode_single_token("<|fim_prefix|>")
|
||||
fim = enc.encode_single_token("<|fim_middle|>")
|
||||
|
||||
text = "<|endoftext|> hello <|fim_prefix|>"
|
||||
assert eot not in enc.encode(text, disallowed_special=())
|
||||
with pytest.raises(ValueError):
|
||||
enc.encode(text)
|
||||
with pytest.raises(ValueError):
|
||||
enc.encode(text, disallowed_special="all")
|
||||
with pytest.raises(ValueError):
|
||||
enc.encode(text, disallowed_special={"<|endoftext|>"})
|
||||
with pytest.raises(ValueError):
|
||||
enc.encode(text, disallowed_special={"<|fim_prefix|>"})
|
||||
|
||||
text = "<|endoftext|> hello <|fim_prefix|> there <|fim_middle|>"
|
||||
tokens = enc.encode(text, disallowed_special=())
|
||||
assert eot not in tokens
|
||||
assert fip not in tokens
|
||||
assert fim not in tokens
|
||||
|
||||
tokens = enc.encode(text, allowed_special="all", disallowed_special=())
|
||||
assert eot in tokens
|
||||
assert fip in tokens
|
||||
assert fim in tokens
|
||||
|
||||
tokens = enc.encode(text, allowed_special="all", disallowed_special="all")
|
||||
assert eot in tokens
|
||||
assert fip in tokens
|
||||
assert fim in tokens
|
||||
|
||||
tokens = enc.encode(text, allowed_special={"<|fim_prefix|>"}, disallowed_special=())
|
||||
assert eot not in tokens
|
||||
assert fip in tokens
|
||||
assert fim not in tokens
|
||||
|
||||
tokens = enc.encode(text, allowed_special={"<|endoftext|>"}, disallowed_special=())
|
||||
assert eot in tokens
|
||||
assert fip not in tokens
|
||||
assert fim not in tokens
|
||||
|
||||
tokens = enc.encode(text, allowed_special={"<|fim_middle|>"}, disallowed_special=())
|
||||
assert eot not in tokens
|
||||
assert fip not in tokens
|
||||
assert fim in tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
|
||||
@hypothesis.given(text=st.text())
|
||||
@hypothesis.settings(deadline=None, max_examples=MAX_EXAMPLES)
|
||||
def test_hyp_special_ordinary(make_enc, text: str):
|
||||
enc = make_enc()
|
||||
assert enc.encode_ordinary(text) == enc.encode(text, disallowed_special=())
|
||||
|
||||
|
||||
# ====================
|
||||
# Batch encoding
|
||||
# ====================
|
||||
|
||||
|
||||
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
|
||||
def test_batch_encode(make_enc: Callable[[], tiktoken.Encoding]):
|
||||
enc = make_enc()
|
||||
text1 = "hello world"
|
||||
text2 = "goodbye world"
|
||||
|
||||
assert enc.encode_batch([text1]) == [enc.encode(text1)]
|
||||
assert enc.encode_batch([text1, text2]) == [enc.encode(text1), enc.encode(text2)]
|
||||
|
||||
assert enc.encode_ordinary_batch([text1]) == [enc.encode_ordinary(text1)]
|
||||
assert enc.encode_ordinary_batch([text1, text2]) == [
|
||||
enc.encode_ordinary(text1),
|
||||
enc.encode_ordinary(text2),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
|
||||
@hypothesis.given(batch=st.lists(st.text()))
|
||||
@hypothesis.settings(deadline=None)
|
||||
def test_hyp_batch_roundtrip(make_enc: Callable[[], tiktoken.Encoding], batch):
|
||||
enc = make_enc()
|
||||
|
||||
encoded = enc.encode_batch(batch)
|
||||
assert encoded == [enc.encode(t) for t in batch]
|
||||
decoded = enc.decode_batch(encoded)
|
||||
assert decoded == batch
|
||||
@@ -0,0 +1,22 @@
|
||||
import bisect
|
||||
import functools
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
import tiktoken
|
||||
|
||||
MAX_EXAMPLES: int = int(os.environ.get("TIKTOKEN_MAX_EXAMPLES", "100"))
|
||||
|
||||
ENCODINGS = ["r50k_base", "cl100k_base"]
|
||||
SOME_ENCODINGS = ["cl100k_base"]
|
||||
|
||||
|
||||
ENCODING_FACTORIES = [
|
||||
pytest.param(functools.partial(tiktoken.get_encoding, name), id=name) for name in ENCODINGS
|
||||
]
|
||||
SOME_ENCODING_FACTORIES = [
|
||||
pytest.param(functools.partial(tiktoken.get_encoding, name), id=name) for name in SOME_ENCODINGS
|
||||
]
|
||||
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import tiktoken
|
||||
|
||||
|
||||
def test_encoding_for_model():
|
||||
enc = tiktoken.encoding_for_model("gpt2")
|
||||
assert enc.name == "gpt2"
|
||||
enc = tiktoken.encoding_for_model("text-davinci-003")
|
||||
assert enc.name == "p50k_base"
|
||||
enc = tiktoken.encoding_for_model("text-davinci-edit-001")
|
||||
assert enc.name == "p50k_edit"
|
||||
enc = tiktoken.encoding_for_model("gpt-3.5-turbo-0301")
|
||||
assert enc.name == "cl100k_base"
|
||||
|
||||
|
||||
def test_optional_blobfile_dependency():
|
||||
prog = """
|
||||
import tiktoken
|
||||
import sys
|
||||
assert "blobfile" not in sys.modules
|
||||
"""
|
||||
subprocess.check_call([sys.executable, "-c", prog])
|
||||
@@ -0,0 +1,79 @@
|
||||
from typing import Callable
|
||||
|
||||
import hypothesis
|
||||
import pytest
|
||||
from hypothesis import strategies as st
|
||||
|
||||
import tiktoken
|
||||
|
||||
from .test_helpers import MAX_EXAMPLES, SOME_ENCODING_FACTORIES
|
||||
|
||||
|
||||
def _common_prefix_len(a, b):
|
||||
i = 0
|
||||
while i < len(a) and i < len(b) and a[i] == b[i]:
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
def _token_offsets_reference(enc, tokens):
|
||||
text = enc.decode(tokens, errors="strict")
|
||||
res = []
|
||||
for i in range(len(tokens)):
|
||||
prefix = enc.decode(tokens[:i], errors="ignore")
|
||||
res.append(_common_prefix_len(text, prefix))
|
||||
return res
|
||||
|
||||
|
||||
@pytest.mark.parametrize("make_enc", SOME_ENCODING_FACTORIES)
|
||||
@hypothesis.given(data=st.data())
|
||||
@hypothesis.settings(deadline=None, max_examples=MAX_EXAMPLES)
|
||||
def test_hyp_offsets(make_enc: Callable[[], tiktoken.Encoding], data):
|
||||
enc = make_enc()
|
||||
|
||||
tokens_st = st.lists(
|
||||
st.integers(0, enc.n_vocab - 1).filter(
|
||||
lambda x: x in enc._special_tokens.values() or x in enc._mergeable_ranks.values()
|
||||
),
|
||||
min_size=1,
|
||||
max_size=20,
|
||||
)
|
||||
tokens = data.draw(tokens_st)
|
||||
|
||||
# This is a dumb hack to make sure that our tokens are a valid UTF-8 string
|
||||
# We could potentially drop this, see the TODO in decode_with_offsets
|
||||
tokens = enc.encode(enc.decode(tokens, errors="ignore"), allowed_special="all")
|
||||
assert enc.decode_with_offsets(tokens)[1] == _token_offsets_reference(enc, tokens)
|
||||
|
||||
|
||||
def test_basic_offsets():
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
prompt = "hello world"
|
||||
p, o = enc.decode_with_offsets(enc.encode(prompt))
|
||||
assert p == prompt
|
||||
assert o == [0, 5]
|
||||
|
||||
prompt = "hello world<|endoftext|> green cow"
|
||||
p, o = enc.decode_with_offsets(enc.encode(prompt, allowed_special="all"))
|
||||
assert p == prompt
|
||||
assert o == [0, 5, 11, 24, 30]
|
||||
|
||||
prompt = "我非常渴望与人工智能一起工作"
|
||||
p, o = enc.decode_with_offsets(enc.encode(prompt))
|
||||
assert p == prompt
|
||||
assert o == [0, 1, 2, 3, 3, 4, 4, 5, 6, 7, 8, 8, 9, 10, 11, 12, 13]
|
||||
|
||||
# contains the interesting tokens b'\xe0\xae\xbf\xe0\xae' and b'\xe0\xaf\x8d\xe0\xae'
|
||||
# in which \xe0 is the start of a 3-byte UTF-8 character
|
||||
prompt = "நடிகர் சூர்யா"
|
||||
p, o = enc.decode_with_offsets(enc.encode(prompt))
|
||||
assert p == prompt
|
||||
assert o == [0, 0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 8, 8, 9, 9, 10, 11, 12, 12]
|
||||
|
||||
# contains the interesting token b'\xa0\xe9\x99\xa4'
|
||||
# in which \xe9 is the start of a 3-byte UTF-8 character and \xa0 is a continuation byte
|
||||
prompt = " Ġ除"
|
||||
p, o = enc.decode_with_offsets(enc.encode(prompt))
|
||||
assert p == prompt
|
||||
assert o == [0, 1]
|
||||
@@ -0,0 +1,212 @@
|
||||
"""This is an educational implementation of the byte pair encoding algorithm."""
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import itertools
|
||||
from typing import Optional
|
||||
|
||||
import regex
|
||||
|
||||
import tiktoken
|
||||
|
||||
|
||||
class SimpleBytePairEncoding:
|
||||
def __init__(self, *, pat_str: str, mergeable_ranks: dict[bytes, int]) -> None:
|
||||
"""Creates an Encoding object."""
|
||||
# A regex pattern string that is used to split the input text
|
||||
self.pat_str = pat_str
|
||||
# A dictionary mapping token bytes to their ranks. The ranks correspond to merge priority
|
||||
self.mergeable_ranks = mergeable_ranks
|
||||
|
||||
self._decoder = {token: token_bytes for token_bytes, token in mergeable_ranks.items()}
|
||||
self._pat = regex.compile(pat_str)
|
||||
|
||||
def encode(self, text: str, visualise: Optional[str] = "colour") -> list[int]:
|
||||
"""Encodes a string into tokens.
|
||||
|
||||
>>> enc.encode("hello world")
|
||||
[388, 372]
|
||||
"""
|
||||
# Use the regex to split the text into (approximately) words
|
||||
words = self._pat.findall(text)
|
||||
tokens = []
|
||||
for word in words:
|
||||
# Turn each word into tokens, using the byte pair encoding algorithm
|
||||
word_bytes = word.encode("utf-8")
|
||||
word_tokens = bpe_encode(self.mergeable_ranks, word_bytes, visualise=visualise)
|
||||
tokens.extend(word_tokens)
|
||||
return tokens
|
||||
|
||||
def decode_bytes(self, tokens: list[int]) -> bytes:
|
||||
"""Decodes a list of tokens into bytes.
|
||||
|
||||
>>> enc.decode_bytes([388, 372])
|
||||
b'hello world'
|
||||
"""
|
||||
return b"".join(self._decoder[token] for token in tokens)
|
||||
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
"""Decodes a list of tokens into a string.
|
||||
|
||||
Decoded bytes are not guaranteed to be valid UTF-8. In that case, we replace
|
||||
the invalid bytes with the replacement character "�".
|
||||
|
||||
>>> enc.decode([388, 372])
|
||||
'hello world'
|
||||
"""
|
||||
return self.decode_bytes(tokens).decode("utf-8", errors="replace")
|
||||
|
||||
def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]:
|
||||
"""Decodes a list of tokens into a list of bytes.
|
||||
|
||||
Useful for visualising how a string is tokenised.
|
||||
|
||||
>>> enc.decode_tokens_bytes([388, 372])
|
||||
[b'hello', b' world']
|
||||
"""
|
||||
return [self._decoder[token] for token in tokens]
|
||||
|
||||
@staticmethod
|
||||
def train(training_data: str, vocab_size: int, pat_str: str):
|
||||
"""Train a BPE tokeniser on some data!"""
|
||||
mergeable_ranks = bpe_train(data=training_data, vocab_size=vocab_size, pat_str=pat_str)
|
||||
return SimpleBytePairEncoding(pat_str=pat_str, mergeable_ranks=mergeable_ranks)
|
||||
|
||||
@staticmethod
|
||||
def from_tiktoken(encoding):
|
||||
if isinstance(encoding, str):
|
||||
encoding = tiktoken.get_encoding(encoding)
|
||||
return SimpleBytePairEncoding(
|
||||
pat_str=encoding._pat_str, mergeable_ranks=encoding._mergeable_ranks
|
||||
)
|
||||
|
||||
|
||||
def bpe_encode(
|
||||
mergeable_ranks: dict[bytes, int], input: bytes, visualise: Optional[str] = "colour"
|
||||
) -> list[int]:
|
||||
parts = [bytes([b]) for b in input]
|
||||
while True:
|
||||
# See the intermediate merges play out!
|
||||
if visualise:
|
||||
if visualise in ["colour", "color"]:
|
||||
visualise_tokens(parts)
|
||||
elif visualise == "simple":
|
||||
print(parts)
|
||||
|
||||
# Iterate over all pairs and find the pair we want to merge the most
|
||||
min_idx = None
|
||||
min_rank = None
|
||||
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
|
||||
rank = mergeable_ranks.get(pair[0] + pair[1])
|
||||
if rank is not None and (min_rank is None or rank < min_rank):
|
||||
min_idx = i
|
||||
min_rank = rank
|
||||
|
||||
# If there were no pairs we could merge, we're done!
|
||||
if min_rank is None:
|
||||
break
|
||||
assert min_idx is not None
|
||||
|
||||
# Otherwise, merge that pair and leave the rest unchanged. Then repeat.
|
||||
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :]
|
||||
|
||||
if visualise:
|
||||
print()
|
||||
|
||||
tokens = [mergeable_ranks[part] for part in parts]
|
||||
return tokens
|
||||
|
||||
|
||||
def bpe_train(
|
||||
data: str, vocab_size: int, pat_str: str, visualise: Optional[str] = "colour"
|
||||
) -> dict[bytes, int]:
|
||||
# First, add tokens for each individual byte value
|
||||
if vocab_size < 2**8:
|
||||
raise ValueError("vocab_size must be at least 256, so we can encode all bytes")
|
||||
ranks = {}
|
||||
for i in range(2**8):
|
||||
ranks[bytes([i])] = i
|
||||
|
||||
# Splinter up our data into lists of bytes
|
||||
# data = "Hello world"
|
||||
# words = [
|
||||
# [b'H', b'e', b'l', b'l', b'o'],
|
||||
# [b' ', b'w', b'o', b'r', b'l', b'd']
|
||||
# ]
|
||||
words: list[list[bytes]] = [
|
||||
[bytes([b]) for b in word.encode("utf-8")] for word in regex.findall(pat_str, data)
|
||||
]
|
||||
|
||||
# Now, use our data to figure out which merges we should make
|
||||
while len(ranks) < vocab_size:
|
||||
# Find the most common pair. This will become our next token
|
||||
stats = collections.Counter()
|
||||
for piece in words:
|
||||
for pair in zip(piece[:-1], piece[1:]):
|
||||
stats[pair] += 1
|
||||
|
||||
most_common_pair = max(stats, key=lambda x: stats[x])
|
||||
token_bytes = most_common_pair[0] + most_common_pair[1]
|
||||
token = len(ranks)
|
||||
# Add the new token!
|
||||
ranks[token_bytes] = token
|
||||
|
||||
# Now merge that most common pair in all the words. That is, update our training data
|
||||
# to reflect our decision to make that pair into a new token.
|
||||
new_words = []
|
||||
for word in words:
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word) - 1:
|
||||
if (word[i], word[i + 1]) == most_common_pair:
|
||||
# We found our pair! Merge it
|
||||
new_word.append(token_bytes)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
if i == len(word) - 1:
|
||||
new_word.append(word[i])
|
||||
new_words.append(new_word)
|
||||
words = new_words
|
||||
|
||||
# See the intermediate merges play out!
|
||||
if visualise:
|
||||
print(f"The current most common pair is {most_common_pair[0]} + {most_common_pair[1]}")
|
||||
print(f"So we made {token_bytes} our {len(ranks)}th token")
|
||||
if visualise in ["colour", "color"]:
|
||||
print("Now the first fifty words in our training data look like:")
|
||||
visualise_tokens([token for word in words[:50] for token in word])
|
||||
elif visualise == "simple":
|
||||
print("Now the first twenty words in our training data look like:")
|
||||
for word in words[:20]:
|
||||
print(word)
|
||||
print("\n")
|
||||
|
||||
return ranks
|
||||
|
||||
|
||||
def visualise_tokens(token_values: list[bytes]) -> None:
|
||||
backgrounds = itertools.cycle(
|
||||
[f"\u001b[48;5;{i}m".encode() for i in [167, 179, 185, 77, 80, 68, 134]]
|
||||
)
|
||||
interleaved = itertools.chain.from_iterable(zip(backgrounds, token_values))
|
||||
print((b"".join(interleaved) + "\u001b[0m".encode()).decode("utf-8"))
|
||||
|
||||
|
||||
def train_simple_encoding():
|
||||
gpt2_pattern = (
|
||||
r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||
)
|
||||
with open(__file__, "r") as f:
|
||||
data = f.read()
|
||||
|
||||
enc = SimpleBytePairEncoding.train(data, vocab_size=600, pat_str=gpt2_pattern)
|
||||
|
||||
print("This is the sequence of merges performed in order to encode 'hello world':")
|
||||
tokens = enc.encode("hello world")
|
||||
assert enc.decode(tokens) == "hello world"
|
||||
assert enc.decode_bytes(tokens) == b"hello world"
|
||||
assert enc.decode_tokens_bytes(tokens) == [b"hello", b" world"]
|
||||
|
||||
return enc
|
||||
@@ -276,6 +276,31 @@ class Encoding:
|
||||
"""
|
||||
return [self.decode_single_token_bytes(token) for token in tokens]
|
||||
|
||||
def decode_with_offsets(self, tokens: list[int]) -> tuple[str, list[int]]:
|
||||
"""Decodes a list of tokens into a string and a list of offsets.
|
||||
|
||||
Each offset is the index into text corresponding to the start of each token.
|
||||
If UTF-8 character boundaries do not line up with token boundaries, the offset is the index
|
||||
of the first character that contains bytes from the token.
|
||||
|
||||
This will currently raise if given tokens that decode to invalid UTF-8; this behaviour may
|
||||
change in the future to be more permissive.
|
||||
|
||||
>>> enc.decode_with_offsets([31373, 995])
|
||||
('hello world', [0, 5])
|
||||
"""
|
||||
token_bytes = self.decode_tokens_bytes(tokens)
|
||||
|
||||
text_len = 0
|
||||
offsets = []
|
||||
for token in token_bytes:
|
||||
offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0)))
|
||||
text_len += sum(1 for c in token if not 0x80 <= c < 0xC0)
|
||||
|
||||
# TODO: assess correctness for errors="ignore" and errors="replace"
|
||||
text = b"".join(token_bytes).decode("utf-8", errors="strict")
|
||||
return text, offsets
|
||||
|
||||
def decode_batch(
|
||||
self, batch: list[list[int]], *, errors: str = "replace", num_threads: int = 8
|
||||
) -> list[str]:
|
||||
|
||||
+2
-1
@@ -15,6 +15,7 @@ MODEL_PREFIX_TO_ENCODING: dict[str, str] = {
|
||||
# chat
|
||||
"gpt-4-": "cl100k_base", # e.g., gpt-4-0314, etc., plus gpt-4-32k
|
||||
"gpt-3.5-turbo-": "cl100k_base", # e.g, gpt-3.5-turbo-0301, -0401, etc.
|
||||
"gpt-35-turbo": "cl100k_base", # Azure deployment name
|
||||
}
|
||||
|
||||
MODEL_TO_ENCODING: dict[str, str] = json.loads(pkg_resources.read_text("tiktoken", "model_to_encoding.json"))
|
||||
@@ -36,7 +37,7 @@ def encoding_for_model(model_name: str) -> Encoding:
|
||||
if encoding_name is None:
|
||||
raise KeyError(
|
||||
f"Could not automatically map {model_name} to a tokeniser. "
|
||||
"Please use `tiktok.get_encoding` to explicitly get the tokeniser you expect."
|
||||
"Please use `tiktoken.get_encoding` to explicitly get the tokeniser you expect."
|
||||
) from None
|
||||
|
||||
return get_encoding(encoding_name)
|
||||
|
||||
+3
-3
@@ -11,10 +11,10 @@ crate-type = ["cdylib"]
|
||||
[dependencies]
|
||||
tiktoken_core = { path = "../core", features = [] }
|
||||
# tiktoken dependencies
|
||||
fancy-regex = "0.10.0"
|
||||
regex = "1.7.0"
|
||||
fancy-regex = "0.11.0"
|
||||
regex = "1.8.3"
|
||||
rustc-hash = "1.1.0"
|
||||
bstr = "1.0.1"
|
||||
bstr = "1.5.0"
|
||||
wasm-bindgen = "0.2.83"
|
||||
anyhow = "1.0.69"
|
||||
base64 = "0.21.0"
|
||||
|
||||
Reference in New Issue
Block a user