fix: Fix and Refactor Spark shuffle function (#20484)

## Which issue does this PR close?
- Closes #20483.

## Rationale for this change
Currently, Spark `shuffle` function returns following error message when
`seed` is `null`. This needs to be fixed by exposing `NULL` instead of
`'Int64'`.

**Current:**
```
query error
SELECT shuffle([2, 1], NULL);
----
DataFusion error: Execution error: shuffle seed must be Int64 type, got 'Int64'
```

**New:**
```
query error DataFusion error: Execution error: shuffle seed must be Int64 type but got 'NULL'
SELECT shuffle([1, 2, 3], NULL);
```

In addition to this fix, this PR also introduces following refactoring
to `shuffle` function:
- Combining args validation checks with `single` error message,
- Extending current error message with expected data types:
```
Current:
shuffle does not support type '{array_type}'.

New:
shuffle does not support type '{array_type}'; expected types: List, LargeList, FixedSizeList or Null." 
```
- Adding new UT coverages for both `shuffle.rs` and `shuffle.slt`.

## What changes are included in this PR?

<!--
There is no need to duplicate the description in the issue here but it
is sometimes worth providing a summary of the individual changes in this
PR.
-->

## Are these changes tested?
Yes, being added new UT cases.

## Are there any user-facing changes?
Yes, updating Spark `shuffle` functions error messages.
This commit is contained in:
Eren Avsarogullari
2026-02-27 16:05:26 -08:00
committed by GitHub
parent e567cb91f4
commit 5d8249ff16
2 changed files with 18 additions and 8 deletions
@@ -105,11 +105,8 @@ impl ScalarUDFImpl for SparkShuffle {
&self,
args: datafusion_expr::ScalarFunctionArgs,
) -> Result<ColumnarValue> {
if args.args.is_empty() {
return exec_err!("shuffle expects at least 1 argument");
}
if args.args.len() > 2 {
return exec_err!("shuffle expects at most 2 arguments");
if args.args.is_empty() || args.args.len() > 2 {
return exec_err!("shuffle expects 1 or 2 argument(s)");
}
// Extract seed from second argument if present
@@ -131,10 +128,10 @@ fn extract_seed(seed_arg: &ColumnarValue) -> Result<Option<u64>> {
ColumnarValue::Scalar(scalar) => {
let seed = match scalar {
ScalarValue::Int64(Some(v)) => Some(*v as u64),
ScalarValue::Null => None,
ScalarValue::Null | ScalarValue::Int64(None) => None,
_ => {
return exec_err!(
"shuffle seed must be Int64 type, got '{}'",
"shuffle seed must be Int64 type but got '{}'",
scalar.data_type()
);
}
@@ -164,7 +161,10 @@ fn array_shuffle_with_seed(arg: &[ArrayRef], seed: Option<u64>) -> Result<ArrayR
fixed_size_array_shuffle(array, field, seed)
}
Null => Ok(Arc::clone(input_array)),
array_type => exec_err!("shuffle does not support type '{array_type}'."),
array_type => exec_err!(
"shuffle does not support type '{array_type}'; \
expected types: List, LargeList, FixedSizeList or Null."
),
}
}
@@ -107,6 +107,16 @@ SELECT shuffle([1, 2, 3, 4], CAST('2' AS INT));
----
[1, 4, 2, 3]
query ?
SELECT shuffle(['ab'], NULL);
----
[ab]
query ?
SELECT shuffle(shuffle([3, 3], NULL), NULL);
----
[3, 3]
# Clean up
statement ok
DROP TABLE test_shuffle_list_types;