mirror of
https://github.com/Mintplex-Labs/tiktoken.git
synced 2026-07-01 18:48:04 -04:00
Merge pull request #1 from eisber/error_handling
Improve error handling in JNI functions
This commit is contained in:
+16
-19
@@ -15,11 +15,11 @@ use jni::sys::{jarray, jlong};
|
||||
|
||||
use _tiktoken_core::{self, CoreBPENative};
|
||||
|
||||
use jni::errors::Error;
|
||||
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
|
||||
|
||||
fn unwrap_or_throw<T>(env: &JNIEnv, result: Result<T, Error>, default: T) -> T {
|
||||
fn unwrap_or_throw<T>(env: &JNIEnv, result: Result<T>, default: T) -> T {
|
||||
// Check if an exception is already thrown
|
||||
if env.exception_check().unwrap() {
|
||||
if env.exception_check().expect("exception_check() failed") {
|
||||
return default;
|
||||
}
|
||||
|
||||
@@ -28,9 +28,9 @@ fn unwrap_or_throw<T>(env: &JNIEnv, result: Result<T, Error>, default: T) -> T {
|
||||
Err(error) => {
|
||||
let exception_class = env
|
||||
.find_class("java/lang/Exception")
|
||||
.unwrap();
|
||||
.expect("Unable to find exception class");
|
||||
env.throw_new(exception_class, format!("{}", error))
|
||||
.unwrap();
|
||||
.expect("Unable to throw exception");
|
||||
default
|
||||
}
|
||||
}
|
||||
@@ -38,7 +38,7 @@ fn unwrap_or_throw<T>(env: &JNIEnv, result: Result<T, Error>, default: T) -> T {
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "system" fn Java_tiktoken_Encoding_init(env: JNIEnv, obj: JObject, model_name: JString) {
|
||||
let result = || -> Result<(), Error> {
|
||||
let result = || -> Result<()> {
|
||||
// First, we have to get the string out of Java. Check out the `strings`
|
||||
// module for more info on how this works.
|
||||
let model_name: String = env
|
||||
@@ -46,27 +46,24 @@ pub extern "system" fn Java_tiktoken_Encoding_init(env: JNIEnv, obj: JObject, mo
|
||||
.into();
|
||||
|
||||
let encoding_name = _tiktoken_core::openai_public::MODEL_TO_ENCODING
|
||||
.get(&model_name)
|
||||
.expect("Unable to find model");
|
||||
.get(&model_name).ok_or("Unable to find model")?;
|
||||
|
||||
// TODO: this is actually mergable_ranks (lazy)
|
||||
let mut encoding = _tiktoken_core::openai_public::REGISTRY
|
||||
.get(encoding_name)
|
||||
.expect("Unable to find encoding");
|
||||
let encoding = _tiktoken_core::openai_public::REGISTRY
|
||||
.get(encoding_name).ok_or("Unable to find encoding")?;
|
||||
|
||||
// TODO: initialize the CoreBPE object
|
||||
|
||||
// TODO: this should be CoreBPE
|
||||
|
||||
let bpe_native = CoreBPENative::new(
|
||||
encoding.get().unwrap(),
|
||||
encoding.get()?,
|
||||
encoding.special_tokens.clone(),
|
||||
&encoding.pat_str,
|
||||
)
|
||||
.unwrap();
|
||||
)?;
|
||||
|
||||
Ok(unsafe {
|
||||
env.set_rust_field(obj, "handle", bpe_native).unwrap();
|
||||
env.set_rust_field(obj, "handle", bpe_native)?;
|
||||
})
|
||||
}();
|
||||
|
||||
@@ -76,7 +73,7 @@ pub extern "system" fn Java_tiktoken_Encoding_init(env: JNIEnv, obj: JObject, mo
|
||||
#[no_mangle]
|
||||
pub extern "system" fn Java_tiktoken_Encoding_destroy(env: JNIEnv, obj: JObject) {
|
||||
unsafe {
|
||||
let _: CoreBPENative = env.take_rust_field(obj, "handle").unwrap();
|
||||
let _: CoreBPENative = env.take_rust_field(obj, "handle").expect("Unable to get handle during destruction");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,7 +85,7 @@ pub extern "system" fn Java_tiktoken_Encoding_encode(
|
||||
allowedSpecialTokens: jarray,
|
||||
maxTokenLength: jlong,
|
||||
) -> jarray {
|
||||
let result = || -> Result<jarray, Error> {
|
||||
let result = || -> Result<jarray> {
|
||||
let encoding: MutexGuard<CoreBPENative> = unsafe { env.get_rust_field(obj, "handle")? };
|
||||
|
||||
let enc = encoding;
|
||||
@@ -109,8 +106,8 @@ pub extern "system" fn Java_tiktoken_Encoding_encode(
|
||||
|
||||
let (tokens, _, _) = enc._encode_native(&input, &v2, Some(maxTokenLength as usize));
|
||||
|
||||
let mut output = env
|
||||
.new_long_array(tokens.len().try_into().unwrap())?;
|
||||
let output = env
|
||||
.new_long_array(tokens.len().try_into()?)?;
|
||||
|
||||
let array_of_u64 = tokens.iter().map(|x| *x as i64).collect::<Vec<i64>>();
|
||||
env.set_long_array_region(output, 0, array_of_u64.as_slice())?;
|
||||
|
||||
Reference in New Issue
Block a user