mirror of
https://github.com/langchain-ai/datafusion.git
synced 2026-07-01 21:24:06 -04:00
perf: Optimize lpad, rpad for ASCII strings (#20278)
The previous implementation incurred the overhead of Unicode machinery,
even for the common case that both the input string and the fill string
consistent only of ASCII characters. For the ASCII-only case, we can
assume that the length in bytes equals the length in characters, and
avoid expensive graphene-based segmentation. This follows similar
optimizations applied elsewhere in the codebase.
Benchmarks indicate this is a significant performance win for ASCII-only
input (4x-10x faster) but only a mild regression for Unicode input (2-5%
slower).
Along the way:
* Combine: a few instances of `write_str(str)? + append_value("")` with
`append_value(str)`, which saves a few cycles
* Add a missing test case for truncating the input string
* Add benchmarks for Unicode input
## Which issue does this PR close?
- Closes #20277.
## Are these changes tested?
Covered by existing tests. Added new benchmarks for Unicode inputs.
## Are there any user-facing changes?
No.
---------
Co-authored-by: Martin Grigorov <martin-g@users.noreply.github.com>
This commit is contained in:
@@ -15,7 +15,10 @@
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
use arrow::array::{ArrowPrimitiveType, OffsetSizeTrait, PrimitiveArray};
|
||||
use arrow::array::{
|
||||
ArrowPrimitiveType, GenericStringBuilder, OffsetSizeTrait, PrimitiveArray,
|
||||
StringViewBuilder,
|
||||
};
|
||||
use arrow::datatypes::{DataType, Field, Int64Type};
|
||||
use arrow::util::bench_util::{
|
||||
create_string_array_with_len, create_string_view_array_with_len,
|
||||
@@ -30,6 +33,51 @@ use std::hint::black_box;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
const UNICODE_STRINGS: &[&str] = &[
|
||||
"Ñandú",
|
||||
"Íslensku",
|
||||
"Þjóðarinnar",
|
||||
"Ελληνική",
|
||||
"Иванович",
|
||||
"データフュージョン",
|
||||
"José García",
|
||||
"Ölçü bïrïmï",
|
||||
"Ÿéšṱëṟḏàÿ",
|
||||
"Ährenstraße",
|
||||
];
|
||||
|
||||
fn create_unicode_string_array<O: OffsetSizeTrait>(
|
||||
size: usize,
|
||||
null_density: f32,
|
||||
) -> arrow::array::GenericStringArray<O> {
|
||||
let mut rng = rand::rng();
|
||||
let mut builder = GenericStringBuilder::<O>::new();
|
||||
for i in 0..size {
|
||||
if rng.random::<f32>() < null_density {
|
||||
builder.append_null();
|
||||
} else {
|
||||
builder.append_value(UNICODE_STRINGS[i % UNICODE_STRINGS.len()]);
|
||||
}
|
||||
}
|
||||
builder.finish()
|
||||
}
|
||||
|
||||
fn create_unicode_string_view_array(
|
||||
size: usize,
|
||||
null_density: f32,
|
||||
) -> arrow::array::StringViewArray {
|
||||
let mut rng = rand::rng();
|
||||
let mut builder = StringViewBuilder::with_capacity(size);
|
||||
for i in 0..size {
|
||||
if rng.random::<f32>() < null_density {
|
||||
builder.append_null();
|
||||
} else {
|
||||
builder.append_value(UNICODE_STRINGS[i % UNICODE_STRINGS.len()]);
|
||||
}
|
||||
}
|
||||
builder.finish()
|
||||
}
|
||||
|
||||
struct Filter<Dist> {
|
||||
dist: Dist,
|
||||
}
|
||||
@@ -67,6 +115,34 @@ where
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Create args for pad benchmark with Unicode strings
|
||||
fn create_unicode_pad_args(
|
||||
size: usize,
|
||||
target_len: usize,
|
||||
use_string_view: bool,
|
||||
) -> Vec<ColumnarValue> {
|
||||
let length_array =
|
||||
Arc::new(create_primitive_array::<Int64Type>(size, 0.0, target_len));
|
||||
|
||||
if use_string_view {
|
||||
let string_array = create_unicode_string_view_array(size, 0.1);
|
||||
let fill_array = create_unicode_string_view_array(size, 0.1);
|
||||
vec![
|
||||
ColumnarValue::Array(Arc::new(string_array)),
|
||||
ColumnarValue::Array(length_array),
|
||||
ColumnarValue::Array(Arc::new(fill_array)),
|
||||
]
|
||||
} else {
|
||||
let string_array = create_unicode_string_array::<i32>(size, 0.1);
|
||||
let fill_array = create_unicode_string_array::<i32>(size, 0.1);
|
||||
vec![
|
||||
ColumnarValue::Array(Arc::new(string_array)),
|
||||
ColumnarValue::Array(length_array),
|
||||
ColumnarValue::Array(Arc::new(fill_array)),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/// Create args for pad benchmark
|
||||
fn create_pad_args<O: OffsetSizeTrait>(
|
||||
size: usize,
|
||||
@@ -208,6 +284,58 @@ fn criterion_benchmark(c: &mut Criterion) {
|
||||
},
|
||||
);
|
||||
|
||||
// Utf8 type with Unicode strings
|
||||
let args = create_unicode_pad_args(size, 20, false);
|
||||
let arg_fields = args
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, arg)| {
|
||||
Field::new(format!("arg_{idx}"), arg.data_type(), true).into()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
group.bench_function(
|
||||
format!("lpad utf8 unicode [size={size}, target=20]"),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
let args_cloned = args.clone();
|
||||
black_box(unicode::lpad().invoke_with_args(ScalarFunctionArgs {
|
||||
args: args_cloned,
|
||||
arg_fields: arg_fields.clone(),
|
||||
number_rows: size,
|
||||
return_field: Field::new("f", DataType::Utf8, true).into(),
|
||||
config_options: Arc::clone(&config_options),
|
||||
}))
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
// StringView type with Unicode strings
|
||||
let args = create_unicode_pad_args(size, 20, true);
|
||||
let arg_fields = args
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, arg)| {
|
||||
Field::new(format!("arg_{idx}"), arg.data_type(), true).into()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
group.bench_function(
|
||||
format!("lpad stringview unicode [size={size}, target=20]"),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
let args_cloned = args.clone();
|
||||
black_box(unicode::lpad().invoke_with_args(ScalarFunctionArgs {
|
||||
args: args_cloned,
|
||||
arg_fields: arg_fields.clone(),
|
||||
number_rows: size,
|
||||
return_field: Field::new("f", DataType::Utf8View, true).into(),
|
||||
config_options: Arc::clone(&config_options),
|
||||
}))
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
@@ -322,6 +450,58 @@ fn criterion_benchmark(c: &mut Criterion) {
|
||||
},
|
||||
);
|
||||
|
||||
// Utf8 type with Unicode strings
|
||||
let args = create_unicode_pad_args(size, 20, false);
|
||||
let arg_fields = args
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, arg)| {
|
||||
Field::new(format!("arg_{idx}"), arg.data_type(), true).into()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
group.bench_function(
|
||||
format!("rpad utf8 unicode [size={size}, target=20]"),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
let args_cloned = args.clone();
|
||||
black_box(unicode::rpad().invoke_with_args(ScalarFunctionArgs {
|
||||
args: args_cloned,
|
||||
arg_fields: arg_fields.clone(),
|
||||
number_rows: size,
|
||||
return_field: Field::new("f", DataType::Utf8, true).into(),
|
||||
config_options: Arc::clone(&config_options),
|
||||
}))
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
// StringView type with Unicode strings
|
||||
let args = create_unicode_pad_args(size, 20, true);
|
||||
let arg_fields = args
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, arg)| {
|
||||
Field::new(format!("arg_{idx}"), arg.data_type(), true).into()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
group.bench_function(
|
||||
format!("rpad stringview unicode [size={size}, target=20]"),
|
||||
|b| {
|
||||
b.iter(|| {
|
||||
let args_cloned = args.clone();
|
||||
black_box(unicode::rpad().invoke_with_args(ScalarFunctionArgs {
|
||||
args: args_cloned,
|
||||
arg_fields: arg_fields.clone(),
|
||||
number_rows: size,
|
||||
return_field: Field::new("f", DataType::Utf8View, true).into(),
|
||||
config_options: Arc::clone(&config_options),
|
||||
}))
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
group.finish();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,7 +49,10 @@ use datafusion_macros::user_doc;
|
||||
+---------------------------------------------+
|
||||
```"#,
|
||||
standard_argument(name = "str", prefix = "String"),
|
||||
argument(name = "n", description = "String length to pad to."),
|
||||
argument(
|
||||
name = "n",
|
||||
description = "String length to pad to. If the input string is longer than this length, it is truncated (on the right)."
|
||||
),
|
||||
argument(
|
||||
name = "padding_str",
|
||||
description = "Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._"
|
||||
@@ -225,24 +228,47 @@ where
|
||||
continue;
|
||||
}
|
||||
|
||||
// Reuse buffers by clearing and refilling
|
||||
graphemes_buf.clear();
|
||||
graphemes_buf.extend(string.graphemes(true));
|
||||
|
||||
fill_chars_buf.clear();
|
||||
fill_chars_buf.extend(fill.chars());
|
||||
|
||||
if length < graphemes_buf.len() {
|
||||
builder.append_value(graphemes_buf[..length].concat());
|
||||
} else if fill_chars_buf.is_empty() {
|
||||
builder.append_value(string);
|
||||
} else {
|
||||
for l in 0..length - graphemes_buf.len() {
|
||||
let c = *fill_chars_buf.get(l % fill_chars_buf.len()).unwrap();
|
||||
builder.write_char(c)?;
|
||||
if string.is_ascii() && fill.is_ascii() {
|
||||
// ASCII fast path: byte length == character length,
|
||||
// so we skip expensive grapheme segmentation.
|
||||
let str_len = string.len();
|
||||
if length < str_len {
|
||||
builder.append_value(&string[..length]);
|
||||
} else if fill.is_empty() {
|
||||
builder.append_value(string);
|
||||
} else {
|
||||
let pad_len = length - str_len;
|
||||
let fill_len = fill.len();
|
||||
let full_reps = pad_len / fill_len;
|
||||
let remainder = pad_len % fill_len;
|
||||
for _ in 0..full_reps {
|
||||
builder.write_str(fill)?;
|
||||
}
|
||||
if remainder > 0 {
|
||||
builder.write_str(&fill[..remainder])?;
|
||||
}
|
||||
builder.append_value(string);
|
||||
}
|
||||
} else {
|
||||
// Reuse buffers by clearing and refilling
|
||||
graphemes_buf.clear();
|
||||
graphemes_buf.extend(string.graphemes(true));
|
||||
|
||||
fill_chars_buf.clear();
|
||||
fill_chars_buf.extend(fill.chars());
|
||||
|
||||
if length < graphemes_buf.len() {
|
||||
builder.append_value(graphemes_buf[..length].concat());
|
||||
} else if fill_chars_buf.is_empty() {
|
||||
builder.append_value(string);
|
||||
} else {
|
||||
for l in 0..length - graphemes_buf.len() {
|
||||
let c =
|
||||
*fill_chars_buf.get(l % fill_chars_buf.len()).unwrap();
|
||||
builder.write_char(c)?;
|
||||
}
|
||||
builder.append_value(string);
|
||||
}
|
||||
builder.write_str(string)?;
|
||||
builder.append_value("");
|
||||
}
|
||||
} else {
|
||||
builder.append_null();
|
||||
@@ -266,17 +292,30 @@ where
|
||||
continue;
|
||||
}
|
||||
|
||||
// Reuse buffer by clearing and refilling
|
||||
graphemes_buf.clear();
|
||||
graphemes_buf.extend(string.graphemes(true));
|
||||
|
||||
if length < graphemes_buf.len() {
|
||||
builder.append_value(graphemes_buf[..length].concat());
|
||||
if string.is_ascii() {
|
||||
// ASCII fast path: byte length == character length
|
||||
let str_len = string.len();
|
||||
if length < str_len {
|
||||
builder.append_value(&string[..length]);
|
||||
} else {
|
||||
for _ in 0..(length - str_len) {
|
||||
builder.write_str(" ")?;
|
||||
}
|
||||
builder.append_value(string);
|
||||
}
|
||||
} else {
|
||||
builder
|
||||
.write_str(" ".repeat(length - graphemes_buf.len()).as_str())?;
|
||||
builder.write_str(string)?;
|
||||
builder.append_value("");
|
||||
// Reuse buffer by clearing and refilling
|
||||
graphemes_buf.clear();
|
||||
graphemes_buf.extend(string.graphemes(true));
|
||||
|
||||
if length < graphemes_buf.len() {
|
||||
builder.append_value(graphemes_buf[..length].concat());
|
||||
} else {
|
||||
for _ in 0..(length - graphemes_buf.len()) {
|
||||
builder.write_str(" ")?;
|
||||
}
|
||||
builder.append_value(string);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
builder.append_null();
|
||||
@@ -523,6 +562,17 @@ mod tests {
|
||||
None,
|
||||
Ok(None)
|
||||
);
|
||||
test_lpad!(
|
||||
Some("hello".into()),
|
||||
ScalarValue::Int64(Some(2i64)),
|
||||
Ok(Some("he"))
|
||||
);
|
||||
test_lpad!(
|
||||
Some("hi".into()),
|
||||
ScalarValue::Int64(Some(6i64)),
|
||||
Some("xy".into()),
|
||||
Ok(Some("xyxyhi"))
|
||||
);
|
||||
test_lpad!(
|
||||
Some("josé".into()),
|
||||
ScalarValue::Int64(Some(10i64)),
|
||||
|
||||
@@ -48,7 +48,10 @@ use unicode_segmentation::UnicodeSegmentation;
|
||||
+-----------------------------------------------+
|
||||
```"#,
|
||||
standard_argument(name = "str", prefix = "String"),
|
||||
argument(name = "n", description = "String length to pad to."),
|
||||
argument(
|
||||
name = "n",
|
||||
description = "String length to pad to. If the input string is longer than this length, it is truncated."
|
||||
),
|
||||
argument(
|
||||
name = "padding_str",
|
||||
description = "String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._"
|
||||
@@ -203,7 +206,8 @@ fn rpad<StringArrayLen: OffsetSizeTrait, FillArrayLen: OffsetSizeTrait>(
|
||||
}
|
||||
}
|
||||
|
||||
/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated.
|
||||
/// Extends the string to length 'length' by appending the characters fill (a space by default).
|
||||
/// If the string is already longer than length then it is truncated (on the right).
|
||||
/// rpad('hi', 5, 'xy') = 'hixyx'
|
||||
fn rpad_impl<'a, StringArrType, FillArrType, StringArrayLen>(
|
||||
string_array: &StringArrType,
|
||||
@@ -234,6 +238,18 @@ where
|
||||
let length = if length < 0 { 0 } else { length as usize };
|
||||
if length == 0 {
|
||||
builder.append_value("");
|
||||
} else if string.is_ascii() {
|
||||
// ASCII fast path: byte length == character length
|
||||
let str_len = string.len();
|
||||
if length < str_len {
|
||||
builder.append_value(&string[..length]);
|
||||
} else {
|
||||
builder.write_str(string)?;
|
||||
for _ in 0..(length - str_len) {
|
||||
builder.write_str(" ")?;
|
||||
}
|
||||
builder.append_value("");
|
||||
}
|
||||
} else {
|
||||
// Reuse buffer by clearing and refilling
|
||||
graphemes_buf.clear();
|
||||
@@ -244,9 +260,9 @@ where
|
||||
.append_value(graphemes_buf[..length].concat());
|
||||
} else {
|
||||
builder.write_str(string)?;
|
||||
builder.write_str(
|
||||
&" ".repeat(length - graphemes_buf.len()),
|
||||
)?;
|
||||
for _ in 0..(length - graphemes_buf.len()) {
|
||||
builder.write_str(" ")?;
|
||||
}
|
||||
builder.append_value("");
|
||||
}
|
||||
}
|
||||
@@ -273,27 +289,52 @@ where
|
||||
);
|
||||
}
|
||||
let length = if length < 0 { 0 } else { length as usize };
|
||||
// Reuse buffer by clearing and refilling
|
||||
graphemes_buf.clear();
|
||||
graphemes_buf.extend(string.graphemes(true));
|
||||
|
||||
if length < graphemes_buf.len() {
|
||||
builder
|
||||
.append_value(graphemes_buf[..length].concat());
|
||||
} else if fill.is_empty() {
|
||||
builder.append_value(string);
|
||||
} else {
|
||||
builder.write_str(string)?;
|
||||
// Reuse fill_chars_buf by clearing and refilling
|
||||
fill_chars_buf.clear();
|
||||
fill_chars_buf.extend(fill.chars());
|
||||
for l in 0..length - graphemes_buf.len() {
|
||||
let c = *fill_chars_buf
|
||||
.get(l % fill_chars_buf.len())
|
||||
.unwrap();
|
||||
builder.write_char(c)?;
|
||||
if string.is_ascii() && fill.is_ascii() {
|
||||
// ASCII fast path: byte length == character length,
|
||||
// so we skip expensive grapheme segmentation.
|
||||
let str_len = string.len();
|
||||
if length < str_len {
|
||||
builder.append_value(&string[..length]);
|
||||
} else if fill.is_empty() {
|
||||
builder.append_value(string);
|
||||
} else {
|
||||
let pad_len = length - str_len;
|
||||
let fill_len = fill.len();
|
||||
let full_reps = pad_len / fill_len;
|
||||
let remainder = pad_len % fill_len;
|
||||
builder.write_str(string)?;
|
||||
for _ in 0..full_reps {
|
||||
builder.write_str(fill)?;
|
||||
}
|
||||
if remainder > 0 {
|
||||
builder.write_str(&fill[..remainder])?;
|
||||
}
|
||||
builder.append_value("");
|
||||
}
|
||||
} else {
|
||||
// Reuse buffer by clearing and refilling
|
||||
graphemes_buf.clear();
|
||||
graphemes_buf.extend(string.graphemes(true));
|
||||
|
||||
if length < graphemes_buf.len() {
|
||||
builder.append_value(
|
||||
graphemes_buf[..length].concat(),
|
||||
);
|
||||
} else if fill.is_empty() {
|
||||
builder.append_value(string);
|
||||
} else {
|
||||
builder.write_str(string)?;
|
||||
// Reuse fill_chars_buf by clearing and refilling
|
||||
fill_chars_buf.clear();
|
||||
fill_chars_buf.extend(fill.chars());
|
||||
for l in 0..length - graphemes_buf.len() {
|
||||
let c = *fill_chars_buf
|
||||
.get(l % fill_chars_buf.len())
|
||||
.unwrap();
|
||||
builder.write_char(c)?;
|
||||
}
|
||||
builder.append_value("");
|
||||
}
|
||||
builder.append_value("");
|
||||
}
|
||||
}
|
||||
_ => builder.append_null(),
|
||||
@@ -459,6 +500,29 @@ mod tests {
|
||||
Utf8,
|
||||
StringArray
|
||||
);
|
||||
test_function!(
|
||||
RPadFunc::new(),
|
||||
vec![
|
||||
ColumnarValue::Scalar(ScalarValue::from("hello")),
|
||||
ColumnarValue::Scalar(ScalarValue::from(2i64)),
|
||||
],
|
||||
Ok(Some("he")),
|
||||
&str,
|
||||
Utf8,
|
||||
StringArray
|
||||
);
|
||||
test_function!(
|
||||
RPadFunc::new(),
|
||||
vec![
|
||||
ColumnarValue::Scalar(ScalarValue::from("hi")),
|
||||
ColumnarValue::Scalar(ScalarValue::from(6i64)),
|
||||
ColumnarValue::Scalar(ScalarValue::from("xy")),
|
||||
],
|
||||
Ok(Some("hixyxy")),
|
||||
&str,
|
||||
Utf8,
|
||||
StringArray
|
||||
);
|
||||
test_function!(
|
||||
RPadFunc::new(),
|
||||
vec![
|
||||
|
||||
@@ -1592,7 +1592,7 @@ lpad(str, n[, padding_str])
|
||||
#### Arguments
|
||||
|
||||
- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators.
|
||||
- **n**: String length to pad to.
|
||||
- **n**: String length to pad to. If the input string is longer than this length, it is truncated (on the right).
|
||||
- **padding_str**: Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._
|
||||
|
||||
#### Example
|
||||
@@ -1820,7 +1820,7 @@ rpad(str, n[, padding_str])
|
||||
#### Arguments
|
||||
|
||||
- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators.
|
||||
- **n**: String length to pad to.
|
||||
- **n**: String length to pad to. If the input string is longer than this length, it is truncated.
|
||||
- **padding_str**: String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._
|
||||
|
||||
#### Example
|
||||
|
||||
Reference in New Issue
Block a user