Merge remote-tracking branch 'upstream/main'

This commit is contained in:
Tat Dat Duong
2023-06-15 15:01:42 +02:00
13 changed files with 606 additions and 10 deletions
+1
View File
@@ -33,6 +33,7 @@ MANIFEST
# Tools
.mypy_cache
.coverage
.hypothesis
htmlcov
# General
+3 -3
View File
@@ -10,10 +10,10 @@ crate-type = ["lib"]
[dependencies]
# tiktoken dependencies
fancy-regex = "0.10.0"
regex = "1.7.0"
fancy-regex = "0.11.0"
regex = "1.8.3"
rustc-hash = "1.1.0"
bstr = "1.0.1"
bstr = "1.5.0"
[features]
default = []
+3 -2
View File
@@ -41,5 +41,6 @@ macos.archs = ["x86_64", "arm64"]
# Warnings will be silenced with following CIBW_TEST_SKIP
test-skip = "*-macosx_arm64"
before-test = "pip install pytest"
test-command = "pytest {project}/tests"
before-test = "pip install pytest hypothesis"
test-command = "pytest {project}/tests --import-mode=append"
+1 -1
View File
@@ -9,6 +9,6 @@ name = "_tiktoken"
crate-type = ["cdylib"]
[dependencies]
pyo3 = { version = "0.17.3", features = ["extension-module"] }
pyo3 = { version = "0.19.0", features = ["extension-module"] }
tiktoken_core = { path = "../core", features = ["multithreading"] }
rustc-hash = "1.1.0"
View File
+231
View File
@@ -0,0 +1,231 @@
# Note that there are more actual tests, they're just not currently public :-)
from typing import Callable
import hypothesis
import hypothesis.strategies as st
import pytest
import tiktoken
from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES
def test_simple():
enc = tiktoken.get_encoding("gpt2")
assert enc.encode("hello world") == [31373, 995]
assert enc.decode([31373, 995]) == "hello world"
assert enc.encode("hello <|endoftext|>", allowed_special="all") == [31373, 220, 50256]
enc = tiktoken.get_encoding("cl100k_base")
assert enc.encode("hello world") == [15339, 1917]
assert enc.decode([15339, 1917]) == "hello world"
assert enc.encode("hello <|endoftext|>", allowed_special="all") == [15339, 220, 100257]
for enc_name in tiktoken.list_encoding_names():
enc = tiktoken.get_encoding(enc_name)
for token in range(10_000):
assert enc.encode_single_token(enc.decode_single_token_bytes(token)) == token
def test_simple_repeated():
enc = tiktoken.get_encoding("gpt2")
assert enc.encode("0") == [15]
assert enc.encode("00") == [405]
assert enc.encode("000") == [830]
assert enc.encode("0000") == [2388]
assert enc.encode("00000") == [20483]
assert enc.encode("000000") == [10535]
assert enc.encode("0000000") == [24598]
assert enc.encode("00000000") == [8269]
assert enc.encode("000000000") == [10535, 830]
assert enc.encode("0000000000") == [8269, 405]
assert enc.encode("00000000000") == [8269, 830]
assert enc.encode("000000000000") == [8269, 2388]
assert enc.encode("0000000000000") == [8269, 20483]
assert enc.encode("00000000000000") == [8269, 10535]
assert enc.encode("000000000000000") == [8269, 24598]
assert enc.encode("0000000000000000") == [25645]
assert enc.encode("00000000000000000") == [8269, 10535, 830]
def test_simple_regex():
enc = tiktoken.get_encoding("cl100k_base")
assert enc.encode("rer") == [38149]
assert enc.encode("'rer") == [2351, 81]
assert enc.encode("today\n ") == [31213, 198, 220]
assert enc.encode("today\n \n") == [31213, 27907]
assert enc.encode("today\n \n") == [31213, 14211]
def test_basic_encode():
enc = tiktoken.get_encoding("r50k_base")
assert enc.encode("hello world") == [31373, 995]
enc = tiktoken.get_encoding("p50k_base")
assert enc.encode("hello world") == [31373, 995]
enc = tiktoken.get_encoding("cl100k_base")
assert enc.encode("hello world") == [15339, 1917]
assert enc.encode(" \x850") == [220, 126, 227, 15]
def test_encode_empty():
enc = tiktoken.get_encoding("r50k_base")
assert enc.encode("") == []
def test_encode_bytes():
enc = tiktoken.get_encoding("cl100k_base")
assert enc._encode_bytes(b" \xec\x8b\xa4\xed") == [62085]
def test_encode_surrogate_pairs():
enc = tiktoken.get_encoding("cl100k_base")
assert enc.encode("👍") == [9468, 239, 235]
# surrogate pair gets converted to codepoint
assert enc.encode("\ud83d\udc4d") == [9468, 239, 235]
# lone surrogate just gets replaced
assert enc.encode("\ud83d") == enc.encode("")
# ====================
# Roundtrip
# ====================
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
def test_basic_roundtrip(make_enc):
enc = make_enc()
for value in (
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件!12345",
):
assert value == enc.decode(enc.encode(value))
assert value == enc.decode(enc.encode_ordinary(value))
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
@hypothesis.given(text=st.text())
@hypothesis.settings(deadline=None)
def test_hyp_roundtrip(make_enc: Callable[[], tiktoken.Encoding], text):
enc = make_enc()
assert text == enc.decode(enc.encode(text))
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
def test_single_token_roundtrip(make_enc: Callable[[], tiktoken.Encoding]):
enc = make_enc()
for token in range(enc.n_vocab):
try:
token_bytes = enc.decode_single_token_bytes(token)
except KeyError:
continue
assert enc.encode_single_token(token_bytes) == token
# ====================
# Special tokens
# ====================
def test_special_token():
enc = tiktoken.get_encoding("cl100k_base")
eot = enc.encode_single_token("<|endoftext|>")
assert eot == enc.eot_token
fip = enc.encode_single_token("<|fim_prefix|>")
fim = enc.encode_single_token("<|fim_middle|>")
text = "<|endoftext|> hello <|fim_prefix|>"
assert eot not in enc.encode(text, disallowed_special=())
with pytest.raises(ValueError):
enc.encode(text)
with pytest.raises(ValueError):
enc.encode(text, disallowed_special="all")
with pytest.raises(ValueError):
enc.encode(text, disallowed_special={"<|endoftext|>"})
with pytest.raises(ValueError):
enc.encode(text, disallowed_special={"<|fim_prefix|>"})
text = "<|endoftext|> hello <|fim_prefix|> there <|fim_middle|>"
tokens = enc.encode(text, disallowed_special=())
assert eot not in tokens
assert fip not in tokens
assert fim not in tokens
tokens = enc.encode(text, allowed_special="all", disallowed_special=())
assert eot in tokens
assert fip in tokens
assert fim in tokens
tokens = enc.encode(text, allowed_special="all", disallowed_special="all")
assert eot in tokens
assert fip in tokens
assert fim in tokens
tokens = enc.encode(text, allowed_special={"<|fim_prefix|>"}, disallowed_special=())
assert eot not in tokens
assert fip in tokens
assert fim not in tokens
tokens = enc.encode(text, allowed_special={"<|endoftext|>"}, disallowed_special=())
assert eot in tokens
assert fip not in tokens
assert fim not in tokens
tokens = enc.encode(text, allowed_special={"<|fim_middle|>"}, disallowed_special=())
assert eot not in tokens
assert fip not in tokens
assert fim in tokens
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
@hypothesis.given(text=st.text())
@hypothesis.settings(deadline=None, max_examples=MAX_EXAMPLES)
def test_hyp_special_ordinary(make_enc, text: str):
enc = make_enc()
assert enc.encode_ordinary(text) == enc.encode(text, disallowed_special=())
# ====================
# Batch encoding
# ====================
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
def test_batch_encode(make_enc: Callable[[], tiktoken.Encoding]):
enc = make_enc()
text1 = "hello world"
text2 = "goodbye world"
assert enc.encode_batch([text1]) == [enc.encode(text1)]
assert enc.encode_batch([text1, text2]) == [enc.encode(text1), enc.encode(text2)]
assert enc.encode_ordinary_batch([text1]) == [enc.encode_ordinary(text1)]
assert enc.encode_ordinary_batch([text1, text2]) == [
enc.encode_ordinary(text1),
enc.encode_ordinary(text2),
]
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
@hypothesis.given(batch=st.lists(st.text()))
@hypothesis.settings(deadline=None)
def test_hyp_batch_roundtrip(make_enc: Callable[[], tiktoken.Encoding], batch):
enc = make_enc()
encoded = enc.encode_batch(batch)
assert encoded == [enc.encode(t) for t in batch]
decoded = enc.decode_batch(encoded)
assert decoded == batch
+22
View File
@@ -0,0 +1,22 @@
import bisect
import functools
import os
import pytest
import tiktoken
MAX_EXAMPLES: int = int(os.environ.get("TIKTOKEN_MAX_EXAMPLES", "100"))
ENCODINGS = ["r50k_base", "cl100k_base"]
SOME_ENCODINGS = ["cl100k_base"]
ENCODING_FACTORIES = [
pytest.param(functools.partial(tiktoken.get_encoding, name), id=name) for name in ENCODINGS
]
SOME_ENCODING_FACTORIES = [
pytest.param(functools.partial(tiktoken.get_encoding, name), id=name) for name in SOME_ENCODINGS
]
+24
View File
@@ -0,0 +1,24 @@
import subprocess
import sys
import tiktoken
def test_encoding_for_model():
enc = tiktoken.encoding_for_model("gpt2")
assert enc.name == "gpt2"
enc = tiktoken.encoding_for_model("text-davinci-003")
assert enc.name == "p50k_base"
enc = tiktoken.encoding_for_model("text-davinci-edit-001")
assert enc.name == "p50k_edit"
enc = tiktoken.encoding_for_model("gpt-3.5-turbo-0301")
assert enc.name == "cl100k_base"
def test_optional_blobfile_dependency():
prog = """
import tiktoken
import sys
assert "blobfile" not in sys.modules
"""
subprocess.check_call([sys.executable, "-c", prog])
+79
View File
@@ -0,0 +1,79 @@
from typing import Callable
import hypothesis
import pytest
from hypothesis import strategies as st
import tiktoken
from .test_helpers import MAX_EXAMPLES, SOME_ENCODING_FACTORIES
def _common_prefix_len(a, b):
i = 0
while i < len(a) and i < len(b) and a[i] == b[i]:
i += 1
return i
def _token_offsets_reference(enc, tokens):
text = enc.decode(tokens, errors="strict")
res = []
for i in range(len(tokens)):
prefix = enc.decode(tokens[:i], errors="ignore")
res.append(_common_prefix_len(text, prefix))
return res
@pytest.mark.parametrize("make_enc", SOME_ENCODING_FACTORIES)
@hypothesis.given(data=st.data())
@hypothesis.settings(deadline=None, max_examples=MAX_EXAMPLES)
def test_hyp_offsets(make_enc: Callable[[], tiktoken.Encoding], data):
enc = make_enc()
tokens_st = st.lists(
st.integers(0, enc.n_vocab - 1).filter(
lambda x: x in enc._special_tokens.values() or x in enc._mergeable_ranks.values()
),
min_size=1,
max_size=20,
)
tokens = data.draw(tokens_st)
# This is a dumb hack to make sure that our tokens are a valid UTF-8 string
# We could potentially drop this, see the TODO in decode_with_offsets
tokens = enc.encode(enc.decode(tokens, errors="ignore"), allowed_special="all")
assert enc.decode_with_offsets(tokens)[1] == _token_offsets_reference(enc, tokens)
def test_basic_offsets():
enc = tiktoken.get_encoding("cl100k_base")
prompt = "hello world"
p, o = enc.decode_with_offsets(enc.encode(prompt))
assert p == prompt
assert o == [0, 5]
prompt = "hello world<|endoftext|> green cow"
p, o = enc.decode_with_offsets(enc.encode(prompt, allowed_special="all"))
assert p == prompt
assert o == [0, 5, 11, 24, 30]
prompt = "我非常渴望与人工智能一起工作"
p, o = enc.decode_with_offsets(enc.encode(prompt))
assert p == prompt
assert o == [0, 1, 2, 3, 3, 4, 4, 5, 6, 7, 8, 8, 9, 10, 11, 12, 13]
# contains the interesting tokens b'\xe0\xae\xbf\xe0\xae' and b'\xe0\xaf\x8d\xe0\xae'
# in which \xe0 is the start of a 3-byte UTF-8 character
prompt = "நடிகர் சூர்யா"
p, o = enc.decode_with_offsets(enc.encode(prompt))
assert p == prompt
assert o == [0, 0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 8, 8, 9, 9, 10, 11, 12, 12]
# contains the interesting token b'\xa0\xe9\x99\xa4'
# in which \xe9 is the start of a 3-byte UTF-8 character and \xa0 is a continuation byte
prompt = " Ġ除"
p, o = enc.decode_with_offsets(enc.encode(prompt))
assert p == prompt
assert o == [0, 1]
+212
View File
@@ -0,0 +1,212 @@
"""This is an educational implementation of the byte pair encoding algorithm."""
from __future__ import annotations
import collections
import itertools
from typing import Optional
import regex
import tiktoken
class SimpleBytePairEncoding:
def __init__(self, *, pat_str: str, mergeable_ranks: dict[bytes, int]) -> None:
"""Creates an Encoding object."""
# A regex pattern string that is used to split the input text
self.pat_str = pat_str
# A dictionary mapping token bytes to their ranks. The ranks correspond to merge priority
self.mergeable_ranks = mergeable_ranks
self._decoder = {token: token_bytes for token_bytes, token in mergeable_ranks.items()}
self._pat = regex.compile(pat_str)
def encode(self, text: str, visualise: Optional[str] = "colour") -> list[int]:
"""Encodes a string into tokens.
>>> enc.encode("hello world")
[388, 372]
"""
# Use the regex to split the text into (approximately) words
words = self._pat.findall(text)
tokens = []
for word in words:
# Turn each word into tokens, using the byte pair encoding algorithm
word_bytes = word.encode("utf-8")
word_tokens = bpe_encode(self.mergeable_ranks, word_bytes, visualise=visualise)
tokens.extend(word_tokens)
return tokens
def decode_bytes(self, tokens: list[int]) -> bytes:
"""Decodes a list of tokens into bytes.
>>> enc.decode_bytes([388, 372])
b'hello world'
"""
return b"".join(self._decoder[token] for token in tokens)
def decode(self, tokens: list[int]) -> str:
"""Decodes a list of tokens into a string.
Decoded bytes are not guaranteed to be valid UTF-8. In that case, we replace
the invalid bytes with the replacement character "".
>>> enc.decode([388, 372])
'hello world'
"""
return self.decode_bytes(tokens).decode("utf-8", errors="replace")
def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]:
"""Decodes a list of tokens into a list of bytes.
Useful for visualising how a string is tokenised.
>>> enc.decode_tokens_bytes([388, 372])
[b'hello', b' world']
"""
return [self._decoder[token] for token in tokens]
@staticmethod
def train(training_data: str, vocab_size: int, pat_str: str):
"""Train a BPE tokeniser on some data!"""
mergeable_ranks = bpe_train(data=training_data, vocab_size=vocab_size, pat_str=pat_str)
return SimpleBytePairEncoding(pat_str=pat_str, mergeable_ranks=mergeable_ranks)
@staticmethod
def from_tiktoken(encoding):
if isinstance(encoding, str):
encoding = tiktoken.get_encoding(encoding)
return SimpleBytePairEncoding(
pat_str=encoding._pat_str, mergeable_ranks=encoding._mergeable_ranks
)
def bpe_encode(
mergeable_ranks: dict[bytes, int], input: bytes, visualise: Optional[str] = "colour"
) -> list[int]:
parts = [bytes([b]) for b in input]
while True:
# See the intermediate merges play out!
if visualise:
if visualise in ["colour", "color"]:
visualise_tokens(parts)
elif visualise == "simple":
print(parts)
# Iterate over all pairs and find the pair we want to merge the most
min_idx = None
min_rank = None
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
rank = mergeable_ranks.get(pair[0] + pair[1])
if rank is not None and (min_rank is None or rank < min_rank):
min_idx = i
min_rank = rank
# If there were no pairs we could merge, we're done!
if min_rank is None:
break
assert min_idx is not None
# Otherwise, merge that pair and leave the rest unchanged. Then repeat.
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :]
if visualise:
print()
tokens = [mergeable_ranks[part] for part in parts]
return tokens
def bpe_train(
data: str, vocab_size: int, pat_str: str, visualise: Optional[str] = "colour"
) -> dict[bytes, int]:
# First, add tokens for each individual byte value
if vocab_size < 2**8:
raise ValueError("vocab_size must be at least 256, so we can encode all bytes")
ranks = {}
for i in range(2**8):
ranks[bytes([i])] = i
# Splinter up our data into lists of bytes
# data = "Hello world"
# words = [
# [b'H', b'e', b'l', b'l', b'o'],
# [b' ', b'w', b'o', b'r', b'l', b'd']
# ]
words: list[list[bytes]] = [
[bytes([b]) for b in word.encode("utf-8")] for word in regex.findall(pat_str, data)
]
# Now, use our data to figure out which merges we should make
while len(ranks) < vocab_size:
# Find the most common pair. This will become our next token
stats = collections.Counter()
for piece in words:
for pair in zip(piece[:-1], piece[1:]):
stats[pair] += 1
most_common_pair = max(stats, key=lambda x: stats[x])
token_bytes = most_common_pair[0] + most_common_pair[1]
token = len(ranks)
# Add the new token!
ranks[token_bytes] = token
# Now merge that most common pair in all the words. That is, update our training data
# to reflect our decision to make that pair into a new token.
new_words = []
for word in words:
new_word = []
i = 0
while i < len(word) - 1:
if (word[i], word[i + 1]) == most_common_pair:
# We found our pair! Merge it
new_word.append(token_bytes)
i += 2
else:
new_word.append(word[i])
i += 1
if i == len(word) - 1:
new_word.append(word[i])
new_words.append(new_word)
words = new_words
# See the intermediate merges play out!
if visualise:
print(f"The current most common pair is {most_common_pair[0]} + {most_common_pair[1]}")
print(f"So we made {token_bytes} our {len(ranks)}th token")
if visualise in ["colour", "color"]:
print("Now the first fifty words in our training data look like:")
visualise_tokens([token for word in words[:50] for token in word])
elif visualise == "simple":
print("Now the first twenty words in our training data look like:")
for word in words[:20]:
print(word)
print("\n")
return ranks
def visualise_tokens(token_values: list[bytes]) -> None:
backgrounds = itertools.cycle(
[f"\u001b[48;5;{i}m".encode() for i in [167, 179, 185, 77, 80, 68, 134]]
)
interleaved = itertools.chain.from_iterable(zip(backgrounds, token_values))
print((b"".join(interleaved) + "\u001b[0m".encode()).decode("utf-8"))
def train_simple_encoding():
gpt2_pattern = (
r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
with open(__file__, "r") as f:
data = f.read()
enc = SimpleBytePairEncoding.train(data, vocab_size=600, pat_str=gpt2_pattern)
print("This is the sequence of merges performed in order to encode 'hello world':")
tokens = enc.encode("hello world")
assert enc.decode(tokens) == "hello world"
assert enc.decode_bytes(tokens) == b"hello world"
assert enc.decode_tokens_bytes(tokens) == [b"hello", b" world"]
return enc
+25
View File
@@ -276,6 +276,31 @@ class Encoding:
"""
return [self.decode_single_token_bytes(token) for token in tokens]
def decode_with_offsets(self, tokens: list[int]) -> tuple[str, list[int]]:
"""Decodes a list of tokens into a string and a list of offsets.
Each offset is the index into text corresponding to the start of each token.
If UTF-8 character boundaries do not line up with token boundaries, the offset is the index
of the first character that contains bytes from the token.
This will currently raise if given tokens that decode to invalid UTF-8; this behaviour may
change in the future to be more permissive.
>>> enc.decode_with_offsets([31373, 995])
('hello world', [0, 5])
"""
token_bytes = self.decode_tokens_bytes(tokens)
text_len = 0
offsets = []
for token in token_bytes:
offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0)))
text_len += sum(1 for c in token if not 0x80 <= c < 0xC0)
# TODO: assess correctness for errors="ignore" and errors="replace"
text = b"".join(token_bytes).decode("utf-8", errors="strict")
return text, offsets
def decode_batch(
self, batch: list[list[int]], *, errors: str = "replace", num_threads: int = 8
) -> list[str]:
+2 -1
View File
@@ -15,6 +15,7 @@ MODEL_PREFIX_TO_ENCODING: dict[str, str] = {
# chat
"gpt-4-": "cl100k_base", # e.g., gpt-4-0314, etc., plus gpt-4-32k
"gpt-3.5-turbo-": "cl100k_base", # e.g, gpt-3.5-turbo-0301, -0401, etc.
"gpt-35-turbo": "cl100k_base", # Azure deployment name
}
MODEL_TO_ENCODING: dict[str, str] = json.loads(pkg_resources.read_text("tiktoken", "model_to_encoding.json"))
@@ -36,7 +37,7 @@ def encoding_for_model(model_name: str) -> Encoding:
if encoding_name is None:
raise KeyError(
f"Could not automatically map {model_name} to a tokeniser. "
"Please use `tiktok.get_encoding` to explicitly get the tokeniser you expect."
"Please use `tiktoken.get_encoding` to explicitly get the tokeniser you expect."
) from None
return get_encoding(encoding_name)
+3 -3
View File
@@ -11,10 +11,10 @@ crate-type = ["cdylib"]
[dependencies]
tiktoken_core = { path = "../core", features = [] }
# tiktoken dependencies
fancy-regex = "0.10.0"
regex = "1.7.0"
fancy-regex = "0.11.0"
regex = "1.8.3"
rustc-hash = "1.1.0"
bstr = "1.0.1"
bstr = "1.5.0"
wasm-bindgen = "0.2.83"
anyhow = "1.0.69"
base64 = "0.21.0"