mirror of
https://github.com/langchain-ai/datafusion.git
synced 2026-07-01 21:24:06 -04:00
perf: Optimize initcap() (#20352)
## Which issue does this PR close? - Closes #20351. ## Rationale for this change When all values in a `Utf8`/`LargeUtf8` array are ASCII, we can skip using `GenericStringBuilder` and instead process the entire input buffer in a single pass using byte-level operations. This also avoids recomputing the offsets and nulls arrays. A similar optimization is already used for lower() and upper(). Along the way, optimize `initcap_string()` for ASCII-only inputs. It already had an ASCII-only fastpath but there was room for further optimization, by iterating over bytes rather than characters. ## What changes are included in this PR? * Cleanup benchmarks: we ran the scalar benchmark for different array sizes, despite the fact that it is invariant to the array size * Add benchmark for different string lengths * Add benchmark for Unicode array input * Optimize for ASCII-only inputs as described above * Add test case for ASCII-only input that is a sliced array * Add test case variants for `LargeStringArray` ## Are these changes tested? Yes, plus an additional test added. ## Are there any user-facing changes? No.
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
use arrow::array::OffsetSizeTrait;
|
||||
use arrow::array::{ArrayRef, OffsetSizeTrait, StringArray, StringViewBuilder};
|
||||
use arrow::datatypes::{DataType, Field};
|
||||
use arrow::util::bench_util::{
|
||||
create_string_array_with_len, create_string_view_array_with_len,
|
||||
@@ -47,52 +47,124 @@ fn create_args<O: OffsetSizeTrait>(
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a Utf8 array where every value contains non-ASCII Unicode text.
|
||||
fn create_unicode_utf8_args(size: usize) -> Vec<ColumnarValue> {
|
||||
let array = Arc::new(StringArray::from_iter_values(std::iter::repeat_n(
|
||||
"ñAnDÚ ÁrBOL ОлЕГ ÍslENsku",
|
||||
size,
|
||||
))) as ArrayRef;
|
||||
vec![ColumnarValue::Array(array)]
|
||||
}
|
||||
|
||||
/// Create a Utf8View array where every value contains non-ASCII Unicode text.
|
||||
fn create_unicode_utf8view_args(size: usize) -> Vec<ColumnarValue> {
|
||||
let mut builder = StringViewBuilder::with_capacity(size);
|
||||
for _ in 0..size {
|
||||
builder.append_value("ñAnDÚ ÁrBOL ОлЕГ ÍslENsku");
|
||||
}
|
||||
let array = Arc::new(builder.finish()) as ArrayRef;
|
||||
vec![ColumnarValue::Array(array)]
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let initcap = unicode::initcap();
|
||||
let config_options = Arc::new(ConfigOptions::default());
|
||||
|
||||
// Grouped benchmarks for array sizes - to compare with scalar performance
|
||||
// Array benchmarks: vary both row count and string length
|
||||
for size in [1024, 4096, 8192] {
|
||||
let mut group = c.benchmark_group(format!("initcap size={size}"));
|
||||
for str_len in [16, 128] {
|
||||
let mut group =
|
||||
c.benchmark_group(format!("initcap size={size} str_len={str_len}"));
|
||||
group.sampling_mode(SamplingMode::Flat);
|
||||
group.sample_size(10);
|
||||
group.measurement_time(Duration::from_secs(10));
|
||||
|
||||
// Utf8
|
||||
let array_args = create_args::<i32>(size, str_len, false);
|
||||
let array_arg_fields = vec![Field::new("arg_0", DataType::Utf8, true).into()];
|
||||
|
||||
group.bench_function("array_utf8", |b| {
|
||||
b.iter(|| {
|
||||
black_box(initcap.invoke_with_args(ScalarFunctionArgs {
|
||||
args: array_args.clone(),
|
||||
arg_fields: array_arg_fields.clone(),
|
||||
number_rows: size,
|
||||
return_field: Field::new("f", DataType::Utf8, true).into(),
|
||||
config_options: Arc::clone(&config_options),
|
||||
}))
|
||||
})
|
||||
});
|
||||
|
||||
// Utf8View
|
||||
let array_view_args = create_args::<i32>(size, str_len, true);
|
||||
let array_view_arg_fields =
|
||||
vec![Field::new("arg_0", DataType::Utf8View, true).into()];
|
||||
|
||||
group.bench_function("array_utf8view", |b| {
|
||||
b.iter(|| {
|
||||
black_box(initcap.invoke_with_args(ScalarFunctionArgs {
|
||||
args: array_view_args.clone(),
|
||||
arg_fields: array_view_arg_fields.clone(),
|
||||
number_rows: size,
|
||||
return_field: Field::new("f", DataType::Utf8View, true).into(),
|
||||
config_options: Arc::clone(&config_options),
|
||||
}))
|
||||
})
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
}
|
||||
|
||||
// Unicode array benchmarks
|
||||
for size in [1024, 4096, 8192] {
|
||||
let mut group = c.benchmark_group(format!("initcap unicode size={size}"));
|
||||
group.sampling_mode(SamplingMode::Flat);
|
||||
group.sample_size(10);
|
||||
group.measurement_time(Duration::from_secs(10));
|
||||
|
||||
// Array benchmark - Utf8
|
||||
let array_args = create_args::<i32>(size, 16, false);
|
||||
let array_arg_fields = vec![Field::new("arg_0", DataType::Utf8, true).into()];
|
||||
let batch_len = size;
|
||||
let unicode_args = create_unicode_utf8_args(size);
|
||||
let unicode_arg_fields = vec![Field::new("arg_0", DataType::Utf8, true).into()];
|
||||
|
||||
group.bench_function("array_utf8", |b| {
|
||||
b.iter(|| {
|
||||
black_box(initcap.invoke_with_args(ScalarFunctionArgs {
|
||||
args: array_args.clone(),
|
||||
arg_fields: array_arg_fields.clone(),
|
||||
number_rows: batch_len,
|
||||
args: unicode_args.clone(),
|
||||
arg_fields: unicode_arg_fields.clone(),
|
||||
number_rows: size,
|
||||
return_field: Field::new("f", DataType::Utf8, true).into(),
|
||||
config_options: Arc::clone(&config_options),
|
||||
}))
|
||||
})
|
||||
});
|
||||
|
||||
// Array benchmark - Utf8View
|
||||
let array_view_args = create_args::<i32>(size, 16, true);
|
||||
let array_view_arg_fields =
|
||||
let unicode_view_args = create_unicode_utf8view_args(size);
|
||||
let unicode_view_arg_fields =
|
||||
vec![Field::new("arg_0", DataType::Utf8View, true).into()];
|
||||
|
||||
group.bench_function("array_utf8view", |b| {
|
||||
b.iter(|| {
|
||||
black_box(initcap.invoke_with_args(ScalarFunctionArgs {
|
||||
args: array_view_args.clone(),
|
||||
arg_fields: array_view_arg_fields.clone(),
|
||||
number_rows: batch_len,
|
||||
args: unicode_view_args.clone(),
|
||||
arg_fields: unicode_view_arg_fields.clone(),
|
||||
number_rows: size,
|
||||
return_field: Field::new("f", DataType::Utf8View, true).into(),
|
||||
config_options: Arc::clone(&config_options),
|
||||
}))
|
||||
})
|
||||
});
|
||||
|
||||
// Scalar benchmark - Utf8 (the optimization we added)
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Scalar benchmarks: independent of array size, run once
|
||||
{
|
||||
let mut group = c.benchmark_group("initcap scalar");
|
||||
group.sampling_mode(SamplingMode::Flat);
|
||||
group.sample_size(10);
|
||||
group.measurement_time(Duration::from_secs(10));
|
||||
|
||||
// Utf8
|
||||
let scalar_args = vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(
|
||||
"hello world test string".to_string(),
|
||||
)))];
|
||||
@@ -110,7 +182,7 @@ fn criterion_benchmark(c: &mut Criterion) {
|
||||
})
|
||||
});
|
||||
|
||||
// Scalar benchmark - Utf8View
|
||||
// Utf8View
|
||||
let scalar_view_args = vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
|
||||
"hello world test string".to_string(),
|
||||
)))];
|
||||
|
||||
@@ -19,8 +19,10 @@ use std::any::Any;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::array::{
|
||||
Array, ArrayRef, GenericStringBuilder, OffsetSizeTrait, StringViewBuilder,
|
||||
Array, ArrayRef, GenericStringArray, GenericStringBuilder, OffsetSizeTrait,
|
||||
StringViewBuilder,
|
||||
};
|
||||
use arrow::buffer::{Buffer, OffsetBuffer};
|
||||
use arrow::datatypes::DataType;
|
||||
|
||||
use crate::utils::{make_scalar_function, utf8_to_str_type};
|
||||
@@ -148,8 +150,8 @@ impl ScalarUDFImpl for InitcapFunc {
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the first letter of each word to upper case and the rest to lower
|
||||
/// case. Words are sequences of alphanumeric characters separated by
|
||||
/// Converts the first letter of each word to uppercase and the rest to
|
||||
/// lowercase. Words are sequences of alphanumeric characters separated by
|
||||
/// non-alphanumeric characters.
|
||||
///
|
||||
/// Example:
|
||||
@@ -159,6 +161,10 @@ impl ScalarUDFImpl for InitcapFunc {
|
||||
fn initcap<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
|
||||
let string_array = as_generic_string_array::<T>(&args[0])?;
|
||||
|
||||
if string_array.is_ascii() {
|
||||
return Ok(initcap_ascii_array(string_array));
|
||||
}
|
||||
|
||||
let mut builder = GenericStringBuilder::<T>::with_capacity(
|
||||
string_array.len(),
|
||||
string_array.value_data().len(),
|
||||
@@ -176,12 +182,67 @@ fn initcap<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
|
||||
Ok(Arc::new(builder.finish()) as ArrayRef)
|
||||
}
|
||||
|
||||
/// Fast path for `Utf8` or `LargeUtf8` arrays that are ASCII-only. We can use a
|
||||
/// single pass over the buffer and operate directly on bytes.
|
||||
fn initcap_ascii_array<T: OffsetSizeTrait>(
|
||||
string_array: &GenericStringArray<T>,
|
||||
) -> ArrayRef {
|
||||
let offsets = string_array.offsets();
|
||||
let src = string_array.value_data();
|
||||
let first_offset = offsets.first().unwrap().as_usize();
|
||||
let last_offset = offsets.last().unwrap().as_usize();
|
||||
|
||||
// For sliced arrays, only convert the visible bytes, not the entire input
|
||||
// buffer.
|
||||
let mut out = Vec::with_capacity(last_offset - first_offset);
|
||||
|
||||
for window in offsets.windows(2) {
|
||||
let start = window[0].as_usize();
|
||||
let end = window[1].as_usize();
|
||||
|
||||
let mut prev_is_alnum = false;
|
||||
for &b in &src[start..end] {
|
||||
let converted = if prev_is_alnum {
|
||||
b.to_ascii_lowercase()
|
||||
} else {
|
||||
b.to_ascii_uppercase()
|
||||
};
|
||||
out.push(converted);
|
||||
prev_is_alnum = b.is_ascii_alphanumeric();
|
||||
}
|
||||
}
|
||||
|
||||
let values = Buffer::from_vec(out);
|
||||
let out_offsets = if first_offset == 0 {
|
||||
offsets.clone()
|
||||
} else {
|
||||
// For sliced arrays, we need to rebase the offsets to reflect that the
|
||||
// output only contains the bytes in the visible slice.
|
||||
let rebased_offsets = offsets
|
||||
.iter()
|
||||
.map(|offset| T::usize_as(offset.as_usize() - first_offset))
|
||||
.collect::<Vec<_>>();
|
||||
OffsetBuffer::<T>::new(rebased_offsets.into())
|
||||
};
|
||||
|
||||
// SAFETY: ASCII case conversion preserves byte length, so the original
|
||||
// string boundaries are preserved. `out_offsets` is either identical to
|
||||
// the input offsets or a rebased version relative to the compacted values
|
||||
// buffer.
|
||||
Arc::new(unsafe {
|
||||
GenericStringArray::<T>::new_unchecked(
|
||||
out_offsets,
|
||||
values,
|
||||
string_array.nulls().cloned(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn initcap_utf8view(args: &[ArrayRef]) -> Result<ArrayRef> {
|
||||
let string_view_array = as_string_view_array(&args[0])?;
|
||||
|
||||
let mut builder = StringViewBuilder::with_capacity(string_view_array.len());
|
||||
|
||||
let mut container = String::new();
|
||||
|
||||
string_view_array.iter().for_each(|str| match str {
|
||||
Some(s) => {
|
||||
initcap_string(s, &mut container);
|
||||
@@ -198,13 +259,16 @@ fn initcap_string(input: &str, container: &mut String) {
|
||||
let mut prev_is_alphanumeric = false;
|
||||
|
||||
if input.is_ascii() {
|
||||
for c in input.chars() {
|
||||
container.reserve(input.len());
|
||||
// SAFETY: each byte is ASCII, so the result is valid UTF-8.
|
||||
let out = unsafe { container.as_mut_vec() };
|
||||
for &b in input.as_bytes() {
|
||||
if prev_is_alphanumeric {
|
||||
container.push(c.to_ascii_lowercase());
|
||||
out.push(b.to_ascii_lowercase());
|
||||
} else {
|
||||
container.push(c.to_ascii_uppercase());
|
||||
};
|
||||
prev_is_alphanumeric = c.is_ascii_alphanumeric();
|
||||
out.push(b.to_ascii_uppercase());
|
||||
}
|
||||
prev_is_alphanumeric = b.is_ascii_alphanumeric();
|
||||
}
|
||||
} else {
|
||||
for c in input.chars() {
|
||||
@@ -222,10 +286,11 @@ fn initcap_string(input: &str, container: &mut String) {
|
||||
mod tests {
|
||||
use crate::unicode::initcap::InitcapFunc;
|
||||
use crate::utils::test::test_function;
|
||||
use arrow::array::{Array, StringArray, StringViewArray};
|
||||
use arrow::array::{Array, ArrayRef, LargeStringArray, StringArray, StringViewArray};
|
||||
use arrow::datatypes::DataType::{Utf8, Utf8View};
|
||||
use datafusion_common::{Result, ScalarValue};
|
||||
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn test_functions() -> Result<()> {
|
||||
@@ -329,4 +394,114 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_initcap_ascii_array() -> Result<()> {
|
||||
let array = StringArray::from(vec![
|
||||
Some("hello world"),
|
||||
None,
|
||||
Some("foo-bar_baz/baX"),
|
||||
Some(""),
|
||||
Some("123 abc 456DEF"),
|
||||
Some("ALL CAPS"),
|
||||
Some("already correct"),
|
||||
]);
|
||||
let args: Vec<ArrayRef> = vec![Arc::new(array)];
|
||||
let result = super::initcap::<i32>(&args)?;
|
||||
let result = result.as_any().downcast_ref::<StringArray>().unwrap();
|
||||
|
||||
assert_eq!(result.len(), 7);
|
||||
assert_eq!(result.value(0), "Hello World");
|
||||
assert!(result.is_null(1));
|
||||
assert_eq!(result.value(2), "Foo-Bar_Baz/Bax");
|
||||
assert_eq!(result.value(3), "");
|
||||
assert_eq!(result.value(4), "123 Abc 456def");
|
||||
assert_eq!(result.value(5), "All Caps");
|
||||
assert_eq!(result.value(6), "Already Correct");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_initcap_ascii_large_array() -> Result<()> {
|
||||
let array = LargeStringArray::from(vec![
|
||||
Some("hello world"),
|
||||
None,
|
||||
Some("foo-bar_baz/baX"),
|
||||
Some(""),
|
||||
Some("123 abc 456DEF"),
|
||||
Some("ALL CAPS"),
|
||||
Some("already correct"),
|
||||
]);
|
||||
let args: Vec<ArrayRef> = vec![Arc::new(array)];
|
||||
let result = super::initcap::<i64>(&args)?;
|
||||
let result = result.as_any().downcast_ref::<LargeStringArray>().unwrap();
|
||||
|
||||
assert_eq!(result.len(), 7);
|
||||
assert_eq!(result.value(0), "Hello World");
|
||||
assert!(result.is_null(1));
|
||||
assert_eq!(result.value(2), "Foo-Bar_Baz/Bax");
|
||||
assert_eq!(result.value(3), "");
|
||||
assert_eq!(result.value(4), "123 Abc 456def");
|
||||
assert_eq!(result.value(5), "All Caps");
|
||||
assert_eq!(result.value(6), "Already Correct");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Test that initcap works correctly on a sliced ASCII StringArray.
|
||||
#[test]
|
||||
fn test_initcap_sliced_ascii_array() -> Result<()> {
|
||||
let array = StringArray::from(vec![
|
||||
Some("hello world"),
|
||||
Some("foo bar"),
|
||||
Some("baz qux"),
|
||||
]);
|
||||
// Slice to get only the last two elements. The resulting array's
|
||||
// offsets are [11, 18, 25] (non-zero start), but value_data still
|
||||
// contains the full original buffer.
|
||||
let sliced = array.slice(1, 2);
|
||||
let args: Vec<ArrayRef> = vec![Arc::new(sliced)];
|
||||
let result = super::initcap::<i32>(&args)?;
|
||||
let result = result.as_any().downcast_ref::<StringArray>().unwrap();
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result.value(0), "Foo Bar");
|
||||
assert_eq!(result.value(1), "Baz Qux");
|
||||
|
||||
// The output values buffer should be compact
|
||||
assert_eq!(*result.offsets().first().unwrap(), 0);
|
||||
assert_eq!(
|
||||
result.value_data().len(),
|
||||
*result.offsets().last().unwrap() as usize
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Test that initcap works correctly on a sliced ASCII LargeStringArray.
|
||||
#[test]
|
||||
fn test_initcap_sliced_ascii_large_array() -> Result<()> {
|
||||
let array = LargeStringArray::from(vec![
|
||||
Some("hello world"),
|
||||
Some("foo bar"),
|
||||
Some("baz qux"),
|
||||
]);
|
||||
// Slice to get only the last two elements. The resulting array's
|
||||
// offsets are [11, 18, 25] (non-zero start), but value_data still
|
||||
// contains the full original buffer.
|
||||
let sliced = array.slice(1, 2);
|
||||
let args: Vec<ArrayRef> = vec![Arc::new(sliced)];
|
||||
let result = super::initcap::<i64>(&args)?;
|
||||
let result = result.as_any().downcast_ref::<LargeStringArray>().unwrap();
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result.value(0), "Foo Bar");
|
||||
assert_eq!(result.value(1), "Baz Qux");
|
||||
|
||||
// The output values buffer should be compact
|
||||
assert_eq!(*result.offsets().first().unwrap(), 0);
|
||||
assert_eq!(
|
||||
result.value_data().len(),
|
||||
*result.offsets().last().unwrap() as usize
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user