[Variant] Introduce new type over &str for ShortString (#7718)

# Which issue does this PR close?

- Closes https://github.com/apache/arrow-rs/issues/7700

This commit introduces `ShortString`, a newtype that wraps around `&str`
that enforces a maximum length constraint. This also allows us to
perform validation once and removes a superfluous validation check in
`append_value`.

The now-superflous validation check was needed since users could
construct `Variant::ShortString`s directly, without doing input
validation. This means you can have a short string variant which
actually contains a string that is no longer than 63 bytes.

But since we enforce this check upon construction, we can directly match
against `Variant::String` and `Variant::ShortString` arms with their
respective appending functions (`append_string` and
`append_short_string`).
This commit is contained in:
Matthew Kim
2025-06-21 07:16:51 -04:00
committed by GitHub
parent 7b374b9b7a
commit 1ededfe024
4 changed files with 122 additions and 35 deletions
+18 -17
View File
@@ -15,11 +15,10 @@
// specific language governing permissions and limitations
// under the License.
use crate::decoder::{VariantBasicType, VariantPrimitiveType};
use crate::Variant;
use crate::{ShortString, Variant};
use std::collections::HashMap;
const BASIC_TYPE_BITS: u8 = 2;
const MAX_SHORT_STRING_SIZE: usize = 0x3F;
const UNIX_EPOCH_DATE: chrono::NaiveDate = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
fn primitive_header(primitive_type: VariantPrimitiveType) -> u8 {
@@ -114,11 +113,11 @@ fn make_room_for_header(buffer: &mut Vec<u8>, start_pos: usize, header_size: usi
/// };
/// assert_eq!(
/// variant_object.field_by_name("first_name").unwrap(),
/// Some(Variant::ShortString("Jiaying"))
/// Some(Variant::from("Jiaying"))
/// );
/// assert_eq!(
/// variant_object.field_by_name("last_name").unwrap(),
/// Some(Variant::ShortString("Li"))
/// Some(Variant::from("Li"))
/// );
/// ```
///
@@ -281,17 +280,18 @@ impl VariantBuilder {
self.buffer.extend_from_slice(value);
}
fn append_short_string(&mut self, value: ShortString) {
let inner = value.0;
self.buffer.push(short_string_header(inner.len()));
self.buffer.extend_from_slice(inner.as_bytes());
}
fn append_string(&mut self, value: &str) {
if value.len() <= MAX_SHORT_STRING_SIZE {
self.buffer.push(short_string_header(value.len()));
self.buffer.extend_from_slice(value.as_bytes());
} else {
self.buffer
.push(primitive_header(VariantPrimitiveType::String));
self.buffer
.extend_from_slice(&(value.len() as u32).to_le_bytes());
self.buffer.extend_from_slice(value.as_bytes());
}
self.buffer
.push(primitive_header(VariantPrimitiveType::String));
self.buffer
.extend_from_slice(&(value.len() as u32).to_le_bytes());
self.buffer.extend_from_slice(value.as_bytes());
}
/// Add key to dictionary, return its ID
@@ -390,7 +390,8 @@ impl VariantBuilder {
Variant::Float(v) => self.append_float(v),
Variant::Double(v) => self.append_double(v),
Variant::Binary(v) => self.append_binary(v),
Variant::String(s) | Variant::ShortString(s) => self.append_string(s),
Variant::String(s) => self.append_string(s),
Variant::ShortString(s) => self.append_short_string(s),
Variant::Object(_) | Variant::List(_) => {
unreachable!("Object and List variants cannot be created through Into<Variant>")
}
@@ -639,7 +640,7 @@ mod tests {
builder.append_value("hello");
let (metadata, value) = builder.finish();
let variant = Variant::try_new(&metadata, &value).unwrap();
assert_eq!(variant, Variant::ShortString("hello"));
assert_eq!(variant, Variant::ShortString(ShortString("hello")));
}
{
@@ -688,7 +689,7 @@ mod tests {
assert_eq!(val1, Variant::Int8(2));
let val2 = list.get(2).unwrap();
assert_eq!(val2, Variant::ShortString("test"));
assert_eq!(val2, Variant::ShortString(ShortString("test")));
}
_ => panic!("Expected an array variant, got: {:?}", variant),
}
+4 -3
View File
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
use crate::utils::{array_from_slice, slice_from_slice, string_from_slice};
use crate::ShortString;
use arrow_schema::ArrowError;
use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, Utc};
@@ -273,10 +274,10 @@ pub(crate) fn decode_long_string(data: &[u8]) -> Result<&str, ArrowError> {
}
/// Decodes a short string from the value section of a variant.
pub(crate) fn decode_short_string(metadata: u8, data: &[u8]) -> Result<&str, ArrowError> {
pub(crate) fn decode_short_string(metadata: u8, data: &[u8]) -> Result<ShortString, ArrowError> {
let len = (metadata >> 2) as usize;
let string = string_from_slice(data, 0..len)?;
Ok(string)
ShortString::try_new(string)
}
#[cfg(test)]
@@ -420,7 +421,7 @@ mod tests {
fn test_short_string() -> Result<(), ArrowError> {
let data = [b'H', b'e', b'l', b'l', b'o', b'o'];
let result = decode_short_string(1 | 5 << 2, &data)?;
assert_eq!(result, "Hello");
assert_eq!(result.0, "Hello");
Ok(())
}
+87 -11
View File
@@ -1,3 +1,5 @@
use std::ops::Deref;
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
@@ -29,6 +31,65 @@ mod list;
mod metadata;
mod object;
const MAX_SHORT_STRING_BYTES: usize = 0x3F;
/// A Variant [`ShortString`]
///
/// This implementation is a zero cost wrapper over `&str` that ensures
/// the length of the underlying string is a valid Variant short string (63 bytes or less)
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ShortString<'a>(pub(crate) &'a str);
impl<'a> ShortString<'a> {
/// Attempts to interpret `value` as a variant short string value.
///
/// # Validation
///
/// This constructor verifies that `value` is shorter than or equal to `MAX_SHORT_STRING_BYTES`
pub fn try_new(value: &'a str) -> Result<Self, ArrowError> {
if value.len() > MAX_SHORT_STRING_BYTES {
return Err(ArrowError::InvalidArgumentError(format!(
"value is larger than {MAX_SHORT_STRING_BYTES} bytes"
)));
}
Ok(Self(value))
}
/// Returns the underlying Variant short string as a &str
pub fn as_str(&self) -> &'a str {
self.0
}
}
impl<'a> From<ShortString<'a>> for &'a str {
fn from(value: ShortString<'a>) -> Self {
value.0
}
}
impl<'a> TryFrom<&'a str> for ShortString<'a> {
type Error = ArrowError;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
Self::try_new(value)
}
}
impl<'a> AsRef<str> for ShortString<'a> {
fn as_ref(&self) -> &str {
self.0
}
}
impl<'a> Deref for ShortString<'a> {
type Target = str;
fn deref(&self) -> &Self::Target {
self.0
}
}
/// Represents a [Parquet Variant]
///
/// The lifetimes `'m` and `'v` are for metadata and value buffers, respectively.
@@ -85,7 +146,7 @@ mod object;
///
/// ## Creating `Variant` from Rust Types
/// ```
/// # use parquet_variant::Variant;
/// use parquet_variant::Variant;
/// // variants can be directly constructed
/// let variant = Variant::Int32(123);
/// // or constructed via `From` impls
@@ -98,7 +159,7 @@ mod object;
/// let value = [0x09, 0x48, 0x49];
/// // parse the header metadata
/// assert_eq!(
/// Variant::ShortString("HI"),
/// Variant::from("HI"),
/// Variant::try_new(&metadata, &value).unwrap()
/// );
/// ```
@@ -152,7 +213,7 @@ pub enum Variant<'m, 'v> {
/// Primitive (type_id=1): STRING
String(&'v str),
/// Short String (type_id=2): STRING
ShortString(&'v str),
ShortString(ShortString<'v>),
// need both metadata & value
/// Object (type_id=3): N/A
Object(VariantObject<'m, 'v>),
@@ -165,12 +226,12 @@ impl<'m, 'v> Variant<'m, 'v> {
///
/// # Example
/// ```
/// # use parquet_variant::{Variant, VariantMetadata};
/// use parquet_variant::{Variant, VariantMetadata};
/// let metadata = [0x01, 0x00, 0x00];
/// let value = [0x09, 0x48, 0x49];
/// // parse the header metadata
/// assert_eq!(
/// Variant::ShortString("HI"),
/// Variant::from("HI"),
/// Variant::try_new(&metadata, &value).unwrap()
/// );
/// ```
@@ -189,7 +250,7 @@ impl<'m, 'v> Variant<'m, 'v> {
/// // parse the header metadata first
/// let metadata = VariantMetadata::try_new(&metadata).unwrap();
/// assert_eq!(
/// Variant::ShortString("HI"),
/// Variant::from("HI"),
/// Variant::try_new_with_metadata(metadata, &value).unwrap()
/// );
/// ```
@@ -432,7 +493,7 @@ impl<'m, 'v> Variant<'m, 'v> {
///
/// // you can extract a string from string variants
/// let s = "hello!";
/// let v1 = Variant::ShortString(s);
/// let v1 = Variant::from(s);
/// assert_eq!(v1.as_string(), Some(s));
///
/// // but not from other variants
@@ -441,7 +502,7 @@ impl<'m, 'v> Variant<'m, 'v> {
/// ```
pub fn as_string(&'v self) -> Option<&'v str> {
match self {
Variant::String(s) | Variant::ShortString(s) => Some(s),
Variant::String(s) | Variant::ShortString(ShortString(s)) => Some(s),
_ => None,
}
}
@@ -861,10 +922,25 @@ impl<'v> From<&'v [u8]> for Variant<'_, 'v> {
impl<'v> From<&'v str> for Variant<'_, 'v> {
fn from(value: &'v str) -> Self {
if value.len() < 64 {
Variant::ShortString(value)
} else {
if value.len() > MAX_SHORT_STRING_BYTES {
Variant::String(value)
} else {
Variant::ShortString(ShortString(value))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_construct_short_string() {
let short_string = ShortString::try_new("norm").expect("should fit in short string");
assert_eq!(short_string.as_str(), "norm");
let long_string = "a".repeat(MAX_SHORT_STRING_BYTES + 1);
let res = ShortString::try_new(&long_string);
assert!(res.is_err());
}
}
+13 -4
View File
@@ -24,7 +24,7 @@ use std::fs;
use std::path::{Path, PathBuf};
use chrono::NaiveDate;
use parquet_variant::{Variant, VariantBuilder};
use parquet_variant::{ShortString, Variant, VariantBuilder};
fn cases_dir() -> PathBuf {
Path::new(env!("CARGO_MANIFEST_DIR"))
@@ -76,7 +76,7 @@ fn get_primitive_cases() -> Vec<(&'static str, Variant<'static, 'static>)> {
("primitive_string", Variant::String("This string is longer than 64 bytes and therefore does not fit in a short_string and it also includes several non ascii characters such as 🐢, 💖, ♥\u{fe0f}, 🎣 and 🤦!!")),
("primitive_timestamp", Variant::TimestampMicros(NaiveDate::from_ymd_opt(2025, 4, 16).unwrap().and_hms_milli_opt(16, 34, 56, 780).unwrap().and_utc())),
("primitive_timestampntz", Variant::TimestampNtzMicros(NaiveDate::from_ymd_opt(2025, 4, 16).unwrap().and_hms_milli_opt(12, 34, 56, 780).unwrap())),
("short_string", Variant::ShortString("Less than 64 bytes (❤\u{fe0f} with utf8)")),
("short_string", Variant::ShortString(ShortString::try_new("Less than 64 bytes (❤\u{fe0f} with utf8)").unwrap())),
]
}
#[test]
@@ -130,11 +130,20 @@ fn variant_object_primitive() {
),
("int_field", Variant::Int8(1)),
("null_field", Variant::Null),
("string_field", Variant::ShortString("Apache Parquet")),
(
"string_field",
Variant::ShortString(
ShortString::try_new("Apache Parquet")
.expect("value should fit inside a short string"),
),
),
(
// apparently spark wrote this as a string (not a timestamp)
"timestamp_field",
Variant::ShortString("2025-04-16T12:34:56.78"),
Variant::ShortString(
ShortString::try_new("2025-04-16T12:34:56.78")
.expect("value should fit inside a short string"),
),
),
];
let actual_fields: Vec<_> = variant_object.iter().collect();