From e3ab3f643e79ef560823e4161f380926fd6edf4c Mon Sep 17 00:00:00 2001 From: Markus Cozowicz Date: Mon, 27 Feb 2023 12:17:37 +0100 Subject: [PATCH] moved config into json --- MANIFEST.in | 1 + core/src/openai_public.rs | 137 +++++------------- java/pom.xml | 43 ++++-- java/src/main/java/tiktoken/Encoding.java | 15 +- ...{EncodingTest.java => EncodingTestIT.java} | 5 +- java/src/tiktoken_Encoding.h | 29 ++++ jni/Cargo.toml | 5 + setup.py | 5 +- tests/test_simple_public.py | 15 -- tiktoken/load.py | 4 - tiktoken/model.py | 47 +----- tiktoken/model_to_encoding.json | 32 ++++ tiktoken/registry.json | 50 +++++++ tiktoken/registry.py | 72 ++++----- tiktoken_ext/openai_public.py | 87 ----------- 15 files changed, 240 insertions(+), 307 deletions(-) rename java/src/test/java/tiktoken/{EncodingTest.java => EncodingTestIT.java} (79%) create mode 100644 java/src/tiktoken_Encoding.h create mode 100644 tiktoken/model_to_encoding.json create mode 100644 tiktoken/registry.json delete mode 100644 tiktoken_ext/openai_public.py diff --git a/MANIFEST.in b/MANIFEST.in index 7f25b27..321b66e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -6,3 +6,4 @@ global-include py.typed recursive-include scripts *.py recursive-include tests *.py recursive-include src *.rs +include tiktoken *.json \ No newline at end of file diff --git a/core/src/openai_public.rs b/core/src/openai_public.rs index 17ff7e2..2a89843 100644 --- a/core/src/openai_public.rs +++ b/core/src/openai_public.rs @@ -2,6 +2,7 @@ use rustc_hash::FxHashMap as HashMap; use std::error::Error; use std::sync::RwLock; +use json; #[path = "load.rs"] mod load; @@ -9,105 +10,47 @@ mod load; type Result = std::result::Result>; lazy_static! { - pub static ref REGISTRY: HashMap = [ - EncodingLazy::new( - "gpt2".into(), - Some(50257), - r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+".into(), - [ ("<|endoftext|>".into(), 50256), ].into_iter().collect(), - EncoderLoadingStrategy::DataGym( - DataGymDef { - vocab_bpe_file: "https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe".into(), - encoder_json_file: "https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json".into() - } - )), - EncodingLazy::new( - "r50k_base".into(), - Some(50257), - r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+".into(), - [ ("<|endoftext|>".into(), 50256), ].into_iter().collect(), - EncoderLoadingStrategy::BPE("https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken".into()) - ), - EncodingLazy::new( - "p50k_base".into(), - Some(50281), - r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+".into(), - [ ("<|endoftext|>".into(), 50256), ].into_iter().collect(), - EncoderLoadingStrategy::BPE("https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken".into()) - ), - EncodingLazy::new( - "p50k_edit".into(), - Some(50281), - r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+".into(), - [ - ("<|endoftext|>".into(), 50256), - ("<|fim_prefix|>".into(), 50281), - ("<|fim_middle|>".into(), 50282), - ("<|fim_suffix|>".into(), 50283), - ].into_iter().collect(), - EncoderLoadingStrategy::BPE("https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken".into()) - ), - EncodingLazy::new( - "cl100k_base".into(), - None, - r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+".into(), - [ - ("<|endoftext|>".into(), 100257), - ("<|fim_prefix|>".into(), 100258), - ("<|fim_middle|>".into(), 100259), - ("<|fim_suffix|>".into(), 100260), - ("<|endofprompt|>".into(), 100276), - ].into_iter().collect(), - EncoderLoadingStrategy::BPE("https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken".into()) - ), - ] - .into_iter() + pub static ref REGISTRY: HashMap = { + // TODO: error handling + json::parse(include_str!("../../tiktoken/registry.json")) + .expect("Failed to parse internal JSON") + .entries() + .map(|(key, value)| { + let loading_strategy = if value.has_key("data_gym_to_mergeable_bpe_ranks") { + EncoderLoadingStrategy::DataGym( + DataGymDef { + vocab_bpe_file: value["data_gym_to_mergeable_bpe_ranks"]["vocab_bpe_file"].as_str().expect("error").into(), + encoder_json_file: value["data_gym_to_mergeable_bpe_ranks"]["encoder_json_file"].as_str().expect("error").into() + }) + } + else if value.has_key("load_tiktoken_bpe") { + EncoderLoadingStrategy::BPE(value["load_tiktoken_bpe"].as_str().expect("fail").into()) + } + else { + panic!("Invalid encoding"); + }; + + EncodingLazy::new( + key.into(), + value["explicit_n_vocab"].as_usize(), + value["pat_str"].as_str().expect("foo").into(), + value["special_tokens"].entries() + .map(|(key, value)| (key.into(), value.as_usize().expect("foo"))) + .collect::>(), + loading_strategy + ) + }) + .map(|enc| (enc.name.clone(), enc)) - .collect::>(); + .collect::>() + }; - - - pub static ref MODEL_TO_ENCODING: HashMap = [ - // text - ("text-davinci-003", "p50k_base"), - ("text-davinci-002", "p50k_base"), - ("text-davinci-001", "r50k_base"), - ("text-curie-001", "r50k_base"), - ("text-babbage-001", "r50k_base"), - ("text-ada-001", "r50k_base"), - ("davinci", "r50k_base"), - ("curie", "r50k_base"), - ("babbage", "r50k_base"), - ("ada", "r50k_base"), - // code - ("code-davinci-002", "p50k_base"), - ("code-davinci-001", "p50k_base"), - ("code-cushman-002", "p50k_base"), - ("code-cushman-001", "p50k_base"), - ("davinci-codex", "p50k_base"), - ("cushman-codex", "p50k_base"), - // edit - ("text-davinci-edit-001", "p50k_edit"), - ("code-davinci-edit-001", "p50k_edit"), - // embeddings - ("text-embedding-ada-002", "cl100k_base"), - // old embeddings - ("text-similarity-davinci-001", "r50k_base"), - ("text-similarity-curie-001", "r50k_base"), - ("text-similarity-babbage-001", "r50k_base"), - ("text-similarity-ada-001", "r50k_base"), - ("text-search-davinci-doc-001", "r50k_base"), - ("text-search-curie-doc-001", "r50k_base"), - ("text-search-babbage-doc-001", "r50k_base"), - ("text-search-ada-doc-001", "r50k_base"), - ("code-search-babbage-code-001", "r50k_base"), - ("code-search-ada-code-001", "r50k_base"), - // open source - ("gpt2", "gpt2"), - ] - .into_iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect::>(); + pub static ref MODEL_TO_ENCODING: HashMap = + json::parse(include_str!("../../tiktoken/model_to_encoding.json")) + .expect("Failed to parse internal JSON") + .entries() + .map(|(k, v)| (k.into(), v.as_str().expect("foo").into())) + .collect::>(); } #[derive(Clone, PartialEq, Eq, Hash)] diff --git a/java/pom.xml b/java/pom.xml index c3a3288..b93feb9 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -10,6 +10,7 @@ tiktoken https://github.com/openai/tiktoken + jar UTF-8 @@ -24,9 +25,23 @@ 4.11 test + + org.scijava + native-lib-loader + 2.4.0 + + + + ${project.build.directory}/../../target/release/ + ${project.build.directory}/classes/natives/linux_64 + + lib_tiktoken_jni.so + + + @@ -69,22 +84,18 @@ 3.0.0 - org.apache.maven.plugins - maven-surefire-plugin - 2.17 - - - surefire-test - test - - test - - - - - -Djava.library.path=${project.build.directory}/../../target/debug/ - - + org.apache.maven.plugins + maven-failsafe-plugin + 2.22.1 + + + + integration-test + verify + + + + diff --git a/java/src/main/java/tiktoken/Encoding.java b/java/src/main/java/tiktoken/Encoding.java index cc041e4..bfc7586 100644 --- a/java/src/main/java/tiktoken/Encoding.java +++ b/java/src/main/java/tiktoken/Encoding.java @@ -1,9 +1,19 @@ package tiktoken; +import org.scijava.nativelib.NativeLoader; +import java.io.IOException; + public class Encoding implements AutoCloseable { static { - System.loadLibrary("_tiktoken_jni"); + // TODO: unpack the library from the jar + // System.loadLibrary("_tiktoken_jni"); + try { + NativeLoader.loadLibrary("_tiktoken_jni"); + } + catch(IOException e) { + throw new RuntimeException(e); + } } // initialized by init @@ -11,10 +21,9 @@ public class Encoding implements AutoCloseable private native void init(String modelName); - public native long[] encode(String text, String[] allowedSpecialTokens, long maxTokenLength); - private native void destroy(); + public native long[] encode(String text, String[] allowedSpecialTokens, long maxTokenLength); public Encoding(String modelName) { this.init(modelName); diff --git a/java/src/test/java/tiktoken/EncodingTest.java b/java/src/test/java/tiktoken/EncodingTestIT.java similarity index 79% rename from java/src/test/java/tiktoken/EncodingTest.java rename to java/src/test/java/tiktoken/EncodingTestIT.java index 591c526..602a1ef 100644 --- a/java/src/test/java/tiktoken/EncodingTest.java +++ b/java/src/test/java/tiktoken/EncodingTestIT.java @@ -1,11 +1,11 @@ package tiktoken; import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertTrue; import org.junit.Test; -public class EncodingTest +// run test: mvn failsafe:integration-test +public class EncodingTestIT { @Test public void shouldAnswerWithTrue() throws Exception @@ -16,7 +16,6 @@ public class EncodingTest encoding.close(); - assertTrue( true ); assertArrayEquals(new long[] {9288}, a); } } diff --git a/java/src/tiktoken_Encoding.h b/java/src/tiktoken_Encoding.h new file mode 100644 index 0000000..030d77f --- /dev/null +++ b/java/src/tiktoken_Encoding.h @@ -0,0 +1,29 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class tiktoken_Encoding */ + +#ifndef _Included_tiktoken_Encoding +#define _Included_tiktoken_Encoding +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: tiktoken_Encoding + * Method: init + * Signature: (Ljava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_tiktoken_Encoding_init + (JNIEnv *, jobject, jstring); + +/* + * Class: tiktoken_Encoding + * Method: encode + * Signature: (Ljava/lang/String;[Ljava/lang/String;J)[J + */ +JNIEXPORT jlongArray JNICALL Java_tiktoken_Encoding_encode + (JNIEnv *, jobject, jstring, jobjectArray, jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/jni/Cargo.toml b/jni/Cargo.toml index 0a7651d..9ac05ad 100644 --- a/jni/Cargo.toml +++ b/jni/Cargo.toml @@ -15,3 +15,8 @@ jni = "0.20.0" [profile.release] incremental = true +opt-level = 'z' # Optimize for size +lto = true # Enable link-time optimization +codegen-units = 1 # Reduce number of codegen units to increase optimizations +panic = 'abort' # Abort on panic +strip = true # Strip symbols from binary* \ No newline at end of file diff --git a/setup.py b/setup.py index 96ad5d6..246487b 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,8 @@ setup( debug=False, ) ], - package_data={"tiktoken": ["py.typed"]}, - packages=["tiktoken", "tiktoken_ext"], + include_package_data=True, + package_data={ "tiktoken": ["py.typed", "registry.json", "model_to_encoding.json"] }, + packages=["tiktoken"], zip_safe=False, ) diff --git a/tests/test_simple_public.py b/tests/test_simple_public.py index ab63bab..4410923 100644 --- a/tests/test_simple_public.py +++ b/tests/test_simple_public.py @@ -24,18 +24,3 @@ def test_encoding_for_model(): assert enc.name == "gpt2" enc = tiktoken.encoding_for_model("text-davinci-003") assert enc.name == "p50k_base" - -def test_loading(): - x = tiktoken.load.data_gym_to_mergeable_bpe_ranks( - vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe", - encoder_json_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json", - ) - - print(len(x)) - - y = tiktoken._tiktoken.py_data_gym_to_mergable_bpe_ranks( - vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe", - encoder_json_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json", - ) - - print(len(y)) \ No newline at end of file diff --git a/tiktoken/load.py b/tiktoken/load.py index c8f3dbd..5537ecf 100644 --- a/tiktoken/load.py +++ b/tiktoken/load.py @@ -55,7 +55,6 @@ def data_gym_to_mergeable_bpe_ranks( # NB: do not add caching to this function rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "] - print(f"rank_to_intbyte: {len(rank_to_intbyte)}") data_gym_byte_to_byte = {chr(b): b for b in rank_to_intbyte} n = 0 for b in range(2**8): @@ -75,9 +74,6 @@ def data_gym_to_mergeable_bpe_ranks( # add the single byte tokens bpe_ranks = {bytes([b]): i for i, b in enumerate(rank_to_intbyte)} - # print(len(rank_to_intbyte)) - print(f"py data gym: {len(data_gym_byte_to_byte)} '{data_gym_byte_to_byte[chr(288)]}'") - # add the merged tokens n = len(bpe_ranks) for first, second in bpe_merges: diff --git a/tiktoken/model.py b/tiktoken/model.py index 66e9e04..b3d3ba5 100644 --- a/tiktoken/model.py +++ b/tiktoken/model.py @@ -2,47 +2,16 @@ from __future__ import annotations from .core import Encoding from .registry import get_encoding +import json + +try: + import importlib.resources as pkg_resources +except ImportError: + # Try backported to PY<37 `importlib_resources`. + import importlib_resources as pkg_resources # TODO: this will likely be replaced by an API endpoint -MODEL_TO_ENCODING: dict[str, str] = { - # text - "text-davinci-003": "p50k_base", - "text-davinci-002": "p50k_base", - "text-davinci-001": "r50k_base", - "text-curie-001": "r50k_base", - "text-babbage-001": "r50k_base", - "text-ada-001": "r50k_base", - "davinci": "r50k_base", - "curie": "r50k_base", - "babbage": "r50k_base", - "ada": "r50k_base", - # code - "code-davinci-002": "p50k_base", - "code-davinci-001": "p50k_base", - "code-cushman-002": "p50k_base", - "code-cushman-001": "p50k_base", - "davinci-codex": "p50k_base", - "cushman-codex": "p50k_base", - # edit - "text-davinci-edit-001": "p50k_edit", - "code-davinci-edit-001": "p50k_edit", - # embeddings - "text-embedding-ada-002": "cl100k_base", - # old embeddings - "text-similarity-davinci-001": "r50k_base", - "text-similarity-curie-001": "r50k_base", - "text-similarity-babbage-001": "r50k_base", - "text-similarity-ada-001": "r50k_base", - "text-search-davinci-doc-001": "r50k_base", - "text-search-curie-doc-001": "r50k_base", - "text-search-babbage-doc-001": "r50k_base", - "text-search-ada-doc-001": "r50k_base", - "code-search-babbage-code-001": "r50k_base", - "code-search-ada-code-001": "r50k_base", - # open source - "gpt2": "gpt2", -} - +MODEL_TO_ENCODING: dict[str, str] = json.loads(pkg_resources.read_text("tiktoken", "model_to_encoding.json")) def encoding_for_model(model_name: str) -> Encoding: try: diff --git a/tiktoken/model_to_encoding.json b/tiktoken/model_to_encoding.json new file mode 100644 index 0000000..987ba14 --- /dev/null +++ b/tiktoken/model_to_encoding.json @@ -0,0 +1,32 @@ +{ + "text-davinci-003": "p50k_base", + "text-davinci-002": "p50k_base", + "text-davinci-001": "r50k_base", + "text-curie-001": "r50k_base", + "text-babbage-001": "r50k_base", + "text-ada-001": "r50k_base", + "davinci": "r50k_base", + "curie": "r50k_base", + "babbage": "r50k_base", + "ada": "r50k_base", + "code-davinci-002": "p50k_base", + "code-davinci-001": "p50k_base", + "code-cushman-002": "p50k_base", + "code-cushman-001": "p50k_base", + "davinci-codex": "p50k_base", + "cushman-codex": "p50k_base", + "text-davinci-edit-001": "p50k_edit", + "code-davinci-edit-001": "p50k_edit", + "text-embedding-ada-002": "cl100k_base", + "text-similarity-davinci-001": "r50k_base", + "text-similarity-curie-001": "r50k_base", + "text-similarity-babbage-001": "r50k_base", + "text-similarity-ada-001": "r50k_base", + "text-search-davinci-doc-001": "r50k_base", + "text-search-curie-doc-001": "r50k_base", + "text-search-babbage-doc-001": "r50k_base", + "text-search-ada-doc-001": "r50k_base", + "code-search-babbage-code-001": "r50k_base", + "code-search-ada-code-001": "r50k_base", + "gpt2": "gpt2" +} \ No newline at end of file diff --git a/tiktoken/registry.json b/tiktoken/registry.json new file mode 100644 index 0000000..aa3ee53 --- /dev/null +++ b/tiktoken/registry.json @@ -0,0 +1,50 @@ +{ + "gpt2": { + "data_gym_to_mergeable_bpe_ranks": { + "vocab_bpe_file": "https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe", + "encoder_json_file": "https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json" + }, + "explicit_n_vocab": 50257, + "pat_str": "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+", + "special_tokens": { + "<|endoftext|>": 50256 + } + }, + "r50k_base": { + "load_tiktoken_bpe": "https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken", + "explicit_n_vocab": 50257, + "pat_str": "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+", + "special_tokens": { + "<|endoftext|>": 50256 + } + }, + "p50k_base": { + "load_tiktoken_bpe": "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken", + "explicit_n_vocab": 50281, + "pat_str": "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+", + "special_tokens": { + "<|endoftext|>": 50256 + } + }, + "p50k_edit": { + "load_tiktoken_bpe": "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken", + "special_tokens": { + "<|endoftext|>": 50256, + "<|fim_prefix|>": 50281, + "<|fim_middle|>": 50282, + "<|fim_suffix|>": 50283 + }, + "pat_str": "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+" + }, + "cl100k_base": { + "load_tiktoken_bpe": "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken", + "special_tokens": { + "<|endoftext|>": 100257, + "<|fim_prefix|>": 100258, + "<|fim_middle|>": 100259, + "<|fim_suffix|>": 100260, + "<|endofprompt|>": 100276 + }, + "pat_str": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + } +} \ No newline at end of file diff --git a/tiktoken/registry.py b/tiktoken/registry.py index 52d8ec2..0a55d27 100644 --- a/tiktoken/registry.py +++ b/tiktoken/registry.py @@ -3,46 +3,32 @@ from __future__ import annotations import importlib import pkgutil import threading +import json from typing import Any, Callable, Optional -import tiktoken_ext - from tiktoken.core import Encoding +from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe _lock = threading.RLock() ENCODINGS: dict[str, Encoding] = {} -ENCODING_CONSTRUCTORS: Optional[dict[str, Callable[[], dict[str, Any]]]] = None +ENCODING_DEFS: dict[str, Any] = None +def _load_encoding_defs(): + global ENCODING_DEFS + if not ENCODING_DEFS is None: + return ENCODING_DEFS -def _find_constructors() -> None: - global ENCODING_CONSTRUCTORS - with _lock: - if ENCODING_CONSTRUCTORS is not None: - return - ENCODING_CONSTRUCTORS = {} + try: + import importlib.resources as pkg_resources + except ImportError: + # Try backported to PY<37 `importlib_resources`. + import importlib_resources as pkg_resources - # tiktoken_ext is a namespace package - # submodules inside tiktoken_ext will be inspected for ENCODING_CONSTRUCTORS attributes - # - we use namespace package pattern so `pkgutil.iter_modules` is fast - # - it's a separate top-level package because namespace subpackages of non-namespace - # packages don't quite do what you want with editable installs - plugin_mods = pkgutil.iter_modules(tiktoken_ext.__path__, tiktoken_ext.__name__ + ".") - - for _, mod_name, _ in plugin_mods: - mod = importlib.import_module(mod_name) - try: - constructors = mod.ENCODING_CONSTRUCTORS - except AttributeError as e: - raise ValueError( - f"tiktoken plugin {mod_name} does not define ENCODING_CONSTRUCTORS" - ) from e - for enc_name, constructor in constructors.items(): - if enc_name in ENCODING_CONSTRUCTORS: - raise ValueError( - f"Duplicate encoding name {enc_name} in tiktoken plugin {mod_name}" - ) - ENCODING_CONSTRUCTORS[enc_name] = constructor + # read registry.json + # note: was trying to place it into /data/registry.json but python packaging is always unhappy + ENCODING_DEFS = json.loads(pkg_resources.read_text("tiktoken", "registry.json")) + return ENCODING_DEFS def get_encoding(encoding_name: str) -> Encoding: if encoding_name in ENCODINGS: @@ -52,22 +38,26 @@ def get_encoding(encoding_name: str) -> Encoding: if encoding_name in ENCODINGS: return ENCODINGS[encoding_name] - if ENCODING_CONSTRUCTORS is None: - _find_constructors() - assert ENCODING_CONSTRUCTORS is not None - - if encoding_name not in ENCODING_CONSTRUCTORS: + _load_encoding_defs() + if encoding_name not in ENCODING_DEFS: raise ValueError(f"Unknown encoding {encoding_name}") - constructor = ENCODING_CONSTRUCTORS[encoding_name] - enc = Encoding(**constructor()) + encoding_def = dict(ENCODING_DEFS[encoding_name]) + encoding_def["name"] = encoding_name + + if "load_tiktoken_bpe" in encoding_def: + encoding_def["mergeable_ranks"] = load_tiktoken_bpe(encoding_def["load_tiktoken_bpe"]) + del encoding_def["load_tiktoken_bpe"] + elif "data_gym_to_mergeable_bpe_ranks" in encoding_def: + encoding_def["mergeable_ranks"] = data_gym_to_mergeable_bpe_ranks(**encoding_def["data_gym_to_mergeable_bpe_ranks"]) + del encoding_def["data_gym_to_mergeable_bpe_ranks"] + else: + raise ValueError(f"Unknown loader {encoding_name}") + enc = Encoding(**encoding_def) ENCODINGS[encoding_name] = enc return enc def list_encoding_names() -> list[str]: with _lock: - if ENCODING_CONSTRUCTORS is None: - _find_constructors() - assert ENCODING_CONSTRUCTORS is not None - return list(ENCODING_CONSTRUCTORS) + return list(_load_encoding_defs().keys()) diff --git a/tiktoken_ext/openai_public.py b/tiktoken_ext/openai_public.py deleted file mode 100644 index a64db9f..0000000 --- a/tiktoken_ext/openai_public.py +++ /dev/null @@ -1,87 +0,0 @@ -from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe - -ENDOFTEXT = "<|endoftext|>" -FIM_PREFIX = "<|fim_prefix|>" -FIM_MIDDLE = "<|fim_middle|>" -FIM_SUFFIX = "<|fim_suffix|>" -ENDOFPROMPT = "<|endofprompt|>" - - -def gpt2(): - mergeable_ranks = data_gym_to_mergeable_bpe_ranks( - vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe", - encoder_json_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json", - ) - return { - "name": "gpt2", - "explicit_n_vocab": 50257, - "pat_str": r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", - "mergeable_ranks": mergeable_ranks, - "special_tokens": {"<|endoftext|>": 50256}, - } - - -def r50k_base(): - mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken" - ) - return { - "name": "r50k_base", - "explicit_n_vocab": 50257, - "pat_str": r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", - "mergeable_ranks": mergeable_ranks, - "special_tokens": {ENDOFTEXT: 50256}, - } - - -def p50k_base(): - mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken" - ) - return { - "name": "p50k_base", - "explicit_n_vocab": 50281, - "pat_str": r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", - "mergeable_ranks": mergeable_ranks, - "special_tokens": {ENDOFTEXT: 50256}, - } - - -def p50k_edit(): - mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken" - ) - special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283} - return { - "name": "p50k_edit", - "pat_str": r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", - "mergeable_ranks": mergeable_ranks, - "special_tokens": special_tokens, - } - - -def cl100k_base(): - mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken" - ) - special_tokens = { - ENDOFTEXT: 100257, - FIM_PREFIX: 100258, - FIM_MIDDLE: 100259, - FIM_SUFFIX: 100260, - ENDOFPROMPT: 100276, - } - return { - "name": "cl100k_base", - "pat_str": r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""", - "mergeable_ranks": mergeable_ranks, - "special_tokens": special_tokens, - } - - -ENCODING_CONSTRUCTORS = { - "gpt2": gpt2, - "r50k_base": r50k_base, - "p50k_base": p50k_base, - "cl100k_base": cl100k_base, -}