mirror of
https://github.com/langchain-ai/delta-rs.git
synced 2026-07-01 20:34:35 -04:00
refactor: use BatchAdapterFactory for scan adaptation (#4195)
# Description Follow up to the DF 52 upgrade. Replaces custom scan batch casting logic with DataFusion's `BatchAdapterFactory` via `datafusion-physical-expr-adapter`. Adds hardening tests for schema evolution and DV scan behavior Keeps DF default behavior with no custom compatibility mode # Related Issue(s) <!--- For example: - closes #106 ---> # Documentation <!--- Share links to useful documentation ---> --------- Signed-off-by: Ethan Urbanski <ethan@urbanskitech.com>
This commit is contained in:
@@ -41,6 +41,7 @@ parquet = { version = "57" }
|
||||
# datafusion 52.1
|
||||
datafusion = { version = "52.1.0" }
|
||||
datafusion-datasource = { version = "52.1.0" }
|
||||
datafusion-physical-expr-adapter = { version = "52.1.0" }
|
||||
datafusion-ffi = { version = "52.1.0" }
|
||||
datafusion-proto = { version = "52.1.0" }
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ object_store = { workspace = true }
|
||||
# datafusion
|
||||
datafusion = { workspace = true, optional = true }
|
||||
datafusion-datasource = { workspace = true, optional = true }
|
||||
datafusion-physical-expr-adapter = { workspace = true, optional = true }
|
||||
datafusion-proto = { workspace = true, optional = true }
|
||||
|
||||
# serde
|
||||
@@ -108,6 +109,7 @@ default = ["rustls"]
|
||||
datafusion = [
|
||||
"dep:datafusion",
|
||||
"datafusion-datasource",
|
||||
"datafusion-physical-expr-adapter",
|
||||
"datafusion-proto",
|
||||
]
|
||||
datafusion-ext = ["datafusion"]
|
||||
|
||||
@@ -18,9 +18,9 @@ use arrow_array::{
|
||||
};
|
||||
use arrow_schema::{DataType, FieldRef, Fields, Schema};
|
||||
use dashmap::DashMap;
|
||||
use datafusion::common::HashMap;
|
||||
use datafusion::common::config::ConfigOptions;
|
||||
use datafusion::common::error::{DataFusionError, Result};
|
||||
use datafusion::common::{HashMap, internal_datafusion_err};
|
||||
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
|
||||
use datafusion::physical_expr::EquivalenceProperties;
|
||||
use datafusion::physical_plan::execution_plan::{
|
||||
@@ -283,12 +283,16 @@ impl DeltaScanMetaStream {
|
||||
)?;
|
||||
|
||||
let batch = if let Some(selection) = self.selection_vectors.get(&file_id) {
|
||||
let missing = batch.num_rows() - selection.len();
|
||||
let filter = if missing > 0 {
|
||||
BooleanArray::from_iter(selection.iter().chain(std::iter::repeat_n(&true, missing)))
|
||||
} else {
|
||||
BooleanArray::from_iter(selection.iter())
|
||||
};
|
||||
if selection.len() != batch.num_rows() {
|
||||
return Err(internal_datafusion_err!(
|
||||
"Selection vector length ({}) does not match row count ({}) for file '{}'. \
|
||||
This indicates a bug in deletion vector processing.",
|
||||
selection.len(),
|
||||
batch.num_rows(),
|
||||
file_id
|
||||
));
|
||||
}
|
||||
let filter = BooleanArray::from_iter(selection.iter());
|
||||
filter_record_batch(&batch, &filter)?
|
||||
} else {
|
||||
batch
|
||||
|
||||
@@ -16,10 +16,9 @@
|
||||
|
||||
use std::{collections::VecDeque, pin::Pin, sync::Arc};
|
||||
|
||||
use arrow::array::AsArray;
|
||||
use arrow_array::{ArrayRef, RecordBatch, StructArray};
|
||||
use arrow_array::{ArrayRef, RecordBatch};
|
||||
use arrow_cast::{CastOptions, cast_with_options};
|
||||
use arrow_schema::{DataType, FieldRef, Schema, SchemaBuilder, SchemaRef};
|
||||
use arrow_schema::{FieldRef, Schema, SchemaBuilder, SchemaRef};
|
||||
use chrono::{TimeZone as _, Utc};
|
||||
use dashmap::DashMap;
|
||||
use datafusion::{
|
||||
@@ -43,6 +42,7 @@ use datafusion_datasource::{
|
||||
PartitionedFile, TableSchema, compute_all_files_statistics, file_groups::FileGroup,
|
||||
file_scan_config::FileScanConfigBuilder, source::DataSourceExec,
|
||||
};
|
||||
use datafusion_physical_expr_adapter::BatchAdapterFactory;
|
||||
use delta_kernel::{
|
||||
Engine, Expression, expressions::StructData, scan::ScanMetadata, table_features::TableFeature,
|
||||
};
|
||||
@@ -505,36 +505,25 @@ fn finalize_transformed_batch(
|
||||
}
|
||||
|
||||
fn cast_record_batch(batch: RecordBatch, target_schema: &SchemaRef) -> Result<RecordBatch> {
|
||||
if batch.num_columns() == 0 {
|
||||
if !target_schema.fields().is_empty() {
|
||||
return plan_err!(
|
||||
"Cannot cast empty RecordBatch to non-empty schema: {:?}",
|
||||
target_schema
|
||||
);
|
||||
}
|
||||
if batch.schema_ref().eq(target_schema) {
|
||||
return Ok(batch);
|
||||
}
|
||||
|
||||
let options = CastOptions {
|
||||
safe: true,
|
||||
..Default::default()
|
||||
};
|
||||
Ok(cast_with_options(
|
||||
&StructArray::from(batch),
|
||||
&DataType::Struct(target_schema.fields().clone()),
|
||||
&options,
|
||||
)?
|
||||
.as_struct()
|
||||
.into())
|
||||
let adapter_factory = BatchAdapterFactory::new(Arc::clone(target_schema));
|
||||
let adapter = adapter_factory.make_adapter(batch.schema())?;
|
||||
adapter.adapt_batch(&batch)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow_array::Array;
|
||||
use arrow_array::{
|
||||
BinaryArray, BinaryViewArray, Int32Array, RecordBatch, StringArray, StructArray,
|
||||
BinaryArray, BinaryViewArray, Int32Array, Int64Array, RecordBatch, RecordBatchOptions,
|
||||
StringArray, StructArray,
|
||||
};
|
||||
use arrow_schema::{DataType, Field, Fields, Schema};
|
||||
use arrow_schema::{ArrowError, DataType, Field, Fields, Schema};
|
||||
use datafusion::{
|
||||
error::DataFusionError,
|
||||
physical_plan::collect,
|
||||
prelude::{col, lit},
|
||||
};
|
||||
@@ -562,6 +551,105 @@ mod tests {
|
||||
assert_eq!(groups[1].len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cast_record_batch_empty_input_synthesizes_nullable_columns() {
|
||||
let source_schema = Arc::new(Schema::new(Fields::empty()));
|
||||
let source = RecordBatch::try_new_with_options(
|
||||
source_schema,
|
||||
vec![],
|
||||
&RecordBatchOptions::new().with_row_count(Some(2)),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let target_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, true)]));
|
||||
let adapted = cast_record_batch(source, &target_schema).unwrap();
|
||||
|
||||
assert_eq!(adapted.schema().as_ref(), target_schema.as_ref());
|
||||
assert_eq!(adapted.num_rows(), 2);
|
||||
let id = adapted
|
||||
.column(0)
|
||||
.as_any()
|
||||
.downcast_ref::<Int32Array>()
|
||||
.unwrap();
|
||||
assert_eq!(id.null_count(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cast_record_batch_empty_input_missing_non_nullable_column_errors() {
|
||||
let source_schema = Arc::new(Schema::new(Fields::empty()));
|
||||
let source = RecordBatch::try_new_with_options(
|
||||
source_schema,
|
||||
vec![],
|
||||
&RecordBatchOptions::new().with_row_count(Some(1)),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let target_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
|
||||
let err = cast_record_batch(source, &target_schema)
|
||||
.expect_err("missing non-nullable columns should error");
|
||||
match err {
|
||||
DataFusionError::Execution(msg) => {
|
||||
assert!(
|
||||
msg.contains("Non-nullable column 'id'"),
|
||||
"expected non-nullable missing-column error, got: {msg}"
|
||||
);
|
||||
assert!(
|
||||
msg.contains("missing from the physical schema"),
|
||||
"expected missing physical schema detail, got: {msg}"
|
||||
);
|
||||
}
|
||||
other => {
|
||||
panic!("expected execution error for missing non-nullable column, got: {other}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cast_record_batch_invalid_scalar_cast_errors() {
|
||||
let source_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, true)]));
|
||||
let source = RecordBatch::try_new(
|
||||
source_schema,
|
||||
vec![Arc::new(StringArray::from(vec![Some("not-an-int")]))],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let target_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, true)]));
|
||||
let err = cast_record_batch(source, &target_schema)
|
||||
.expect_err("invalid value cast should fail under DataFusion default cast semantics");
|
||||
match err {
|
||||
DataFusionError::ArrowError(inner, _) => {
|
||||
assert!(
|
||||
matches!(inner.as_ref(), ArrowError::CastError(_)),
|
||||
"expected arrow cast error, got: {inner}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected arrow cast error for invalid scalar cast, got: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cast_record_batch_overflow_cast_errors() {
|
||||
let source_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, true)]));
|
||||
let source = RecordBatch::try_new(
|
||||
source_schema,
|
||||
vec![Arc::new(Int64Array::from(vec![i64::from(i32::MAX) + 1]))],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let target_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, true)]));
|
||||
let err = cast_record_batch(source, &target_schema)
|
||||
.expect_err("overflow cast should fail under DataFusion default cast semantics");
|
||||
match err {
|
||||
DataFusionError::ArrowError(inner, _) => {
|
||||
assert!(
|
||||
matches!(inner.as_ref(), ArrowError::CastError(_)),
|
||||
"expected arrow cast error, got: {inner}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected arrow cast error for overflow cast, got: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parquet_plan() -> TestResult {
|
||||
let store = Arc::new(InMemory::new());
|
||||
|
||||
@@ -3,8 +3,10 @@ use std::sync::Arc;
|
||||
|
||||
use arrow_array::RecordBatch;
|
||||
use datafusion::assert_batches_sorted_eq;
|
||||
use datafusion::catalog::TableProvider;
|
||||
use datafusion::physical_plan::{ExecutionPlan, collect_partitioned};
|
||||
use datafusion::prelude::{SessionContext, col, lit};
|
||||
use datafusion::prelude::{SessionConfig, SessionContext, col, lit};
|
||||
use deltalake_core::delta_datafusion::DeltaScanConfig;
|
||||
use deltalake_core::delta_datafusion::DeltaScanNext;
|
||||
use deltalake_core::delta_datafusion::create_session;
|
||||
use deltalake_core::delta_datafusion::engine::DataFusionEngine;
|
||||
@@ -31,6 +33,23 @@ async fn scan_dat(case: &str) -> TestResult<(Snapshot, SessionContext)> {
|
||||
Ok((snapshot, session))
|
||||
}
|
||||
|
||||
async fn scan_dat_with_session(case: &str, session: &SessionContext) -> TestResult<Snapshot> {
|
||||
let root_dir = format!(
|
||||
"{}/../../dat/v0.0.3/reader_tests/generated/{}/",
|
||||
env!["CARGO_MANIFEST_DIR"],
|
||||
case
|
||||
);
|
||||
let root_dir = std::fs::canonicalize(root_dir)?;
|
||||
let case = read_dat_case(root_dir)?;
|
||||
|
||||
let engine = DataFusionEngine::new_from_session(&session.state());
|
||||
let snapshot =
|
||||
Snapshot::try_new_with_engine(engine.clone(), case.table_root()?, Default::default(), None)
|
||||
.await?;
|
||||
|
||||
Ok(snapshot)
|
||||
}
|
||||
|
||||
async fn collect_plan(
|
||||
plan: Arc<dyn ExecutionPlan>,
|
||||
session: &SessionContext,
|
||||
@@ -113,6 +132,72 @@ async fn test_all_primitive_types() -> TestResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_view_types_filter_exec_compatibility() -> TestResult<()> {
|
||||
use arrow_schema::DataType;
|
||||
|
||||
let config =
|
||||
SessionConfig::new().set_bool("datafusion.execution.parquet.schema_force_view_types", true);
|
||||
let session = SessionContext::new_with_config(config);
|
||||
let snapshot = scan_dat_with_session("all_primitive_types", &session).await?;
|
||||
let provider: Arc<dyn TableProvider> = Arc::new(DeltaScanNext::new(
|
||||
snapshot,
|
||||
DeltaScanConfig::new_from_session(&session.state()),
|
||||
)?);
|
||||
|
||||
let plan = provider.scan(&session.state(), None, &[], None).await?;
|
||||
let has_view_types = plan
|
||||
.schema()
|
||||
.fields()
|
||||
.iter()
|
||||
.any(|field| matches!(field.data_type(), DataType::Utf8View | DataType::BinaryView));
|
||||
assert!(
|
||||
has_view_types,
|
||||
"view types should be present when configured"
|
||||
);
|
||||
|
||||
let filter = col("utf8").eq(lit("1"));
|
||||
let batches = session
|
||||
.read_table(provider.clone())?
|
||||
.filter(filter)?
|
||||
.select(vec![col("utf8")])?
|
||||
.collect()
|
||||
.await?;
|
||||
let expected = vec!["+------+", "| utf8 |", "+------+", "| 1 |", "+------+"];
|
||||
assert_batches_sorted_eq!(&expected, &batches);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_view_types_disabled() -> TestResult<()> {
|
||||
use arrow_schema::DataType;
|
||||
|
||||
let config = SessionConfig::new().set_bool(
|
||||
"datafusion.execution.parquet.schema_force_view_types",
|
||||
false,
|
||||
);
|
||||
let session = SessionContext::new_with_config(config);
|
||||
let snapshot = scan_dat_with_session("all_primitive_types", &session).await?;
|
||||
let provider: Arc<dyn TableProvider> = Arc::new(DeltaScanNext::new(
|
||||
snapshot,
|
||||
DeltaScanConfig::new_from_session(&session.state()),
|
||||
)?);
|
||||
|
||||
let plan = provider.scan(&session.state(), None, &[], None).await?;
|
||||
let has_view_types = plan
|
||||
.schema()
|
||||
.fields()
|
||||
.iter()
|
||||
.any(|field| matches!(field.data_type(), DataType::Utf8View | DataType::BinaryView));
|
||||
assert!(
|
||||
!has_view_types,
|
||||
"view types should be disabled when configured"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multi_partitioned() -> TestResult<()> {
|
||||
let (snapshot, session) = scan_dat("multi_partitioned").await?;
|
||||
@@ -234,3 +319,27 @@ async fn test_deletion_vectors() -> TestResult<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deletion_vectors_multi_batch() -> TestResult<()> {
|
||||
let config = SessionConfig::new().with_batch_size(1);
|
||||
let session = SessionContext::new_with_config(config);
|
||||
let snapshot = scan_dat_with_session("deletion_vectors", &session).await?;
|
||||
let provider: Arc<dyn TableProvider> = Arc::new(DeltaScanNext::new(
|
||||
snapshot,
|
||||
DeltaScanConfig::new_from_session(&session.state()),
|
||||
)?);
|
||||
|
||||
let plan = provider.scan(&session.state(), None, &[], None).await?;
|
||||
let batches: Vec<_> = collect_plan(plan, &session).await?;
|
||||
let expected = vec![
|
||||
"+--------+-----+------------+",
|
||||
"| letter | int | date |",
|
||||
"+--------+-----+------------+",
|
||||
"| b | 228 | 1978-12-01 |",
|
||||
"+--------+-----+------------+",
|
||||
];
|
||||
assert_batches_sorted_eq!(&expected, &batches);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user