mirror of
https://github.com/langchain-ai/datafusion.git
synced 2026-07-01 21:24:06 -04:00
fix: HashJoin panic with dictionary-encoded columns in multi-key joins (#20441)
## Which issue does this PR close? - Closes #20437 ## Rationale for this change `flatten_dictionary_array` returned only the unique values rather then the full expanded array when being called on a `DictionaryArray`. When building a `StructArray` this caused a length mismatch panic. ## What changes are included in this PR? Replaced `array.values()` with `arrow::compute::cast(array, value_type)` in `flatten_dictionary_array`, which properly expands the dictionary into a full length array matching the row count. ## Are these changes tested? Yes, both a new unit test aswell as a regression test were added. ## Are there any user-facing changes? Nope --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
This commit is contained in:
@@ -20,8 +20,8 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::array::{ArrayRef, StructArray};
|
||||
use arrow::compute::cast;
|
||||
use arrow::datatypes::{Field, FieldRef, Fields};
|
||||
use arrow::downcast_dictionary_array;
|
||||
use arrow_schema::DataType;
|
||||
use datafusion_common::Result;
|
||||
|
||||
@@ -33,15 +33,16 @@ pub(super) fn build_struct_fields(data_types: &[DataType]) -> Result<Fields> {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Flattens dictionary-encoded arrays to their underlying value arrays.
|
||||
/// Casts dictionary-encoded arrays to their underlying value type, preserving row count.
|
||||
/// Non-dictionary arrays are returned as-is.
|
||||
fn flatten_dictionary_array(array: &ArrayRef) -> ArrayRef {
|
||||
downcast_dictionary_array! {
|
||||
array => {
|
||||
fn flatten_dictionary_array(array: &ArrayRef) -> Result<ArrayRef> {
|
||||
match array.data_type() {
|
||||
DataType::Dictionary(_, value_type) => {
|
||||
let casted = cast(array, value_type)?;
|
||||
// Recursively flatten in case of nested dictionaries
|
||||
flatten_dictionary_array(array.values())
|
||||
flatten_dictionary_array(&casted)
|
||||
}
|
||||
_ => Arc::clone(array)
|
||||
_ => Ok(Arc::clone(array)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,7 +69,7 @@ pub(super) fn build_struct_inlist_values(
|
||||
let flattened_arrays: Vec<ArrayRef> = join_key_arrays
|
||||
.iter()
|
||||
.map(flatten_dictionary_array)
|
||||
.collect();
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
// Build the source array/struct
|
||||
let source_array: ArrayRef = if flattened_arrays.len() == 1 {
|
||||
@@ -99,7 +100,9 @@ pub(super) fn build_struct_inlist_values(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use arrow::array::{Int32Array, StringArray};
|
||||
use arrow::array::{
|
||||
DictionaryArray, Int8Array, Int32Array, StringArray, StringDictionaryBuilder,
|
||||
};
|
||||
use arrow_schema::DataType;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -130,4 +133,41 @@ mod tests {
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_multi_column_inlist_with_dictionary() {
|
||||
let mut builder = StringDictionaryBuilder::<arrow::datatypes::Int8Type>::new();
|
||||
builder.append_value("foo");
|
||||
builder.append_value("foo");
|
||||
builder.append_value("foo");
|
||||
let dict_array = Arc::new(builder.finish()) as ArrayRef;
|
||||
|
||||
let int_array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
|
||||
|
||||
let result = build_struct_inlist_values(&[dict_array, int_array])
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 3);
|
||||
assert_eq!(
|
||||
*result.data_type(),
|
||||
DataType::Struct(
|
||||
build_struct_fields(&[DataType::Utf8, DataType::Int32]).unwrap()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_single_column_dictionary_inlist() {
|
||||
let keys = Int8Array::from(vec![0i8, 0, 0]);
|
||||
let values = Arc::new(StringArray::from(vec!["foo"]));
|
||||
let dict_array = Arc::new(DictionaryArray::new(keys, values)) as ArrayRef;
|
||||
|
||||
let result = build_struct_inlist_values(std::slice::from_ref(&dict_array))
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 3);
|
||||
assert_eq!(*result.data_type(), DataType::Utf8);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5290,3 +5290,31 @@ DROP TABLE empty_proj_left;
|
||||
|
||||
statement count 0
|
||||
DROP TABLE empty_proj_right;
|
||||
|
||||
# Issue #20437: HashJoin panic with dictionary-encoded columns in multi-key joins
|
||||
# https://github.com/apache/datafusion/issues/20437
|
||||
|
||||
statement ok
|
||||
CREATE TABLE issue_20437_small AS
|
||||
SELECT id, arrow_cast(region, 'Dictionary(Int32, Utf8)') AS region
|
||||
FROM (VALUES (1, 'west'), (2, 'west')) AS t(id, region);
|
||||
|
||||
statement ok
|
||||
CREATE TABLE issue_20437_large AS
|
||||
SELECT id, region, value
|
||||
FROM (VALUES (1, 'west', 100), (2, 'west', 200), (3, 'east', 300)) AS t(id, region, value);
|
||||
|
||||
query ITI
|
||||
SELECT s.id, s.region, l.value
|
||||
FROM issue_20437_small s
|
||||
JOIN issue_20437_large l ON s.id = l.id AND s.region = l.region
|
||||
ORDER BY s.id;
|
||||
----
|
||||
1 west 100
|
||||
2 west 200
|
||||
|
||||
statement count 0
|
||||
DROP TABLE issue_20437_small;
|
||||
|
||||
statement count 0
|
||||
DROP TABLE issue_20437_large;
|
||||
|
||||
Reference in New Issue
Block a user