mirror of
https://github.com/langchain-ai/arrow-rs.git
synced 2026-07-01 21:34:01 -04:00
arrow-row: Add support for REE (#7649)
# Which issue does this PR close? Part 2 of https://github.com/apache/datafusion/issues/16011 # Are there any user-facing changes? No user facing changes, just extending functionality of existing APIs to support extracting rows from REE arrays. @alamb
This commit is contained in:
committed by
GitHub
parent
e6c93c02fd
commit
3837ac01dc
@@ -145,9 +145,11 @@ use variable::{decode_binary_view, decode_string_view};
|
||||
|
||||
use crate::fixed::{decode_bool, decode_fixed_size_binary, decode_primitive};
|
||||
use crate::variable::{decode_binary, decode_string};
|
||||
use arrow_array::types::{Int16Type, Int32Type, Int64Type};
|
||||
|
||||
mod fixed;
|
||||
mod list;
|
||||
mod run;
|
||||
mod variable;
|
||||
|
||||
/// Converts [`ArrayRef`] columns into a [row-oriented](self) format.
|
||||
@@ -381,6 +383,8 @@ enum Codec {
|
||||
Struct(RowConverter, OwnedRow),
|
||||
/// A row converter for the child field
|
||||
List(RowConverter),
|
||||
/// A row converter for the values array of a run-end encoded array
|
||||
RunEndEncoded(RowConverter),
|
||||
}
|
||||
|
||||
impl Codec {
|
||||
@@ -400,6 +404,17 @@ impl Codec {
|
||||
};
|
||||
Ok(Self::Dictionary(converter, owned))
|
||||
}
|
||||
DataType::RunEndEncoded(_, values) => {
|
||||
// Similar to List implementation
|
||||
let options = SortOptions {
|
||||
descending: false,
|
||||
nulls_first: sort_field.options.nulls_first != sort_field.options.descending,
|
||||
};
|
||||
|
||||
let field = SortField::new_with_options(values.data_type().clone(), options);
|
||||
let converter = RowConverter::new(vec![field])?;
|
||||
Ok(Self::RunEndEncoded(converter))
|
||||
}
|
||||
d if !d.is_nested() => Ok(Self::Stateless),
|
||||
DataType::List(f) | DataType::LargeList(f) => {
|
||||
// The encoded contents will be inverted if descending is set to true
|
||||
@@ -460,6 +475,19 @@ impl Codec {
|
||||
let rows = converter.convert_columns(&[values.clone()])?;
|
||||
Ok(Encoder::List(rows))
|
||||
}
|
||||
Codec::RunEndEncoded(converter) => {
|
||||
let values = match array.data_type() {
|
||||
DataType::RunEndEncoded(r, _) => match r.data_type() {
|
||||
DataType::Int16 => array.as_run::<Int16Type>().values(),
|
||||
DataType::Int32 => array.as_run::<Int32Type>().values(),
|
||||
DataType::Int64 => array.as_run::<Int64Type>().values(),
|
||||
_ => unreachable!("Unsupported run end index type: {r:?}"),
|
||||
},
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let rows = converter.convert_columns(&[values.clone()])?;
|
||||
Ok(Encoder::RunEndEncoded(rows))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -469,6 +497,7 @@ impl Codec {
|
||||
Codec::Dictionary(converter, nulls) => converter.size() + nulls.data.len(),
|
||||
Codec::Struct(converter, nulls) => converter.size() + nulls.data.len(),
|
||||
Codec::List(converter) => converter.size(),
|
||||
Codec::RunEndEncoded(converter) => converter.size(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -487,6 +516,8 @@ enum Encoder<'a> {
|
||||
Struct(Rows, Row<'a>),
|
||||
/// The row encoding of the child array
|
||||
List(Rows),
|
||||
/// The row encoding of the values array
|
||||
RunEndEncoded(Rows),
|
||||
}
|
||||
|
||||
/// Configure the data type and sort order for a given column
|
||||
@@ -545,6 +576,7 @@ impl RowConverter {
|
||||
Self::supports_datatype(f.data_type())
|
||||
}
|
||||
DataType::Struct(f) => f.iter().all(|x| Self::supports_datatype(x.data_type())),
|
||||
DataType::RunEndEncoded(_, values) => Self::supports_datatype(values.data_type()),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
@@ -1331,6 +1363,27 @@ fn row_lengths(cols: &[ArrayRef], encoders: &[Encoder]) -> LengthTracker {
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
Encoder::RunEndEncoded(rows) => match array.data_type() {
|
||||
DataType::RunEndEncoded(r, _) => match r.data_type() {
|
||||
DataType::Int16 => run::compute_lengths(
|
||||
tracker.materialized(),
|
||||
rows,
|
||||
array.as_run::<Int16Type>(),
|
||||
),
|
||||
DataType::Int32 => run::compute_lengths(
|
||||
tracker.materialized(),
|
||||
rows,
|
||||
array.as_run::<Int32Type>(),
|
||||
),
|
||||
DataType::Int64 => run::compute_lengths(
|
||||
tracker.materialized(),
|
||||
rows,
|
||||
array.as_run::<Int64Type>(),
|
||||
),
|
||||
_ => unreachable!("Unsupported run end index type: {r:?}"),
|
||||
},
|
||||
_ => unreachable!(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1427,6 +1480,21 @@ fn encode_column(
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
Encoder::RunEndEncoded(rows) => match column.data_type() {
|
||||
DataType::RunEndEncoded(r, _) => match r.data_type() {
|
||||
DataType::Int16 => {
|
||||
run::encode(data, offsets, rows, opts, column.as_run::<Int16Type>())
|
||||
}
|
||||
DataType::Int32 => {
|
||||
run::encode(data, offsets, rows, opts, column.as_run::<Int32Type>())
|
||||
}
|
||||
DataType::Int64 => {
|
||||
run::encode(data, offsets, rows, opts, column.as_run::<Int64Type>())
|
||||
}
|
||||
_ => unreachable!("Unsupported run end index type: {r:?}"),
|
||||
},
|
||||
_ => unreachable!(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1512,6 +1580,30 @@ unsafe fn decode_column(
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
Codec::RunEndEncoded(converter) => match &field.data_type {
|
||||
DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() {
|
||||
DataType::Int16 => Arc::new(run::decode::<Int16Type>(
|
||||
converter,
|
||||
rows,
|
||||
field,
|
||||
validate_utf8,
|
||||
)?),
|
||||
DataType::Int32 => Arc::new(run::decode::<Int32Type>(
|
||||
converter,
|
||||
rows,
|
||||
field,
|
||||
validate_utf8,
|
||||
)?),
|
||||
DataType::Int64 => Arc::new(run::decode::<Int64Type>(
|
||||
converter,
|
||||
rows,
|
||||
field,
|
||||
validate_utf8,
|
||||
)?),
|
||||
_ => unreachable!(),
|
||||
},
|
||||
_ => unreachable!(),
|
||||
},
|
||||
};
|
||||
Ok(array)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,695 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
use crate::{variable, RowConverter, Rows, SortField};
|
||||
use arrow_array::types::RunEndIndexType;
|
||||
use arrow_array::{PrimitiveArray, RunArray};
|
||||
use arrow_buffer::{ArrowNativeType, ScalarBuffer};
|
||||
use arrow_schema::{ArrowError, SortOptions};
|
||||
|
||||
/// Computes the lengths of each row for a RunEndEncodedArray
|
||||
pub fn compute_lengths<R: RunEndIndexType>(
|
||||
lengths: &mut [usize],
|
||||
rows: &Rows,
|
||||
array: &RunArray<R>,
|
||||
) {
|
||||
let run_ends = array.run_ends().values();
|
||||
let mut logical_start = 0;
|
||||
|
||||
// Iterate over each run and apply the same length to all logical positions in the run
|
||||
for (physical_idx, &run_end) in run_ends.iter().enumerate() {
|
||||
let logical_end = run_end.as_usize();
|
||||
let row = rows.row(physical_idx);
|
||||
let encoded_len = variable::encoded_len(Some(row.data));
|
||||
|
||||
// Add the same length for all logical positions in this run
|
||||
for length in &mut lengths[logical_start..logical_end] {
|
||||
*length += encoded_len;
|
||||
}
|
||||
|
||||
logical_start = logical_end;
|
||||
}
|
||||
}
|
||||
|
||||
/// Encodes the provided `RunEndEncodedArray` to `out` with the provided `SortOptions`
|
||||
///
|
||||
/// `rows` should contain the encoded values
|
||||
pub fn encode<R: RunEndIndexType>(
|
||||
data: &mut [u8],
|
||||
offsets: &mut [usize],
|
||||
rows: &Rows,
|
||||
opts: SortOptions,
|
||||
array: &RunArray<R>,
|
||||
) {
|
||||
let run_ends = array.run_ends();
|
||||
|
||||
let mut logical_idx = 0;
|
||||
let mut offset_idx = 1; // Skip first offset
|
||||
|
||||
// Iterate over each run
|
||||
for physical_idx in 0..run_ends.values().len() {
|
||||
let run_end = run_ends.values()[physical_idx].as_usize();
|
||||
|
||||
// Process all elements in this run
|
||||
while logical_idx < run_end && offset_idx < offsets.len() {
|
||||
let offset = &mut offsets[offset_idx];
|
||||
let out = &mut data[*offset..];
|
||||
|
||||
// Use variable-length encoding to make the data self-describing
|
||||
let row = rows.row(physical_idx);
|
||||
let bytes_written = variable::encode_one(out, Some(row.data), opts);
|
||||
*offset += bytes_written;
|
||||
|
||||
logical_idx += 1;
|
||||
offset_idx += 1;
|
||||
}
|
||||
|
||||
// Break if we've processed all offsets
|
||||
if offset_idx >= offsets.len() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Decodes a RunEndEncodedArray from `rows` with the provided `options`
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// `rows` must contain valid data for the provided `converter`
|
||||
pub unsafe fn decode<R: RunEndIndexType>(
|
||||
converter: &RowConverter,
|
||||
rows: &mut [&[u8]],
|
||||
field: &SortField,
|
||||
validate_utf8: bool,
|
||||
) -> Result<RunArray<R>, ArrowError> {
|
||||
if rows.is_empty() {
|
||||
let values = converter.convert_raw(&mut [], validate_utf8)?;
|
||||
let run_ends_array = PrimitiveArray::<R>::new(ScalarBuffer::from(vec![]), None);
|
||||
return RunArray::<R>::try_new(&run_ends_array, &values[0]);
|
||||
}
|
||||
|
||||
// Decode each row's REE data and collect the decoded values
|
||||
let mut decoded_values = Vec::new();
|
||||
let mut run_ends = Vec::new();
|
||||
let mut unique_row_indices = Vec::new();
|
||||
|
||||
// Process each row to extract its REE data (following decode_binary pattern)
|
||||
let mut decoded_data = Vec::new();
|
||||
for (idx, row) in rows.iter_mut().enumerate() {
|
||||
decoded_data.clear();
|
||||
// Extract the decoded value data from this row
|
||||
let consumed = variable::decode_blocks(row, field.options, |block| {
|
||||
decoded_data.extend_from_slice(block);
|
||||
});
|
||||
|
||||
// Handle bit inversion for descending sort (following decode_binary pattern)
|
||||
if field.options.descending {
|
||||
decoded_data.iter_mut().for_each(|b| *b = !*b);
|
||||
}
|
||||
|
||||
// Update the row to point past the consumed REE data
|
||||
*row = &row[consumed..];
|
||||
|
||||
// Check if this decoded value is the same as the previous one to identify runs
|
||||
let is_new_run =
|
||||
idx == 0 || decoded_data != decoded_values[*unique_row_indices.last().unwrap()];
|
||||
|
||||
if is_new_run {
|
||||
// This is a new unique value - end the previous run if any
|
||||
if idx > 0 {
|
||||
run_ends.push(R::Native::usize_as(idx));
|
||||
}
|
||||
unique_row_indices.push(decoded_values.len());
|
||||
decoded_values.push(decoded_data.clone());
|
||||
}
|
||||
}
|
||||
// Add the final run end
|
||||
run_ends.push(R::Native::usize_as(rows.len()));
|
||||
|
||||
// Convert the unique decoded values using the row converter
|
||||
let mut unique_rows: Vec<&[u8]> = decoded_values.iter().map(|v| v.as_slice()).collect();
|
||||
let values = if unique_rows.is_empty() {
|
||||
converter.convert_raw(&mut [], validate_utf8)?
|
||||
} else {
|
||||
converter.convert_raw(&mut unique_rows, validate_utf8)?
|
||||
};
|
||||
|
||||
// Create run ends array
|
||||
let run_ends_array = PrimitiveArray::<R>::new(ScalarBuffer::from(run_ends), None);
|
||||
|
||||
// Create the RunEndEncodedArray
|
||||
RunArray::<R>::try_new(&run_ends_array, &values[0])
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{RowConverter, SortField};
|
||||
use arrow_array::types::Int32Type;
|
||||
use arrow_array::{Array, Int64Array, RunArray, StringArray};
|
||||
use arrow_schema::{DataType, SortOptions};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn test_run_end_encoded_supports_datatype() {
|
||||
// Test that the RowConverter correctly supports run-end encoded arrays
|
||||
assert!(RowConverter::supports_datatype(&DataType::RunEndEncoded(
|
||||
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
|
||||
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
|
||||
)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_run_end_encoded_round_trip_int64s() {
|
||||
// Test round-trip correctness for RunEndEncodedArray with Int64 values making sure it
|
||||
// doesn't just work with eg. strings (which are all the other tests).
|
||||
|
||||
let values = Int64Array::from(vec![100, 200, 100, 300]);
|
||||
let run_ends = vec![2, 3, 5, 6];
|
||||
let array: RunArray<Int32Type> =
|
||||
RunArray::try_new(&arrow_array::PrimitiveArray::from(run_ends), &values).unwrap();
|
||||
|
||||
let converter = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded(
|
||||
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
|
||||
Arc::new(arrow_schema::Field::new("values", DataType::Int64, true)),
|
||||
))])
|
||||
.unwrap();
|
||||
|
||||
let rows = converter
|
||||
.convert_columns(&[Arc::new(array.clone())])
|
||||
.unwrap();
|
||||
|
||||
let arrays = converter.convert_rows(&rows).unwrap();
|
||||
let result = arrays[0]
|
||||
.as_any()
|
||||
.downcast_ref::<RunArray<Int32Type>>()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(array.run_ends().values(), result.run_ends().values());
|
||||
assert_eq!(array.values().as_ref(), result.values().as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_run_end_encoded_round_trip_strings() {
|
||||
// Test round-trip correctness for RunEndEncodedArray with strings
|
||||
|
||||
let array: RunArray<Int32Type> = vec!["b", "b", "a"].into_iter().collect();
|
||||
|
||||
let converter = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded(
|
||||
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
|
||||
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
|
||||
))])
|
||||
.unwrap();
|
||||
|
||||
let rows = converter
|
||||
.convert_columns(&[Arc::new(array.clone())])
|
||||
.unwrap();
|
||||
|
||||
let arrays = converter.convert_rows(&rows).unwrap();
|
||||
let result = arrays[0]
|
||||
.as_any()
|
||||
.downcast_ref::<RunArray<Int32Type>>()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(array.run_ends().values(), result.run_ends().values());
|
||||
assert_eq!(array.values().as_ref(), result.values().as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_run_end_encoded_round_trip_strings_with_nulls() {
|
||||
// Test round-trip correctness for RunEndEncodedArray with nulls
|
||||
|
||||
let array: RunArray<Int32Type> = vec![Some("b"), Some("b"), None, Some("a")]
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let converter = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded(
|
||||
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
|
||||
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
|
||||
))])
|
||||
.unwrap();
|
||||
|
||||
let rows = converter
|
||||
.convert_columns(&[Arc::new(array.clone())])
|
||||
.unwrap();
|
||||
|
||||
let arrays = converter.convert_rows(&rows).unwrap();
|
||||
let result = arrays[0]
|
||||
.as_any()
|
||||
.downcast_ref::<RunArray<Int32Type>>()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(array.run_ends().values(), result.run_ends().values());
|
||||
assert_eq!(array.values().as_ref(), result.values().as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_run_end_encoded_ascending_descending_round_trip() {
|
||||
// Test round-trip correctness for ascending vs descending sort options
|
||||
|
||||
let values_asc =
|
||||
arrow_array::StringArray::from(vec![Some("apple"), Some("banana"), Some("cherry")]);
|
||||
let run_ends_asc = vec![2, 4, 6];
|
||||
let run_array_asc: RunArray<Int32Type> = RunArray::try_new(
|
||||
&arrow_array::PrimitiveArray::from(run_ends_asc),
|
||||
&values_asc,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Test ascending order
|
||||
let converter_asc = RowConverter::new(vec![SortField::new_with_options(
|
||||
DataType::RunEndEncoded(
|
||||
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
|
||||
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
|
||||
),
|
||||
SortOptions {
|
||||
descending: false,
|
||||
nulls_first: true,
|
||||
},
|
||||
)])
|
||||
.unwrap();
|
||||
|
||||
let rows_asc = converter_asc
|
||||
.convert_columns(&[Arc::new(run_array_asc.clone())])
|
||||
.unwrap();
|
||||
let arrays_asc = converter_asc.convert_rows(&rows_asc).unwrap();
|
||||
let result_asc = arrays_asc[0]
|
||||
.as_any()
|
||||
.downcast_ref::<RunArray<Int32Type>>()
|
||||
.unwrap();
|
||||
|
||||
// Verify round-trip correctness for ascending
|
||||
assert_eq!(run_array_asc.len(), result_asc.len());
|
||||
for i in 0..run_array_asc.len() {
|
||||
let orig_physical = run_array_asc.get_physical_index(i);
|
||||
let result_physical = result_asc.get_physical_index(i);
|
||||
|
||||
let orig_values = run_array_asc
|
||||
.values()
|
||||
.as_any()
|
||||
.downcast_ref::<arrow_array::StringArray>()
|
||||
.unwrap();
|
||||
let result_values = result_asc
|
||||
.values()
|
||||
.as_any()
|
||||
.downcast_ref::<arrow_array::StringArray>()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
orig_values.value(orig_physical),
|
||||
result_values.value(result_physical),
|
||||
"Ascending sort value mismatch at index {}",
|
||||
i
|
||||
);
|
||||
}
|
||||
|
||||
// Test descending order
|
||||
let converter_desc = RowConverter::new(vec![SortField::new_with_options(
|
||||
DataType::RunEndEncoded(
|
||||
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
|
||||
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
|
||||
),
|
||||
SortOptions {
|
||||
descending: true,
|
||||
nulls_first: true,
|
||||
},
|
||||
)])
|
||||
.unwrap();
|
||||
|
||||
let rows_desc = converter_desc
|
||||
.convert_columns(&[Arc::new(run_array_asc.clone())])
|
||||
.unwrap();
|
||||
let arrays_desc = converter_desc.convert_rows(&rows_desc).unwrap();
|
||||
let result_desc = arrays_desc[0]
|
||||
.as_any()
|
||||
.downcast_ref::<RunArray<Int32Type>>()
|
||||
.unwrap();
|
||||
|
||||
// Verify round-trip correctness for descending
|
||||
assert_eq!(run_array_asc.len(), result_desc.len());
|
||||
for i in 0..run_array_asc.len() {
|
||||
let orig_physical = run_array_asc.get_physical_index(i);
|
||||
let result_physical = result_desc.get_physical_index(i);
|
||||
|
||||
let orig_values = run_array_asc
|
||||
.values()
|
||||
.as_any()
|
||||
.downcast_ref::<arrow_array::StringArray>()
|
||||
.unwrap();
|
||||
let result_values = result_desc
|
||||
.values()
|
||||
.as_any()
|
||||
.downcast_ref::<arrow_array::StringArray>()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
orig_values.value(orig_physical),
|
||||
result_values.value(result_physical),
|
||||
"Descending sort value mismatch at index {}",
|
||||
i
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_run_end_encoded_sort_configurations_basic() {
|
||||
// Test that different sort configurations work and can round-trip successfully
|
||||
|
||||
let test_array: RunArray<Int32Type> = vec!["test"].into_iter().collect();
|
||||
|
||||
let converter_asc = RowConverter::new(vec![SortField::new_with_options(
|
||||
DataType::RunEndEncoded(
|
||||
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
|
||||
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
|
||||
),
|
||||
SortOptions {
|
||||
descending: false,
|
||||
nulls_first: true,
|
||||
},
|
||||
)])
|
||||
.unwrap();
|
||||
|
||||
let converter_desc = RowConverter::new(vec![SortField::new_with_options(
|
||||
DataType::RunEndEncoded(
|
||||
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
|
||||
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
|
||||
),
|
||||
SortOptions {
|
||||
descending: true,
|
||||
nulls_first: true,
|
||||
},
|
||||
)])
|
||||
.unwrap();
|
||||
|
||||
let rows_test_asc = converter_asc
|
||||
.convert_columns(&[Arc::new(test_array.clone())])
|
||||
.unwrap();
|
||||
let rows_test_desc = converter_desc
|
||||
.convert_columns(&[Arc::new(test_array.clone())])
|
||||
.unwrap();
|
||||
|
||||
// Convert back to verify both configurations work
|
||||
let result_test_asc = converter_asc.convert_rows(&rows_test_asc).unwrap();
|
||||
let result_test_desc = converter_desc.convert_rows(&rows_test_desc).unwrap();
|
||||
|
||||
// Both should successfully reconstruct the original
|
||||
assert_eq!(result_test_asc.len(), 1);
|
||||
assert_eq!(result_test_desc.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_run_end_encoded_nulls_first_last_configurations() {
|
||||
// Test that nulls_first vs nulls_last configurations work
|
||||
|
||||
let simple_array: RunArray<Int32Type> = vec!["simple"].into_iter().collect();
|
||||
|
||||
let converter_nulls_first = RowConverter::new(vec![SortField::new_with_options(
|
||||
DataType::RunEndEncoded(
|
||||
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
|
||||
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
|
||||
),
|
||||
SortOptions {
|
||||
descending: false,
|
||||
nulls_first: true,
|
||||
},
|
||||
)])
|
||||
.unwrap();
|
||||
|
||||
let converter_nulls_last = RowConverter::new(vec![SortField::new_with_options(
|
||||
DataType::RunEndEncoded(
|
||||
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
|
||||
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
|
||||
),
|
||||
SortOptions {
|
||||
descending: false,
|
||||
nulls_first: false,
|
||||
},
|
||||
)])
|
||||
.unwrap();
|
||||
|
||||
// Test that both configurations can handle simple arrays
|
||||
let rows_nulls_first = converter_nulls_first
|
||||
.convert_columns(&[Arc::new(simple_array.clone())])
|
||||
.unwrap();
|
||||
let arrays_nulls_first = converter_nulls_first
|
||||
.convert_rows(&rows_nulls_first)
|
||||
.unwrap();
|
||||
let result_nulls_first = arrays_nulls_first[0]
|
||||
.as_any()
|
||||
.downcast_ref::<RunArray<Int32Type>>()
|
||||
.unwrap();
|
||||
|
||||
let rows_nulls_last = converter_nulls_last
|
||||
.convert_columns(&[Arc::new(simple_array.clone())])
|
||||
.unwrap();
|
||||
let arrays_nulls_last = converter_nulls_last.convert_rows(&rows_nulls_last).unwrap();
|
||||
let result_nulls_last = arrays_nulls_last[0]
|
||||
.as_any()
|
||||
.downcast_ref::<RunArray<Int32Type>>()
|
||||
.unwrap();
|
||||
|
||||
// Both should successfully convert the simple array
|
||||
assert_eq!(simple_array.len(), result_nulls_first.len());
|
||||
assert_eq!(simple_array.len(), result_nulls_last.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_run_end_encoded_row_consumption() {
|
||||
// This test verifies that ALL rows are properly consumed during decoding,
|
||||
// not just the unique values. We test this by ensuring multi-column conversion
|
||||
// works correctly - if rows aren't consumed properly, the second column would fail.
|
||||
|
||||
// Create a REE array with multiple runs
|
||||
let array: RunArray<Int32Type> = vec!["a", "a", "b", "b", "b", "c"].into_iter().collect();
|
||||
let string_array = StringArray::from(vec!["x", "y", "z", "w", "u", "v"]);
|
||||
|
||||
let multi_converter = RowConverter::new(vec![
|
||||
SortField::new(DataType::RunEndEncoded(
|
||||
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
|
||||
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
|
||||
)),
|
||||
SortField::new(DataType::Utf8),
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let multi_rows = multi_converter
|
||||
.convert_columns(&[Arc::new(array.clone()), Arc::new(string_array.clone())])
|
||||
.unwrap();
|
||||
|
||||
// Convert back - this will test that all rows are consumed properly
|
||||
let arrays = multi_converter.convert_rows(&multi_rows).unwrap();
|
||||
|
||||
// Verify both columns round-trip correctly
|
||||
let result_ree = arrays[0]
|
||||
.as_any()
|
||||
.downcast_ref::<RunArray<Int32Type>>()
|
||||
.unwrap();
|
||||
|
||||
let result_string = arrays[1].as_any().downcast_ref::<StringArray>().unwrap();
|
||||
|
||||
// This should pass - both arrays should be identical to originals
|
||||
assert_eq!(result_ree.values().as_ref(), array.values().as_ref());
|
||||
assert_eq!(result_ree.run_ends().values(), array.run_ends().values());
|
||||
assert_eq!(*result_string, string_array);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_run_end_encoded_sorting_behavior() {
|
||||
// Test that the binary row encoding actually produces the correct sort order
|
||||
|
||||
// Create REE arrays with different values to test sorting
|
||||
let array1: RunArray<Int32Type> = vec!["apple", "apple"].into_iter().collect();
|
||||
let array2: RunArray<Int32Type> = vec!["banana", "banana"].into_iter().collect();
|
||||
let array3: RunArray<Int32Type> = vec!["cherry", "cherry"].into_iter().collect();
|
||||
|
||||
// Test ascending sort
|
||||
let converter_asc = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded(
|
||||
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
|
||||
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
|
||||
))])
|
||||
.unwrap();
|
||||
|
||||
let rows1_asc = converter_asc
|
||||
.convert_columns(&[Arc::new(array1.clone())])
|
||||
.unwrap();
|
||||
let rows2_asc = converter_asc
|
||||
.convert_columns(&[Arc::new(array2.clone())])
|
||||
.unwrap();
|
||||
let rows3_asc = converter_asc
|
||||
.convert_columns(&[Arc::new(array3.clone())])
|
||||
.unwrap();
|
||||
|
||||
// For ascending: apple < banana < cherry
|
||||
// So row bytes should sort: rows1 < rows2 < rows3
|
||||
assert!(
|
||||
rows1_asc.row(0) < rows2_asc.row(0),
|
||||
"apple should come before banana in ascending order"
|
||||
);
|
||||
assert!(
|
||||
rows2_asc.row(0) < rows3_asc.row(0),
|
||||
"banana should come before cherry in ascending order"
|
||||
);
|
||||
assert!(
|
||||
rows1_asc.row(0) < rows3_asc.row(0),
|
||||
"apple should come before cherry in ascending order"
|
||||
);
|
||||
|
||||
// Test descending sort
|
||||
let converter_desc = RowConverter::new(vec![SortField::new_with_options(
|
||||
DataType::RunEndEncoded(
|
||||
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
|
||||
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
|
||||
),
|
||||
arrow_schema::SortOptions {
|
||||
descending: true,
|
||||
nulls_first: true,
|
||||
},
|
||||
)])
|
||||
.unwrap();
|
||||
|
||||
let rows1_desc = converter_desc
|
||||
.convert_columns(&[Arc::new(array1.clone())])
|
||||
.unwrap();
|
||||
let rows2_desc = converter_desc
|
||||
.convert_columns(&[Arc::new(array2.clone())])
|
||||
.unwrap();
|
||||
let rows3_desc = converter_desc
|
||||
.convert_columns(&[Arc::new(array3.clone())])
|
||||
.unwrap();
|
||||
|
||||
// For descending: cherry > banana > apple
|
||||
// So row bytes should sort: rows3 < rows2 < rows1 (because byte comparison is ascending)
|
||||
assert!(
|
||||
rows3_desc.row(0) < rows2_desc.row(0),
|
||||
"cherry should come before banana in descending order (byte-wise)"
|
||||
);
|
||||
assert!(
|
||||
rows2_desc.row(0) < rows1_desc.row(0),
|
||||
"banana should come before apple in descending order (byte-wise)"
|
||||
);
|
||||
assert!(
|
||||
rows3_desc.row(0) < rows1_desc.row(0),
|
||||
"cherry should come before apple in descending order (byte-wise)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_run_end_encoded_null_sorting() {
|
||||
// Test null handling in sort order
|
||||
|
||||
let array_with_nulls: RunArray<Int32Type> = vec![None, None].into_iter().collect();
|
||||
let array_with_values: RunArray<Int32Type> = vec!["apple", "apple"].into_iter().collect();
|
||||
|
||||
// Test nulls_first = true
|
||||
let converter_nulls_first = RowConverter::new(vec![SortField::new_with_options(
|
||||
DataType::RunEndEncoded(
|
||||
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
|
||||
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
|
||||
),
|
||||
arrow_schema::SortOptions {
|
||||
descending: false,
|
||||
nulls_first: true,
|
||||
},
|
||||
)])
|
||||
.unwrap();
|
||||
|
||||
let rows_nulls = converter_nulls_first
|
||||
.convert_columns(&[Arc::new(array_with_nulls.clone())])
|
||||
.unwrap();
|
||||
let rows_values = converter_nulls_first
|
||||
.convert_columns(&[Arc::new(array_with_values.clone())])
|
||||
.unwrap();
|
||||
|
||||
// nulls should come before values when nulls_first = true
|
||||
assert!(
|
||||
rows_nulls.row(0) < rows_values.row(0),
|
||||
"nulls should come before values when nulls_first=true"
|
||||
);
|
||||
|
||||
// Test nulls_first = false
|
||||
let converter_nulls_last = RowConverter::new(vec![SortField::new_with_options(
|
||||
DataType::RunEndEncoded(
|
||||
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
|
||||
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
|
||||
),
|
||||
arrow_schema::SortOptions {
|
||||
descending: false,
|
||||
nulls_first: false,
|
||||
},
|
||||
)])
|
||||
.unwrap();
|
||||
|
||||
let rows_nulls_last = converter_nulls_last
|
||||
.convert_columns(&[Arc::new(array_with_nulls.clone())])
|
||||
.unwrap();
|
||||
let rows_values_last = converter_nulls_last
|
||||
.convert_columns(&[Arc::new(array_with_values.clone())])
|
||||
.unwrap();
|
||||
|
||||
// values should come before nulls when nulls_first = false
|
||||
assert!(
|
||||
rows_values_last.row(0) < rows_nulls_last.row(0),
|
||||
"values should come before nulls when nulls_first=false"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_run_end_encoded_mixed_sorting() {
|
||||
// Test sorting with mixed values and nulls to ensure complex scenarios work
|
||||
|
||||
let array1: RunArray<Int32Type> = vec![Some("apple"), None].into_iter().collect();
|
||||
let array2: RunArray<Int32Type> = vec![None, Some("banana")].into_iter().collect();
|
||||
let array3: RunArray<Int32Type> =
|
||||
vec![Some("cherry"), Some("cherry")].into_iter().collect();
|
||||
|
||||
let converter = RowConverter::new(vec![SortField::new_with_options(
|
||||
DataType::RunEndEncoded(
|
||||
Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)),
|
||||
Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)),
|
||||
),
|
||||
arrow_schema::SortOptions {
|
||||
descending: false,
|
||||
nulls_first: true,
|
||||
},
|
||||
)])
|
||||
.unwrap();
|
||||
|
||||
let rows1 = converter.convert_columns(&[Arc::new(array1)]).unwrap();
|
||||
let rows2 = converter.convert_columns(&[Arc::new(array2)]).unwrap();
|
||||
let rows3 = converter.convert_columns(&[Arc::new(array3)]).unwrap();
|
||||
|
||||
// With nulls_first=true, ascending:
|
||||
// Row 0: array1[0]="apple", array2[0]=null, array3[0]="cherry" -> null < apple < cherry
|
||||
// Row 1: array1[1]=null, array2[1]="banana", array3[1]="cherry" -> null < banana < cherry
|
||||
|
||||
// Compare first rows: null < apple < cherry
|
||||
assert!(rows2.row(0) < rows1.row(0), "null should come before apple");
|
||||
assert!(
|
||||
rows1.row(0) < rows3.row(0),
|
||||
"apple should come before cherry"
|
||||
);
|
||||
|
||||
// Compare second rows: null < banana < cherry
|
||||
assert!(
|
||||
rows1.row(1) < rows2.row(1),
|
||||
"null should come before banana"
|
||||
);
|
||||
assert!(
|
||||
rows2.row(1) < rows3.row(1),
|
||||
"banana should come before cherry"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -592,6 +592,7 @@ impl DataType {
|
||||
use DataType::*;
|
||||
match self {
|
||||
Dictionary(_, v) => DataType::is_nested(v.as_ref()),
|
||||
RunEndEncoded(_, v) => DataType::is_nested(v.data_type()),
|
||||
List(_)
|
||||
| FixedSizeList(_, _)
|
||||
| LargeList(_)
|
||||
|
||||
Reference in New Issue
Block a user