LexRequirement as a struct, instead of a type (#12583)

* Converted LexRequirement into a struct.

* Adjusted the wrapping to return the correct type, since the LexRequirement was not being converted after the merge.

---------

Co-authored-by: nglime <git@ngli.me>
This commit is contained in:
ngli-me
2024-09-23 12:10:19 -05:00
committed by GitHub
parent 04895c451f
commit bd8960eabc
12 changed files with 134 additions and 57 deletions
@@ -53,6 +53,7 @@ use datafusion_physical_expr::{
use async_trait::async_trait;
use datafusion_catalog::Session;
use datafusion_physical_expr_common::sort_expr::LexRequirement;
use futures::{future, stream, StreamExt, TryStreamExt};
use itertools::Itertools;
use object_store::ObjectStore;
@@ -987,12 +988,12 @@ impl TableProvider for ListingTable {
))?
.clone();
// Converts Vec<Vec<SortExpr>> into type required by execution plan to specify its required input ordering
Some(
Some(LexRequirement::new(
ordering
.into_iter()
.map(PhysicalSortRequirement::from)
.collect::<Vec<_>>(),
)
))
} else {
None
};
@@ -1272,7 +1272,7 @@ fn ensure_distribution(
// Make sure to satisfy ordering requirement:
child = add_sort_above_with_check(
child,
required_input_ordering.to_vec(),
required_input_ordering.clone(),
None,
);
}
@@ -54,6 +54,7 @@ use datafusion_physical_expr::{
use datafusion_physical_plan::streaming::StreamingTableExec;
use datafusion_physical_plan::union::UnionExec;
use datafusion_physical_expr_common::sort_expr::LexRequirement;
use datafusion_physical_optimizer::PhysicalOptimizerRule;
use itertools::Itertools;
@@ -334,10 +335,10 @@ fn try_swapping_with_output_req(
return Ok(None);
}
let mut updated_sort_reqs = vec![];
let mut updated_sort_reqs = LexRequirement::new(vec![]);
// None or empty_vec can be treated in the same way.
if let Some(reqs) = &output_req.required_input_ordering()[0] {
for req in reqs {
for req in &reqs.inner {
let Some(new_expr) = update_expr(&req.expr, projection.expr(), false)? else {
return Ok(None);
};
@@ -1995,7 +1996,7 @@ mod tests {
let csv = create_simple_csv_exec();
let sort_req: Arc<dyn ExecutionPlan> = Arc::new(OutputRequirementExec::new(
csv.clone(),
Some(vec![
Some(LexRequirement::new(vec![
PhysicalSortRequirement {
expr: Arc::new(Column::new("b", 1)),
options: Some(SortOptions::default()),
@@ -2008,7 +2009,7 @@ mod tests {
)),
options: Some(SortOptions::default()),
},
]),
])),
Distribution::HashPartitioned(vec![
Arc::new(Column::new("a", 0)),
Arc::new(Column::new("b", 1)),
@@ -2041,7 +2042,7 @@ mod tests {
];
assert_eq!(get_plan_string(&after_optimize), expected);
let expected_reqs = vec![
let expected_reqs = LexRequirement::new(vec![
PhysicalSortRequirement {
expr: Arc::new(Column::new("b", 2)),
options: Some(SortOptions::default()),
@@ -2054,7 +2055,7 @@ mod tests {
)),
options: Some(SortOptions::default()),
},
];
]);
assert_eq!(
after_optimize
.as_any()
@@ -173,7 +173,8 @@ fn pushdown_requirement_to_children(
let child_plan = plan.children().swap_remove(0);
match determine_children_requirement(parent_required, request_child, child_plan) {
RequirementsCompatibility::Satisfy => {
let req = (!request_child.is_empty()).then(|| request_child.to_vec());
let req = (!request_child.is_empty())
.then(|| LexRequirement::new(request_child.to_vec()));
Ok(Some(vec![req]))
}
RequirementsCompatibility::Compatible(adjusted) => Ok(Some(vec![adjusted])),
@@ -189,7 +190,9 @@ fn pushdown_requirement_to_children(
.requirements_compatible(parent_required, &sort_req)
{
debug_assert!(!parent_required.is_empty());
Ok(Some(vec![Some(parent_required.to_vec())]))
Ok(Some(vec![Some(LexRequirement::new(
parent_required.to_vec(),
))]))
} else {
Ok(None)
}
@@ -211,7 +214,8 @@ fn pushdown_requirement_to_children(
.eq_properties
.requirements_compatible(parent_required, &output_req)
{
let req = (!parent_required.is_empty()).then(|| parent_required.to_vec());
let req = (!parent_required.is_empty())
.then(|| LexRequirement::new(parent_required.to_vec()));
Ok(Some(vec![req]))
} else {
Ok(None)
@@ -219,7 +223,8 @@ fn pushdown_requirement_to_children(
} else if is_union(plan) {
// UnionExec does not have real sort requirements for its input. Here we change the adjusted_request_ordering to UnionExec's output ordering and
// propagate the sort requirements down to correct the unnecessary descendant SortExec under the UnionExec
let req = (!parent_required.is_empty()).then(|| parent_required.to_vec());
let req = (!parent_required.is_empty())
.then(|| LexRequirement::new(parent_required.to_vec()));
Ok(Some(vec![req; plan.children().len()]))
} else if let Some(smj) = plan.as_any().downcast_ref::<SortMergeJoinExec>() {
// If the current plan is SortMergeJoinExec
@@ -277,7 +282,8 @@ fn pushdown_requirement_to_children(
} else {
// Can push-down through SortPreservingMergeExec, because parent requirement is finer
// than SortPreservingMergeExec output ordering.
let req = (!parent_required.is_empty()).then(|| parent_required.to_vec());
let req = (!parent_required.is_empty())
.then(|| LexRequirement::new(parent_required.to_vec()));
Ok(Some(vec![req]))
}
} else {
@@ -331,7 +337,8 @@ fn determine_children_requirement(
{
// Parent requirements are more specific, adjust child's requirements
// and push down the new requirements:
let adjusted = (!parent_required.is_empty()).then(|| parent_required.to_vec());
let adjusted = (!parent_required.is_empty())
.then(|| LexRequirement::new(parent_required.to_vec()));
RequirementsCompatibility::Compatible(adjusted)
} else {
RequirementsCompatibility::NonCompatible
@@ -471,7 +478,7 @@ fn shift_right_required(
})
.collect::<Vec<_>>();
if new_right_required.len() == parent_required.len() {
Ok(new_right_required)
Ok(LexRequirement::new(new_right_required))
} else {
plan_err!(
"Expect to shift all the parent required column indexes for SortMergeJoin"
@@ -574,7 +581,7 @@ fn handle_custom_pushdown(
.iter()
.map(|&maintains_order| {
if maintains_order {
Some(updated_parent_req.clone())
Some(LexRequirement::new(updated_parent_req.clone()))
} else {
None
}
@@ -19,6 +19,7 @@
use std::fmt::Display;
use std::hash::{Hash, Hasher};
use std::ops::Deref;
use std::sync::Arc;
use crate::physical_expr::PhysicalExpr;
@@ -296,11 +297,13 @@ impl PhysicalSortRequirement {
pub fn from_sort_exprs<'a>(
ordering: impl IntoIterator<Item = &'a PhysicalSortExpr>,
) -> LexRequirement {
ordering
.into_iter()
.cloned()
.map(PhysicalSortRequirement::from)
.collect()
LexRequirement::new(
ordering
.into_iter()
.cloned()
.map(PhysicalSortRequirement::from)
.collect(),
)
}
/// Converts an iterator of [`PhysicalSortRequirement`] into a Vec
@@ -338,9 +341,55 @@ pub type LexOrdering = Vec<PhysicalSortExpr>;
/// a reference to a lexicographical ordering.
pub type LexOrderingRef<'a> = &'a [PhysicalSortExpr];
///`LexRequirement` is an alias for the type `Vec<PhysicalSortRequirement>`, which
///`LexRequirement` is an struct containing a `Vec<PhysicalSortRequirement>`, which
/// represents a lexicographical ordering requirement.
pub type LexRequirement = Vec<PhysicalSortRequirement>;
#[derive(Debug, Default, Clone, PartialEq)]
pub struct LexRequirement {
pub inner: Vec<PhysicalSortRequirement>,
}
impl LexRequirement {
pub fn new(inner: Vec<PhysicalSortRequirement>) -> Self {
Self { inner }
}
pub fn iter(&self) -> impl Iterator<Item = &PhysicalSortRequirement> {
self.inner.iter()
}
pub fn push(&mut self, physical_sort_requirement: PhysicalSortRequirement) {
self.inner.push(physical_sort_requirement)
}
}
impl Deref for LexRequirement {
type Target = [PhysicalSortRequirement];
fn deref(&self) -> &Self::Target {
self.inner.as_slice()
}
}
impl FromIterator<PhysicalSortRequirement> for LexRequirement {
fn from_iter<T: IntoIterator<Item = PhysicalSortRequirement>>(iter: T) -> Self {
let mut lex_requirement = LexRequirement::new(vec![]);
for i in iter {
lex_requirement.inner.push(i);
}
lex_requirement
}
}
impl IntoIterator for LexRequirement {
type Item = PhysicalSortRequirement;
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
self.inner.into_iter()
}
}
///`LexRequirementRef` is an alias for the type &`[PhysicalSortRequirement]`, which
/// represents a reference to a lexicographical ordering requirement.
@@ -418,7 +418,7 @@ impl EquivalenceGroup {
// Normalize the requirements:
let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs);
// Convert sort requirements back to sort expressions:
PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs)
PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs.inner)
}
/// This function applies the `normalize_sort_requirement` function for all
@@ -428,12 +428,12 @@ impl EquivalenceGroup {
&self,
sort_reqs: LexRequirementRef,
) -> LexRequirement {
collapse_lex_req(
collapse_lex_req(LexRequirement::new(
sort_reqs
.iter()
.map(|sort_req| self.normalize_sort_requirement(sort_req.clone()))
.collect(),
)
))
}
/// Projects `expr` according to the given projection mapping.
@@ -48,7 +48,7 @@ pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement {
output.push(item);
}
}
output
LexRequirement::new(output)
}
/// Adds the `offset` value to `Column` indices inside `expr`. This function is
@@ -515,8 +515,9 @@ impl EquivalenceProperties {
) -> Option<LexRequirement> {
let mut lhs = self.normalize_sort_requirements(req1);
let mut rhs = self.normalize_sort_requirements(req2);
lhs.iter_mut()
.zip(rhs.iter_mut())
lhs.inner
.iter_mut()
.zip(rhs.inner.iter_mut())
.all(|(lhs, rhs)| {
lhs.expr.eq(&rhs.expr)
&& match (lhs.options, rhs.options) {
+10 -8
View File
@@ -370,13 +370,15 @@ impl AggregateExec {
// prefix requirements with this section. In this case, aggregation will
// work more efficiently.
let indices = get_ordered_partition_by_indices(&groupby_exprs, &input);
let mut new_requirement = indices
.iter()
.map(|&idx| PhysicalSortRequirement {
expr: Arc::clone(&groupby_exprs[idx]),
options: None,
})
.collect::<Vec<_>>();
let mut new_requirement = LexRequirement::new(
indices
.iter()
.map(|&idx| PhysicalSortRequirement {
expr: Arc::clone(&groupby_exprs[idx]),
options: None,
})
.collect::<Vec<_>>(),
);
let req = get_finer_aggregate_exprs_requirement(
&mut aggr_expr,
@@ -384,7 +386,7 @@ impl AggregateExec {
input_eq_properties,
&mode,
)?;
new_requirement.extend(req);
new_requirement.inner.extend(req);
new_requirement = collapse_lex_req(new_requirement);
// If our aggregation has grouping sets then our base grouping exprs will
+6 -2
View File
@@ -792,7 +792,9 @@ impl SortExec {
) -> PlanProperties {
// Determine execution mode:
let sort_satisfied = input.equivalence_properties().ordering_satisfy_requirement(
PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()).as_slice(),
PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter())
.inner
.as_slice(),
);
let mode = match input.execution_mode() {
ExecutionMode::Unbounded if sort_satisfied => ExecutionMode::Unbounded,
@@ -895,7 +897,9 @@ impl ExecutionPlan for SortExec {
.input
.equivalence_properties()
.ordering_satisfy_requirement(
PhysicalSortRequirement::from_sort_exprs(self.expr.iter()).as_slice(),
PhysicalSortRequirement::from_sort_exprs(self.expr.iter())
.inner
.as_slice(),
);
match (sort_satisfied, self.fetch.as_ref()) {
+20 -10
View File
@@ -399,12 +399,14 @@ pub(crate) fn calc_requirements<
partition_by_exprs: impl IntoIterator<Item = T>,
orderby_sort_exprs: impl IntoIterator<Item = S>,
) -> Option<LexRequirement> {
let mut sort_reqs = partition_by_exprs
.into_iter()
.map(|partition_by| {
PhysicalSortRequirement::new(Arc::clone(partition_by.borrow()), None)
})
.collect::<Vec<_>>();
let mut sort_reqs = LexRequirement::new(
partition_by_exprs
.into_iter()
.map(|partition_by| {
PhysicalSortRequirement::new(Arc::clone(partition_by.borrow()), None)
})
.collect::<Vec<_>>(),
);
for element in orderby_sort_exprs.into_iter() {
let PhysicalSortExpr { expr, options } = element.borrow();
if !sort_reqs.iter().any(|e| e.expr.eq(expr)) {
@@ -568,12 +570,18 @@ pub fn get_window_mode(
input: &Arc<dyn ExecutionPlan>,
) -> Option<(bool, InputOrderMode)> {
let input_eqs = input.equivalence_properties().clone();
let mut partition_by_reqs: LexRequirement = vec![];
let mut partition_by_reqs: LexRequirement = LexRequirement::new(vec![]);
let (_, indices) = input_eqs.find_longest_permutation(partitionby_exprs);
partition_by_reqs.extend(indices.iter().map(|&idx| PhysicalSortRequirement {
vec![].extend(indices.iter().map(|&idx| PhysicalSortRequirement {
expr: Arc::clone(&partitionby_exprs[idx]),
options: None,
}));
partition_by_reqs
.inner
.extend(indices.iter().map(|&idx| PhysicalSortRequirement {
expr: Arc::clone(&partitionby_exprs[idx]),
options: None,
}));
// Treat partition by exprs as constant. During analysis of requirements are satisfied.
let const_exprs = partitionby_exprs.iter().map(ConstExpr::from);
let partition_by_eqs = input_eqs.add_constants(const_exprs);
@@ -583,7 +591,9 @@ pub fn get_window_mode(
for (should_swap, order_by_reqs) in
[(false, order_by_reqs), (true, reverse_order_by_reqs)]
{
let req = [partition_by_reqs.clone(), order_by_reqs].concat();
let req = LexRequirement::new(
[partition_by_reqs.inner.clone(), order_by_reqs.inner].concat(),
);
let req = collapse_lex_req(req);
if partition_by_eqs.ordering_satisfy_requirement(&req) {
// Window can be run with existing ordering
@@ -736,7 +746,7 @@ mod tests {
if let Some(expected) = &mut expected {
expected.push(res);
} else {
expected = Some(vec![res]);
expected = Some(LexRequirement::new(vec![res]));
}
}
assert_eq!(calc_requirements(partitionbys, orderbys), expected);
@@ -50,7 +50,9 @@ use datafusion::functions_aggregate::sum::sum_udaf;
use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility};
use datafusion::physical_expr::expressions::Literal;
use datafusion::physical_expr::window::SlidingAggregateWindowExpr;
use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr};
use datafusion::physical_expr::{
LexRequirement, PhysicalSortRequirement, ScalarFunctionExpr,
};
use datafusion::physical_plan::aggregates::{
AggregateExec, AggregateMode, PhysicalGroupBy,
};
@@ -1148,13 +1150,13 @@ fn roundtrip_json_sink() -> Result<()> {
file_sink_config,
JsonWriterOptions::new(CompressionTypeVariant::UNCOMPRESSED),
));
let sort_order = vec![PhysicalSortRequirement::new(
let sort_order = LexRequirement::new(vec![PhysicalSortRequirement::new(
Arc::new(Column::new("plan_type", 0)),
Some(SortOptions {
descending: true,
nulls_first: false,
}),
)];
)]);
roundtrip_test(Arc::new(DataSinkExec::new(
input,
@@ -1184,13 +1186,13 @@ fn roundtrip_csv_sink() -> Result<()> {
file_sink_config,
CsvWriterOptions::new(WriterBuilder::default(), CompressionTypeVariant::ZSTD),
));
let sort_order = vec![PhysicalSortRequirement::new(
let sort_order = LexRequirement::new(vec![PhysicalSortRequirement::new(
Arc::new(Column::new("plan_type", 0)),
Some(SortOptions {
descending: true,
nulls_first: false,
}),
)];
)]);
let ctx = SessionContext::new();
let codec = DefaultPhysicalExtensionCodec {};
@@ -1243,13 +1245,13 @@ fn roundtrip_parquet_sink() -> Result<()> {
file_sink_config,
TableParquetOptions::default(),
));
let sort_order = vec![PhysicalSortRequirement::new(
let sort_order = LexRequirement::new(vec![PhysicalSortRequirement::new(
Arc::new(Column::new("plan_type", 0)),
Some(SortOptions {
descending: true,
nulls_first: false,
}),
)];
)]);
roundtrip_test(Arc::new(DataSinkExec::new(
input,