mirror of
https://github.com/langchain-ai/datafusion.git
synced 2026-07-01 21:24:06 -04:00
fix(substrait): Correctly parse field references in subqueries (#20439)
## Which issue does this PR close? - Closes #20438. ## Rationale for this change The substrait consumer parsed field references in correlated subqueries incorrectly. Field references were always resolved relative to the schema of the current (innermost) subquery, leading to incorrect results. ## What changes are included in this PR? We now maintain a stack of outer query schemas, and pushes/pops elements from it as we traverse subqueries. When resolving field references, we now use `FieldReference.root_type` to detect outer query field references and resolve them against the appropriate schema. This commit updates the expected results for parsing TPC-H queries, because several of them were parsed incorrectly (the misparsing was probably not detected because the incorrect parse didn't result in any illegal queries, by sheer luck). This also means we can enable Q17, which failed to parse before. ## Are these changes tested? Yes. Test results updated to reflect new, correct behavior, and new unit tests added. ## Are there any user-facing changes? The behavior of the substrait consumer has changed, although the previous behavior was wrong and it seems a bit unlikely anyone would have dependend on it. The `DefaultSubstraitConsumer` API is slightly changed (new private field).
This commit is contained in:
@@ -16,34 +16,48 @@
|
||||
// under the License.
|
||||
|
||||
use crate::logical_plan::consumer::SubstraitConsumer;
|
||||
use datafusion::common::{Column, DFSchema, not_impl_err};
|
||||
use datafusion::common::{Column, DFSchema, not_impl_err, substrait_err};
|
||||
use datafusion::logical_expr::Expr;
|
||||
use std::sync::Arc;
|
||||
use substrait::proto::expression::FieldReference;
|
||||
use substrait::proto::expression::field_reference::ReferenceType::DirectReference;
|
||||
use substrait::proto::expression::field_reference::RootType;
|
||||
use substrait::proto::expression::reference_segment::ReferenceType::StructField;
|
||||
|
||||
pub async fn from_field_reference(
|
||||
_consumer: &impl SubstraitConsumer,
|
||||
consumer: &impl SubstraitConsumer,
|
||||
field_ref: &FieldReference,
|
||||
input_schema: &DFSchema,
|
||||
) -> datafusion::common::Result<Expr> {
|
||||
from_substrait_field_reference(field_ref, input_schema)
|
||||
from_substrait_field_reference(consumer, field_ref, input_schema)
|
||||
}
|
||||
|
||||
pub(crate) fn from_substrait_field_reference(
|
||||
consumer: &impl SubstraitConsumer,
|
||||
field_ref: &FieldReference,
|
||||
input_schema: &DFSchema,
|
||||
) -> datafusion::common::Result<Expr> {
|
||||
match &field_ref.reference_type {
|
||||
Some(DirectReference(direct)) => match &direct.reference_type.as_ref() {
|
||||
Some(StructField(x)) => match &x.child.as_ref() {
|
||||
Some(_) => not_impl_err!(
|
||||
"Direct reference StructField with child is not supported"
|
||||
),
|
||||
None => Ok(Expr::Column(Column::from(
|
||||
input_schema.qualified_field(x.field as usize),
|
||||
))),
|
||||
},
|
||||
Some(StructField(struct_field)) => {
|
||||
if struct_field.child.is_some() {
|
||||
return not_impl_err!(
|
||||
"Direct reference StructField with child is not supported"
|
||||
);
|
||||
}
|
||||
let field_idx = struct_field.field as usize;
|
||||
match &field_ref.root_type {
|
||||
Some(RootType::RootReference(_)) | None => Ok(Expr::Column(
|
||||
Column::from(input_schema.qualified_field(field_idx)),
|
||||
)),
|
||||
Some(RootType::OuterReference(outer_ref)) => {
|
||||
resolve_outer_reference(consumer, outer_ref, field_idx)
|
||||
}
|
||||
Some(RootType::Expression(_)) => not_impl_err!(
|
||||
"Expression root type in field reference is not supported"
|
||||
),
|
||||
}
|
||||
}
|
||||
_ => not_impl_err!(
|
||||
"Direct reference with types other than StructField is not supported"
|
||||
),
|
||||
@@ -51,3 +65,20 @@ pub(crate) fn from_substrait_field_reference(
|
||||
_ => not_impl_err!("unsupported field ref type"),
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_outer_reference(
|
||||
consumer: &impl SubstraitConsumer,
|
||||
outer_ref: &substrait::proto::expression::field_reference::OuterReference,
|
||||
field_idx: usize,
|
||||
) -> datafusion::common::Result<Expr> {
|
||||
let steps_out = outer_ref.steps_out as usize;
|
||||
let Some(outer_schema) = consumer.get_outer_schema(steps_out) else {
|
||||
return substrait_err!(
|
||||
"OuterReference with steps_out={steps_out} \
|
||||
but no outer schema is available"
|
||||
);
|
||||
};
|
||||
let (qualifier, field) = outer_schema.qualified_field(field_idx);
|
||||
let col = Column::from((qualifier, field));
|
||||
Ok(Expr::OuterReferenceColumn(Arc::clone(field), col))
|
||||
}
|
||||
|
||||
@@ -117,10 +117,7 @@ pub async fn from_substrait_extended_expr(
|
||||
return not_impl_err!("Type variation extensions are not supported");
|
||||
}
|
||||
|
||||
let consumer = DefaultSubstraitConsumer {
|
||||
extensions: &extensions,
|
||||
state,
|
||||
};
|
||||
let consumer = DefaultSubstraitConsumer::new(&extensions, state);
|
||||
|
||||
let input_schema = DFSchemaRef::new(match &extended_expr.base_schema {
|
||||
Some(base_schema) => from_substrait_named_struct(&consumer, base_schema),
|
||||
|
||||
@@ -18,13 +18,31 @@
|
||||
use crate::logical_plan::consumer::SubstraitConsumer;
|
||||
use datafusion::common::{DFSchema, Spans, substrait_datafusion_err, substrait_err};
|
||||
use datafusion::logical_expr::expr::{Exists, InSubquery, SetComparison, SetQuantifier};
|
||||
use datafusion::logical_expr::{Expr, Operator, Subquery};
|
||||
use datafusion::logical_expr::{Expr, LogicalPlan, Operator, Subquery};
|
||||
use std::sync::Arc;
|
||||
use substrait::proto::Rel;
|
||||
use substrait::proto::expression as substrait_expression;
|
||||
use substrait::proto::expression::subquery::SubqueryType;
|
||||
use substrait::proto::expression::subquery::set_comparison::{ComparisonOp, ReductionOp};
|
||||
use substrait::proto::expression::subquery::set_predicate::PredicateOp;
|
||||
|
||||
/// Consume a subquery relation, making the enclosing query's schema
|
||||
/// available for resolving correlated column references.
|
||||
///
|
||||
/// Substrait represents correlated references using `OuterReference`
|
||||
/// field references with a `steps_out` depth. To resolve these,
|
||||
/// the consumer maintains a stack of outer schemas.
|
||||
async fn consume_subquery_rel(
|
||||
consumer: &impl SubstraitConsumer,
|
||||
rel: &Rel,
|
||||
outer_schema: &DFSchema,
|
||||
) -> datafusion::common::Result<LogicalPlan> {
|
||||
consumer.push_outer_schema(Arc::new(outer_schema.clone()));
|
||||
let result = consumer.consume_rel(rel).await;
|
||||
consumer.pop_outer_schema();
|
||||
result
|
||||
}
|
||||
|
||||
pub async fn from_subquery(
|
||||
consumer: &impl SubstraitConsumer,
|
||||
subquery: &substrait_expression::Subquery,
|
||||
@@ -41,7 +59,9 @@ pub async fn from_subquery(
|
||||
let needle_expr = &in_predicate.needles[0];
|
||||
let haystack_expr = &in_predicate.haystack;
|
||||
if let Some(haystack_expr) = haystack_expr {
|
||||
let haystack_expr = consumer.consume_rel(haystack_expr).await?;
|
||||
let haystack_expr =
|
||||
consume_subquery_rel(consumer, haystack_expr, input_schema)
|
||||
.await?;
|
||||
let outer_refs = haystack_expr.all_out_ref_exprs();
|
||||
Ok(Expr::InSubquery(InSubquery {
|
||||
expr: Box::new(
|
||||
@@ -64,9 +84,12 @@ pub async fn from_subquery(
|
||||
}
|
||||
}
|
||||
SubqueryType::Scalar(query) => {
|
||||
let plan = consumer
|
||||
.consume_rel(&(query.input.clone()).unwrap_or_default())
|
||||
.await?;
|
||||
let plan = consume_subquery_rel(
|
||||
consumer,
|
||||
&(query.input.clone()).unwrap_or_default(),
|
||||
input_schema,
|
||||
)
|
||||
.await?;
|
||||
let outer_ref_columns = plan.all_out_ref_exprs();
|
||||
Ok(Expr::ScalarSubquery(Subquery {
|
||||
subquery: Arc::new(plan),
|
||||
@@ -79,9 +102,12 @@ pub async fn from_subquery(
|
||||
// exist
|
||||
PredicateOp::Exists => {
|
||||
let relation = &predicate.tuples;
|
||||
let plan = consumer
|
||||
.consume_rel(&relation.clone().unwrap_or_default())
|
||||
.await?;
|
||||
let plan = consume_subquery_rel(
|
||||
consumer,
|
||||
&relation.clone().unwrap_or_default(),
|
||||
input_schema,
|
||||
)
|
||||
.await?;
|
||||
let outer_ref_columns = plan.all_out_ref_exprs();
|
||||
Ok(Expr::Exists(Exists::new(
|
||||
Subquery {
|
||||
@@ -131,7 +157,7 @@ pub async fn from_subquery(
|
||||
};
|
||||
|
||||
let left_expr = consumer.consume_expression(left, input_schema).await?;
|
||||
let plan = consumer.consume_rel(right).await?;
|
||||
let plan = consume_subquery_rel(consumer, right, input_schema).await?;
|
||||
let outer_ref_columns = plan.all_out_ref_exprs();
|
||||
|
||||
Ok(Expr::SetComparison(SetComparison::new(
|
||||
|
||||
@@ -35,10 +35,7 @@ pub async fn from_substrait_plan(
|
||||
return not_impl_err!("Type variation extensions are not supported");
|
||||
}
|
||||
|
||||
let consumer = DefaultSubstraitConsumer {
|
||||
extensions: &extensions,
|
||||
state,
|
||||
};
|
||||
let consumer = DefaultSubstraitConsumer::new(&extensions, state);
|
||||
from_substrait_plan_with_consumer(&consumer, plan).await
|
||||
}
|
||||
|
||||
|
||||
@@ -42,7 +42,8 @@ pub async fn from_exchange_rel(
|
||||
let mut partition_columns = vec![];
|
||||
let input_schema = input.schema();
|
||||
for field_ref in &scatter_fields.fields {
|
||||
let column = from_substrait_field_reference(field_ref, input_schema)?;
|
||||
let column =
|
||||
from_substrait_field_reference(consumer, field_ref, input_schema)?;
|
||||
partition_columns.push(column);
|
||||
}
|
||||
Partitioning::Hash(partition_columns, exchange.partition_count as usize)
|
||||
|
||||
@@ -31,7 +31,7 @@ use datafusion::common::{
|
||||
};
|
||||
use datafusion::execution::{FunctionRegistry, SessionState};
|
||||
use datafusion::logical_expr::{Expr, Extension, LogicalPlan};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use substrait::proto;
|
||||
use substrait::proto::expression as substrait_expression;
|
||||
use substrait::proto::expression::{
|
||||
@@ -364,6 +364,26 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
|
||||
not_impl_err!("Dynamic Parameter expression not supported")
|
||||
}
|
||||
|
||||
// Outer Schema Stack
|
||||
// These methods manage a stack of outer schemas for correlated subquery support.
|
||||
// When entering a subquery, the enclosing query's schema is pushed onto the stack.
|
||||
// Field references with OuterReference root_type use these to resolve columns.
|
||||
|
||||
/// Push an outer schema onto the stack when entering a subquery.
|
||||
fn push_outer_schema(&self, _schema: Arc<DFSchema>) {}
|
||||
|
||||
/// Pop an outer schema from the stack when leaving a subquery.
|
||||
fn pop_outer_schema(&self) {}
|
||||
|
||||
/// Get the outer schema at the given nesting depth.
|
||||
/// `steps_out = 1` is the immediately enclosing query, `steps_out = 2`
|
||||
/// is two levels out, etc. Returns `None` if `steps_out` is 0 or
|
||||
/// exceeds the current nesting depth (the caller should treat this as
|
||||
/// an error in the Substrait plan).
|
||||
fn get_outer_schema(&self, _steps_out: usize) -> Option<Arc<DFSchema>> {
|
||||
None
|
||||
}
|
||||
|
||||
// User-Defined Functionality
|
||||
|
||||
// The details of extension relations, and how to handle them, are fully up to users to specify.
|
||||
@@ -437,11 +457,16 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
|
||||
pub struct DefaultSubstraitConsumer<'a> {
|
||||
pub(super) extensions: &'a Extensions,
|
||||
pub(super) state: &'a SessionState,
|
||||
outer_schemas: RwLock<Vec<Arc<DFSchema>>>,
|
||||
}
|
||||
|
||||
impl<'a> DefaultSubstraitConsumer<'a> {
|
||||
pub fn new(extensions: &'a Extensions, state: &'a SessionState) -> Self {
|
||||
DefaultSubstraitConsumer { extensions, state }
|
||||
DefaultSubstraitConsumer {
|
||||
extensions,
|
||||
state,
|
||||
outer_schemas: RwLock::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -465,6 +490,24 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
|
||||
self.state
|
||||
}
|
||||
|
||||
fn push_outer_schema(&self, schema: Arc<DFSchema>) {
|
||||
self.outer_schemas.write().unwrap().push(schema);
|
||||
}
|
||||
|
||||
fn pop_outer_schema(&self) {
|
||||
self.outer_schemas.write().unwrap().pop();
|
||||
}
|
||||
|
||||
fn get_outer_schema(&self, steps_out: usize) -> Option<Arc<DFSchema>> {
|
||||
let schemas = self.outer_schemas.read().unwrap();
|
||||
// steps_out=1 → last element, steps_out=2 → second-to-last, etc.
|
||||
// Returns None for steps_out=0 or steps_out > stack depth.
|
||||
schemas
|
||||
.len()
|
||||
.checked_sub(steps_out)
|
||||
.and_then(|idx| schemas.get(idx).cloned())
|
||||
}
|
||||
|
||||
async fn consume_extension_leaf(
|
||||
&self,
|
||||
rel: &ExtensionLeafRel,
|
||||
@@ -520,3 +563,79 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
|
||||
Ok(LogicalPlan::Extension(Extension { node: plan }))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::logical_plan::consumer::utils::tests::test_consumer;
|
||||
use datafusion::arrow::datatypes::{DataType, Field, Schema};
|
||||
|
||||
fn make_schema(fields: &[(&str, DataType)]) -> Arc<DFSchema> {
|
||||
let arrow_fields: Vec<Field> = fields
|
||||
.iter()
|
||||
.map(|(name, dt)| Field::new(*name, dt.clone(), true))
|
||||
.collect();
|
||||
Arc::new(
|
||||
DFSchema::try_from(Schema::new(arrow_fields))
|
||||
.expect("failed to create schema"),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_outer_schema_empty_stack() {
|
||||
let consumer = test_consumer();
|
||||
|
||||
// No schemas pushed — any steps_out should return None
|
||||
assert!(consumer.get_outer_schema(0).is_none());
|
||||
assert!(consumer.get_outer_schema(1).is_none());
|
||||
assert!(consumer.get_outer_schema(2).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_outer_schema_single_level() {
|
||||
let consumer = test_consumer();
|
||||
|
||||
let schema_a = make_schema(&[("a", DataType::Int64)]);
|
||||
consumer.push_outer_schema(Arc::clone(&schema_a));
|
||||
|
||||
// steps_out=1 returns the one pushed schema
|
||||
let result = consumer.get_outer_schema(1).unwrap();
|
||||
assert_eq!(result.fields().len(), 1);
|
||||
assert_eq!(result.fields()[0].name(), "a");
|
||||
|
||||
// steps_out=0 and steps_out=2 are out of range
|
||||
assert!(consumer.get_outer_schema(0).is_none());
|
||||
assert!(consumer.get_outer_schema(2).is_none());
|
||||
|
||||
consumer.pop_outer_schema();
|
||||
assert!(consumer.get_outer_schema(1).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_outer_schema_nested() {
|
||||
let consumer = test_consumer();
|
||||
|
||||
let schema_a = make_schema(&[("a", DataType::Int64)]);
|
||||
let schema_b = make_schema(&[("b", DataType::Utf8)]);
|
||||
|
||||
consumer.push_outer_schema(Arc::clone(&schema_a));
|
||||
consumer.push_outer_schema(Arc::clone(&schema_b));
|
||||
|
||||
// steps_out=1 returns the most recent (schema_b)
|
||||
let result = consumer.get_outer_schema(1).unwrap();
|
||||
assert_eq!(result.fields()[0].name(), "b");
|
||||
|
||||
// steps_out=2 returns the grandparent (schema_a)
|
||||
let result = consumer.get_outer_schema(2).unwrap();
|
||||
assert_eq!(result.fields()[0].name(), "a");
|
||||
|
||||
// steps_out=3 exceeds depth
|
||||
assert!(consumer.get_outer_schema(3).is_none());
|
||||
|
||||
// Pop one level — now steps_out=1 returns schema_a
|
||||
consumer.pop_outer_schema();
|
||||
let result = consumer.get_outer_schema(1).unwrap();
|
||||
assert_eq!(result.fields()[0].name(), "a");
|
||||
assert!(consumer.get_outer_schema(2).is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ mod tests {
|
||||
Subquery:
|
||||
Aggregate: groupBy=[[]], aggr=[[min(PARTSUPP.PS_SUPPLYCOST)]]
|
||||
Projection: PARTSUPP.PS_SUPPLYCOST
|
||||
Filter: PARTSUPP.PS_PARTKEY = PARTSUPP.PS_PARTKEY AND SUPPLIER.S_SUPPKEY = PARTSUPP.PS_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8("EUROPE")
|
||||
Filter: outer_ref(PART.P_PARTKEY) = PARTSUPP.PS_PARTKEY AND SUPPLIER.S_SUPPKEY = PARTSUPP.PS_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8("EUROPE")
|
||||
Cross Join:
|
||||
Cross Join:
|
||||
Cross Join:
|
||||
@@ -134,7 +134,7 @@ mod tests {
|
||||
Projection: ORDERS.O_ORDERPRIORITY
|
||||
Filter: ORDERS.O_ORDERDATE >= CAST(Utf8("1993-07-01") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8("1993-10-01") AS Date32) AND EXISTS (<subquery>)
|
||||
Subquery:
|
||||
Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE
|
||||
Filter: LINEITEM.L_ORDERKEY = outer_ref(ORDERS.O_ORDERKEY) AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE
|
||||
TableScan: LINEITEM
|
||||
TableScan: ORDERS
|
||||
"#
|
||||
@@ -353,11 +353,27 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[ignore]
|
||||
#[tokio::test]
|
||||
async fn tpch_test_17() -> Result<()> {
|
||||
let plan_str = tpch_plan_to_string(17).await?;
|
||||
assert_snapshot!(plan_str, "panics due to out of bounds field access");
|
||||
assert_snapshot!(
|
||||
plan_str,
|
||||
@r#"
|
||||
Projection: sum(LINEITEM.L_EXTENDEDPRICE) / Decimal128(Some(70),2,1) AS AVG_YEARLY
|
||||
Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE)]]
|
||||
Projection: LINEITEM.L_EXTENDEDPRICE
|
||||
Filter: PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8("Brand#23") AND PART.P_CONTAINER = Utf8("MED BOX") AND LINEITEM.L_QUANTITY < (<subquery>)
|
||||
Subquery:
|
||||
Projection: Decimal128(Some(2),2,1) * avg(LINEITEM.L_QUANTITY)
|
||||
Aggregate: groupBy=[[]], aggr=[[avg(LINEITEM.L_QUANTITY)]]
|
||||
Projection: LINEITEM.L_QUANTITY
|
||||
Filter: LINEITEM.L_PARTKEY = outer_ref(PART.P_PARTKEY)
|
||||
TableScan: LINEITEM
|
||||
Cross Join:
|
||||
TableScan: LINEITEM
|
||||
TableScan: PART
|
||||
"#
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -425,7 +441,7 @@ mod tests {
|
||||
Projection: Decimal128(Some(5),2,1) * sum(LINEITEM.L_QUANTITY)
|
||||
Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_QUANTITY)]]
|
||||
Projection: LINEITEM.L_QUANTITY
|
||||
Filter: LINEITEM.L_PARTKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_SUPPKEY = LINEITEM.L_PARTKEY AND LINEITEM.L_SHIPDATE >= CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8("1995-01-01") AS Date32)
|
||||
Filter: LINEITEM.L_PARTKEY = outer_ref(PARTSUPP.PS_PARTKEY) AND LINEITEM.L_SUPPKEY = outer_ref(PARTSUPP.PS_SUPPKEY) AND LINEITEM.L_SHIPDATE >= CAST(Utf8("1994-01-01") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8("1995-01-01") AS Date32)
|
||||
TableScan: LINEITEM
|
||||
TableScan: PARTSUPP
|
||||
Cross Join:
|
||||
@@ -449,10 +465,10 @@ mod tests {
|
||||
Projection: SUPPLIER.S_NAME
|
||||
Filter: SUPPLIER.S_SUPPKEY = LINEITEM.L_SUPPKEY AND ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND ORDERS.O_ORDERSTATUS = Utf8("F") AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE AND EXISTS (<subquery>) AND NOT EXISTS (<subquery>) AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8("SAUDI ARABIA")
|
||||
Subquery:
|
||||
Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS
|
||||
Filter: LINEITEM.L_ORDERKEY = outer_ref(LINEITEM.L_ORDERKEY) AND LINEITEM.L_SUPPKEY != outer_ref(LINEITEM.L_SUPPKEY)
|
||||
TableScan: LINEITEM
|
||||
Subquery:
|
||||
Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE
|
||||
Filter: LINEITEM.L_ORDERKEY = outer_ref(LINEITEM.L_ORDERKEY) AND LINEITEM.L_SUPPKEY != outer_ref(LINEITEM.L_SUPPKEY) AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE
|
||||
TableScan: LINEITEM
|
||||
Cross Join:
|
||||
Cross Join:
|
||||
@@ -483,7 +499,7 @@ mod tests {
|
||||
Filter: CUSTOMER.C_ACCTBAL > Decimal128(Some(0),3,2) AND (substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("13") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("31") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("23") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("29") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("30") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("18") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8("17") AS Utf8))
|
||||
TableScan: CUSTOMER
|
||||
Subquery:
|
||||
Filter: ORDERS.O_CUSTKEY = ORDERS.O_ORDERKEY
|
||||
Filter: ORDERS.O_CUSTKEY = outer_ref(CUSTOMER.C_CUSTKEY)
|
||||
TableScan: ORDERS
|
||||
TableScan: CUSTOMER
|
||||
"#
|
||||
@@ -491,6 +507,52 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Tests nested correlated subqueries where the innermost subquery
|
||||
/// references the outermost query (steps_out=2).
|
||||
///
|
||||
/// This tests the outer schema stack with depth > 1.
|
||||
/// The plan represents:
|
||||
/// ```sql
|
||||
/// SELECT * FROM A
|
||||
/// WHERE EXISTS (
|
||||
/// SELECT * FROM B
|
||||
/// WHERE B.b1 = A.a1 -- steps_out=1 (references immediate parent)
|
||||
/// AND EXISTS (
|
||||
/// SELECT * FROM C
|
||||
/// WHERE C.c1 = A.a1 -- steps_out=2 (references grandparent)
|
||||
/// AND C.c2 = B.b2 -- steps_out=1 (references immediate parent)
|
||||
/// )
|
||||
/// )
|
||||
/// ```
|
||||
///
|
||||
#[tokio::test]
|
||||
async fn test_nested_correlated_subquery() -> Result<()> {
|
||||
let path = "tests/testdata/test_plans/nested_correlated_subquery.substrait.json";
|
||||
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
|
||||
File::open(path).expect("file not found"),
|
||||
))
|
||||
.expect("failed to parse json");
|
||||
|
||||
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?;
|
||||
let plan = from_substrait_plan(&ctx.state(), &proto).await?;
|
||||
let plan_str = format!("{plan}");
|
||||
|
||||
assert_snapshot!(
|
||||
plan_str,
|
||||
@r#"
|
||||
Filter: EXISTS (<subquery>)
|
||||
Subquery:
|
||||
Filter: B.b1 = outer_ref(A.a1) AND EXISTS (<subquery>)
|
||||
Subquery:
|
||||
Filter: C.c1 = outer_ref(A.a1) AND C.c2 = outer_ref(B.b2)
|
||||
TableScan: C
|
||||
TableScan: B
|
||||
TableScan: A
|
||||
"#
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn test_plan_to_string(name: &str) -> Result<String> {
|
||||
let path = format!("tests/testdata/test_plans/{name}");
|
||||
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
|
||||
|
||||
+265
@@ -0,0 +1,265 @@
|
||||
{
|
||||
"extensionUris": [{
|
||||
"extensionUriAnchor": 1,
|
||||
"uri": "/functions_boolean.yaml"
|
||||
}, {
|
||||
"extensionUriAnchor": 2,
|
||||
"uri": "/functions_comparison.yaml"
|
||||
}],
|
||||
"extensions": [{
|
||||
"extensionFunction": {
|
||||
"extensionUriReference": 1,
|
||||
"name": "and:bool"
|
||||
}
|
||||
}, {
|
||||
"extensionFunction": {
|
||||
"extensionUriReference": 2,
|
||||
"functionAnchor": 1,
|
||||
"name": "equal:any_any"
|
||||
}
|
||||
}],
|
||||
"relations": [{
|
||||
"root": {
|
||||
"input": {
|
||||
"filter": {
|
||||
"common": {
|
||||
"direct": {}
|
||||
},
|
||||
"input": {
|
||||
"read": {
|
||||
"common": {
|
||||
"direct": {}
|
||||
},
|
||||
"baseSchema": {
|
||||
"names": ["a1", "a2"],
|
||||
"struct": {
|
||||
"types": [{
|
||||
"i64": {
|
||||
"nullability": "NULLABILITY_REQUIRED"
|
||||
}
|
||||
}, {
|
||||
"i64": {
|
||||
"nullability": "NULLABILITY_REQUIRED"
|
||||
}
|
||||
}],
|
||||
"nullability": "NULLABILITY_REQUIRED"
|
||||
}
|
||||
},
|
||||
"namedTable": {
|
||||
"names": ["A"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"condition": {
|
||||
"subquery": {
|
||||
"setPredicate": {
|
||||
"predicateOp": "PREDICATE_OP_EXISTS",
|
||||
"tuples": {
|
||||
"filter": {
|
||||
"common": {
|
||||
"direct": {}
|
||||
},
|
||||
"input": {
|
||||
"read": {
|
||||
"common": {
|
||||
"direct": {}
|
||||
},
|
||||
"baseSchema": {
|
||||
"names": ["b1", "b2"],
|
||||
"struct": {
|
||||
"types": [{
|
||||
"i64": {
|
||||
"nullability": "NULLABILITY_REQUIRED"
|
||||
}
|
||||
}, {
|
||||
"i64": {
|
||||
"nullability": "NULLABILITY_REQUIRED"
|
||||
}
|
||||
}],
|
||||
"nullability": "NULLABILITY_REQUIRED"
|
||||
}
|
||||
},
|
||||
"namedTable": {
|
||||
"names": ["B"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"condition": {
|
||||
"scalarFunction": {
|
||||
"outputType": {
|
||||
"bool": {
|
||||
"nullability": "NULLABILITY_REQUIRED"
|
||||
}
|
||||
},
|
||||
"arguments": [{
|
||||
"value": {
|
||||
"scalarFunction": {
|
||||
"functionReference": 1,
|
||||
"outputType": {
|
||||
"bool": {
|
||||
"nullability": "NULLABILITY_REQUIRED"
|
||||
}
|
||||
},
|
||||
"arguments": [{
|
||||
"value": {
|
||||
"selection": {
|
||||
"directReference": {
|
||||
"structField": {
|
||||
"field": 0
|
||||
}
|
||||
},
|
||||
"rootReference": {}
|
||||
}
|
||||
}
|
||||
}, {
|
||||
"value": {
|
||||
"selection": {
|
||||
"directReference": {
|
||||
"structField": {
|
||||
"field": 0
|
||||
}
|
||||
},
|
||||
"outerReference": {
|
||||
"stepsOut": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
}, {
|
||||
"value": {
|
||||
"subquery": {
|
||||
"setPredicate": {
|
||||
"predicateOp": "PREDICATE_OP_EXISTS",
|
||||
"tuples": {
|
||||
"filter": {
|
||||
"common": {
|
||||
"direct": {}
|
||||
},
|
||||
"input": {
|
||||
"read": {
|
||||
"common": {
|
||||
"direct": {}
|
||||
},
|
||||
"baseSchema": {
|
||||
"names": ["c1", "c2"],
|
||||
"struct": {
|
||||
"types": [{
|
||||
"i64": {
|
||||
"nullability": "NULLABILITY_REQUIRED"
|
||||
}
|
||||
}, {
|
||||
"i64": {
|
||||
"nullability": "NULLABILITY_REQUIRED"
|
||||
}
|
||||
}],
|
||||
"nullability": "NULLABILITY_REQUIRED"
|
||||
}
|
||||
},
|
||||
"namedTable": {
|
||||
"names": ["C"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"condition": {
|
||||
"scalarFunction": {
|
||||
"outputType": {
|
||||
"bool": {
|
||||
"nullability": "NULLABILITY_REQUIRED"
|
||||
}
|
||||
},
|
||||
"arguments": [{
|
||||
"value": {
|
||||
"scalarFunction": {
|
||||
"functionReference": 1,
|
||||
"outputType": {
|
||||
"bool": {
|
||||
"nullability": "NULLABILITY_REQUIRED"
|
||||
}
|
||||
},
|
||||
"arguments": [{
|
||||
"value": {
|
||||
"selection": {
|
||||
"directReference": {
|
||||
"structField": {
|
||||
"field": 0
|
||||
}
|
||||
},
|
||||
"rootReference": {}
|
||||
}
|
||||
}
|
||||
}, {
|
||||
"value": {
|
||||
"selection": {
|
||||
"directReference": {
|
||||
"structField": {
|
||||
"field": 0
|
||||
}
|
||||
},
|
||||
"outerReference": {
|
||||
"stepsOut": 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
}, {
|
||||
"value": {
|
||||
"scalarFunction": {
|
||||
"functionReference": 1,
|
||||
"outputType": {
|
||||
"bool": {
|
||||
"nullability": "NULLABILITY_REQUIRED"
|
||||
}
|
||||
},
|
||||
"arguments": [{
|
||||
"value": {
|
||||
"selection": {
|
||||
"directReference": {
|
||||
"structField": {
|
||||
"field": 1
|
||||
}
|
||||
},
|
||||
"rootReference": {}
|
||||
}
|
||||
}
|
||||
}, {
|
||||
"value": {
|
||||
"selection": {
|
||||
"directReference": {
|
||||
"structField": {
|
||||
"field": 1
|
||||
}
|
||||
},
|
||||
"outerReference": {
|
||||
"stepsOut": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"names": ["a1", "a2"]
|
||||
}
|
||||
}]
|
||||
}
|
||||
Reference in New Issue
Block a user