Merge pull request #1 from eisber/error_handling

Improve error handling in JNI functions
This commit is contained in:
Markus Cozowicz
2023-02-27 08:51:11 +01:00
committed by GitHub
+16 -19
View File
@@ -15,11 +15,11 @@ use jni::sys::{jarray, jlong};
use _tiktoken_core::{self, CoreBPENative};
use jni::errors::Error;
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
fn unwrap_or_throw<T>(env: &JNIEnv, result: Result<T, Error>, default: T) -> T {
fn unwrap_or_throw<T>(env: &JNIEnv, result: Result<T>, default: T) -> T {
// Check if an exception is already thrown
if env.exception_check().unwrap() {
if env.exception_check().expect("exception_check() failed") {
return default;
}
@@ -28,9 +28,9 @@ fn unwrap_or_throw<T>(env: &JNIEnv, result: Result<T, Error>, default: T) -> T {
Err(error) => {
let exception_class = env
.find_class("java/lang/Exception")
.unwrap();
.expect("Unable to find exception class");
env.throw_new(exception_class, format!("{}", error))
.unwrap();
.expect("Unable to throw exception");
default
}
}
@@ -38,7 +38,7 @@ fn unwrap_or_throw<T>(env: &JNIEnv, result: Result<T, Error>, default: T) -> T {
#[no_mangle]
pub extern "system" fn Java_tiktoken_Encoding_init(env: JNIEnv, obj: JObject, model_name: JString) {
let result = || -> Result<(), Error> {
let result = || -> Result<()> {
// First, we have to get the string out of Java. Check out the `strings`
// module for more info on how this works.
let model_name: String = env
@@ -46,27 +46,24 @@ pub extern "system" fn Java_tiktoken_Encoding_init(env: JNIEnv, obj: JObject, mo
.into();
let encoding_name = _tiktoken_core::openai_public::MODEL_TO_ENCODING
.get(&model_name)
.expect("Unable to find model");
.get(&model_name).ok_or("Unable to find model")?;
// TODO: this is actually mergable_ranks (lazy)
let mut encoding = _tiktoken_core::openai_public::REGISTRY
.get(encoding_name)
.expect("Unable to find encoding");
let encoding = _tiktoken_core::openai_public::REGISTRY
.get(encoding_name).ok_or("Unable to find encoding")?;
// TODO: initialize the CoreBPE object
// TODO: this should be CoreBPE
let bpe_native = CoreBPENative::new(
encoding.get().unwrap(),
encoding.get()?,
encoding.special_tokens.clone(),
&encoding.pat_str,
)
.unwrap();
)?;
Ok(unsafe {
env.set_rust_field(obj, "handle", bpe_native).unwrap();
env.set_rust_field(obj, "handle", bpe_native)?;
})
}();
@@ -76,7 +73,7 @@ pub extern "system" fn Java_tiktoken_Encoding_init(env: JNIEnv, obj: JObject, mo
#[no_mangle]
pub extern "system" fn Java_tiktoken_Encoding_destroy(env: JNIEnv, obj: JObject) {
unsafe {
let _: CoreBPENative = env.take_rust_field(obj, "handle").unwrap();
let _: CoreBPENative = env.take_rust_field(obj, "handle").expect("Unable to get handle during destruction");
}
}
@@ -88,7 +85,7 @@ pub extern "system" fn Java_tiktoken_Encoding_encode(
allowedSpecialTokens: jarray,
maxTokenLength: jlong,
) -> jarray {
let result = || -> Result<jarray, Error> {
let result = || -> Result<jarray> {
let encoding: MutexGuard<CoreBPENative> = unsafe { env.get_rust_field(obj, "handle")? };
let enc = encoding;
@@ -109,8 +106,8 @@ pub extern "system" fn Java_tiktoken_Encoding_encode(
let (tokens, _, _) = enc._encode_native(&input, &v2, Some(maxTokenLength as usize));
let mut output = env
.new_long_array(tokens.len().try_into().unwrap())?;
let output = env
.new_long_array(tokens.len().try_into()?)?;
let array_of_u64 = tokens.iter().map(|x| *x as i64).collect::<Vec<i64>>();
env.set_long_array_region(output, 0, array_of_u64.as_slice())?;