mirror of
https://github.com/langchain-ai/arrow-rs.git
synced 2026-07-01 21:34:01 -04:00
[Variant] Introduce new type over &str for ShortString (#7718)
# Which issue does this PR close? - Closes https://github.com/apache/arrow-rs/issues/7700 This commit introduces `ShortString`, a newtype that wraps around `&str` that enforces a maximum length constraint. This also allows us to perform validation once and removes a superfluous validation check in `append_value`. The now-superflous validation check was needed since users could construct `Variant::ShortString`s directly, without doing input validation. This means you can have a short string variant which actually contains a string that is no longer than 63 bytes. But since we enforce this check upon construction, we can directly match against `Variant::String` and `Variant::ShortString` arms with their respective appending functions (`append_string` and `append_short_string`).
This commit is contained in:
@@ -15,11 +15,10 @@
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
use crate::decoder::{VariantBasicType, VariantPrimitiveType};
|
||||
use crate::Variant;
|
||||
use crate::{ShortString, Variant};
|
||||
use std::collections::HashMap;
|
||||
|
||||
const BASIC_TYPE_BITS: u8 = 2;
|
||||
const MAX_SHORT_STRING_SIZE: usize = 0x3F;
|
||||
const UNIX_EPOCH_DATE: chrono::NaiveDate = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
|
||||
|
||||
fn primitive_header(primitive_type: VariantPrimitiveType) -> u8 {
|
||||
@@ -114,11 +113,11 @@ fn make_room_for_header(buffer: &mut Vec<u8>, start_pos: usize, header_size: usi
|
||||
/// };
|
||||
/// assert_eq!(
|
||||
/// variant_object.field_by_name("first_name").unwrap(),
|
||||
/// Some(Variant::ShortString("Jiaying"))
|
||||
/// Some(Variant::from("Jiaying"))
|
||||
/// );
|
||||
/// assert_eq!(
|
||||
/// variant_object.field_by_name("last_name").unwrap(),
|
||||
/// Some(Variant::ShortString("Li"))
|
||||
/// Some(Variant::from("Li"))
|
||||
/// );
|
||||
/// ```
|
||||
///
|
||||
@@ -281,17 +280,18 @@ impl VariantBuilder {
|
||||
self.buffer.extend_from_slice(value);
|
||||
}
|
||||
|
||||
fn append_short_string(&mut self, value: ShortString) {
|
||||
let inner = value.0;
|
||||
self.buffer.push(short_string_header(inner.len()));
|
||||
self.buffer.extend_from_slice(inner.as_bytes());
|
||||
}
|
||||
|
||||
fn append_string(&mut self, value: &str) {
|
||||
if value.len() <= MAX_SHORT_STRING_SIZE {
|
||||
self.buffer.push(short_string_header(value.len()));
|
||||
self.buffer.extend_from_slice(value.as_bytes());
|
||||
} else {
|
||||
self.buffer
|
||||
.push(primitive_header(VariantPrimitiveType::String));
|
||||
self.buffer
|
||||
.extend_from_slice(&(value.len() as u32).to_le_bytes());
|
||||
self.buffer.extend_from_slice(value.as_bytes());
|
||||
}
|
||||
self.buffer
|
||||
.push(primitive_header(VariantPrimitiveType::String));
|
||||
self.buffer
|
||||
.extend_from_slice(&(value.len() as u32).to_le_bytes());
|
||||
self.buffer.extend_from_slice(value.as_bytes());
|
||||
}
|
||||
|
||||
/// Add key to dictionary, return its ID
|
||||
@@ -390,7 +390,8 @@ impl VariantBuilder {
|
||||
Variant::Float(v) => self.append_float(v),
|
||||
Variant::Double(v) => self.append_double(v),
|
||||
Variant::Binary(v) => self.append_binary(v),
|
||||
Variant::String(s) | Variant::ShortString(s) => self.append_string(s),
|
||||
Variant::String(s) => self.append_string(s),
|
||||
Variant::ShortString(s) => self.append_short_string(s),
|
||||
Variant::Object(_) | Variant::List(_) => {
|
||||
unreachable!("Object and List variants cannot be created through Into<Variant>")
|
||||
}
|
||||
@@ -639,7 +640,7 @@ mod tests {
|
||||
builder.append_value("hello");
|
||||
let (metadata, value) = builder.finish();
|
||||
let variant = Variant::try_new(&metadata, &value).unwrap();
|
||||
assert_eq!(variant, Variant::ShortString("hello"));
|
||||
assert_eq!(variant, Variant::ShortString(ShortString("hello")));
|
||||
}
|
||||
|
||||
{
|
||||
@@ -688,7 +689,7 @@ mod tests {
|
||||
assert_eq!(val1, Variant::Int8(2));
|
||||
|
||||
let val2 = list.get(2).unwrap();
|
||||
assert_eq!(val2, Variant::ShortString("test"));
|
||||
assert_eq!(val2, Variant::ShortString(ShortString("test")));
|
||||
}
|
||||
_ => panic!("Expected an array variant, got: {:?}", variant),
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
use crate::utils::{array_from_slice, slice_from_slice, string_from_slice};
|
||||
use crate::ShortString;
|
||||
|
||||
use arrow_schema::ArrowError;
|
||||
use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, Utc};
|
||||
@@ -273,10 +274,10 @@ pub(crate) fn decode_long_string(data: &[u8]) -> Result<&str, ArrowError> {
|
||||
}
|
||||
|
||||
/// Decodes a short string from the value section of a variant.
|
||||
pub(crate) fn decode_short_string(metadata: u8, data: &[u8]) -> Result<&str, ArrowError> {
|
||||
pub(crate) fn decode_short_string(metadata: u8, data: &[u8]) -> Result<ShortString, ArrowError> {
|
||||
let len = (metadata >> 2) as usize;
|
||||
let string = string_from_slice(data, 0..len)?;
|
||||
Ok(string)
|
||||
ShortString::try_new(string)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -420,7 +421,7 @@ mod tests {
|
||||
fn test_short_string() -> Result<(), ArrowError> {
|
||||
let data = [b'H', b'e', b'l', b'l', b'o', b'o'];
|
||||
let result = decode_short_string(1 | 5 << 2, &data)?;
|
||||
assert_eq!(result, "Hello");
|
||||
assert_eq!(result.0, "Hello");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::ops::Deref;
|
||||
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
@@ -29,6 +31,65 @@ mod list;
|
||||
mod metadata;
|
||||
mod object;
|
||||
|
||||
const MAX_SHORT_STRING_BYTES: usize = 0x3F;
|
||||
|
||||
/// A Variant [`ShortString`]
|
||||
///
|
||||
/// This implementation is a zero cost wrapper over `&str` that ensures
|
||||
/// the length of the underlying string is a valid Variant short string (63 bytes or less)
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct ShortString<'a>(pub(crate) &'a str);
|
||||
|
||||
impl<'a> ShortString<'a> {
|
||||
/// Attempts to interpret `value` as a variant short string value.
|
||||
///
|
||||
/// # Validation
|
||||
///
|
||||
/// This constructor verifies that `value` is shorter than or equal to `MAX_SHORT_STRING_BYTES`
|
||||
pub fn try_new(value: &'a str) -> Result<Self, ArrowError> {
|
||||
if value.len() > MAX_SHORT_STRING_BYTES {
|
||||
return Err(ArrowError::InvalidArgumentError(format!(
|
||||
"value is larger than {MAX_SHORT_STRING_BYTES} bytes"
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Self(value))
|
||||
}
|
||||
|
||||
/// Returns the underlying Variant short string as a &str
|
||||
pub fn as_str(&self) -> &'a str {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<ShortString<'a>> for &'a str {
|
||||
fn from(value: ShortString<'a>) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<&'a str> for ShortString<'a> {
|
||||
type Error = ArrowError;
|
||||
|
||||
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
|
||||
Self::try_new(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> AsRef<str> for ShortString<'a> {
|
||||
fn as_ref(&self) -> &str {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Deref for ShortString<'a> {
|
||||
type Target = str;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a [Parquet Variant]
|
||||
///
|
||||
/// The lifetimes `'m` and `'v` are for metadata and value buffers, respectively.
|
||||
@@ -85,7 +146,7 @@ mod object;
|
||||
///
|
||||
/// ## Creating `Variant` from Rust Types
|
||||
/// ```
|
||||
/// # use parquet_variant::Variant;
|
||||
/// use parquet_variant::Variant;
|
||||
/// // variants can be directly constructed
|
||||
/// let variant = Variant::Int32(123);
|
||||
/// // or constructed via `From` impls
|
||||
@@ -98,7 +159,7 @@ mod object;
|
||||
/// let value = [0x09, 0x48, 0x49];
|
||||
/// // parse the header metadata
|
||||
/// assert_eq!(
|
||||
/// Variant::ShortString("HI"),
|
||||
/// Variant::from("HI"),
|
||||
/// Variant::try_new(&metadata, &value).unwrap()
|
||||
/// );
|
||||
/// ```
|
||||
@@ -152,7 +213,7 @@ pub enum Variant<'m, 'v> {
|
||||
/// Primitive (type_id=1): STRING
|
||||
String(&'v str),
|
||||
/// Short String (type_id=2): STRING
|
||||
ShortString(&'v str),
|
||||
ShortString(ShortString<'v>),
|
||||
// need both metadata & value
|
||||
/// Object (type_id=3): N/A
|
||||
Object(VariantObject<'m, 'v>),
|
||||
@@ -165,12 +226,12 @@ impl<'m, 'v> Variant<'m, 'v> {
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use parquet_variant::{Variant, VariantMetadata};
|
||||
/// use parquet_variant::{Variant, VariantMetadata};
|
||||
/// let metadata = [0x01, 0x00, 0x00];
|
||||
/// let value = [0x09, 0x48, 0x49];
|
||||
/// // parse the header metadata
|
||||
/// assert_eq!(
|
||||
/// Variant::ShortString("HI"),
|
||||
/// Variant::from("HI"),
|
||||
/// Variant::try_new(&metadata, &value).unwrap()
|
||||
/// );
|
||||
/// ```
|
||||
@@ -189,7 +250,7 @@ impl<'m, 'v> Variant<'m, 'v> {
|
||||
/// // parse the header metadata first
|
||||
/// let metadata = VariantMetadata::try_new(&metadata).unwrap();
|
||||
/// assert_eq!(
|
||||
/// Variant::ShortString("HI"),
|
||||
/// Variant::from("HI"),
|
||||
/// Variant::try_new_with_metadata(metadata, &value).unwrap()
|
||||
/// );
|
||||
/// ```
|
||||
@@ -432,7 +493,7 @@ impl<'m, 'v> Variant<'m, 'v> {
|
||||
///
|
||||
/// // you can extract a string from string variants
|
||||
/// let s = "hello!";
|
||||
/// let v1 = Variant::ShortString(s);
|
||||
/// let v1 = Variant::from(s);
|
||||
/// assert_eq!(v1.as_string(), Some(s));
|
||||
///
|
||||
/// // but not from other variants
|
||||
@@ -441,7 +502,7 @@ impl<'m, 'v> Variant<'m, 'v> {
|
||||
/// ```
|
||||
pub fn as_string(&'v self) -> Option<&'v str> {
|
||||
match self {
|
||||
Variant::String(s) | Variant::ShortString(s) => Some(s),
|
||||
Variant::String(s) | Variant::ShortString(ShortString(s)) => Some(s),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -861,10 +922,25 @@ impl<'v> From<&'v [u8]> for Variant<'_, 'v> {
|
||||
|
||||
impl<'v> From<&'v str> for Variant<'_, 'v> {
|
||||
fn from(value: &'v str) -> Self {
|
||||
if value.len() < 64 {
|
||||
Variant::ShortString(value)
|
||||
} else {
|
||||
if value.len() > MAX_SHORT_STRING_BYTES {
|
||||
Variant::String(value)
|
||||
} else {
|
||||
Variant::ShortString(ShortString(value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_construct_short_string() {
|
||||
let short_string = ShortString::try_new("norm").expect("should fit in short string");
|
||||
assert_eq!(short_string.as_str(), "norm");
|
||||
|
||||
let long_string = "a".repeat(MAX_SHORT_STRING_BYTES + 1);
|
||||
let res = ShortString::try_new(&long_string);
|
||||
assert!(res.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use chrono::NaiveDate;
|
||||
use parquet_variant::{Variant, VariantBuilder};
|
||||
use parquet_variant::{ShortString, Variant, VariantBuilder};
|
||||
|
||||
fn cases_dir() -> PathBuf {
|
||||
Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
@@ -76,7 +76,7 @@ fn get_primitive_cases() -> Vec<(&'static str, Variant<'static, 'static>)> {
|
||||
("primitive_string", Variant::String("This string is longer than 64 bytes and therefore does not fit in a short_string and it also includes several non ascii characters such as 🐢, 💖, ♥\u{fe0f}, 🎣 and 🤦!!")),
|
||||
("primitive_timestamp", Variant::TimestampMicros(NaiveDate::from_ymd_opt(2025, 4, 16).unwrap().and_hms_milli_opt(16, 34, 56, 780).unwrap().and_utc())),
|
||||
("primitive_timestampntz", Variant::TimestampNtzMicros(NaiveDate::from_ymd_opt(2025, 4, 16).unwrap().and_hms_milli_opt(12, 34, 56, 780).unwrap())),
|
||||
("short_string", Variant::ShortString("Less than 64 bytes (❤\u{fe0f} with utf8)")),
|
||||
("short_string", Variant::ShortString(ShortString::try_new("Less than 64 bytes (❤\u{fe0f} with utf8)").unwrap())),
|
||||
]
|
||||
}
|
||||
#[test]
|
||||
@@ -130,11 +130,20 @@ fn variant_object_primitive() {
|
||||
),
|
||||
("int_field", Variant::Int8(1)),
|
||||
("null_field", Variant::Null),
|
||||
("string_field", Variant::ShortString("Apache Parquet")),
|
||||
(
|
||||
"string_field",
|
||||
Variant::ShortString(
|
||||
ShortString::try_new("Apache Parquet")
|
||||
.expect("value should fit inside a short string"),
|
||||
),
|
||||
),
|
||||
(
|
||||
// apparently spark wrote this as a string (not a timestamp)
|
||||
"timestamp_field",
|
||||
Variant::ShortString("2025-04-16T12:34:56.78"),
|
||||
Variant::ShortString(
|
||||
ShortString::try_new("2025-04-16T12:34:56.78")
|
||||
.expect("value should fit inside a short string"),
|
||||
),
|
||||
),
|
||||
];
|
||||
let actual_fields: Vec<_> = variant_object.iter().collect();
|
||||
|
||||
Reference in New Issue
Block a user