From bcd42b090a1cb9ca89f0b479a24937ff0b13b230 Mon Sep 17 00:00:00 2001 From: Emily Matheys <55631053+EmilyMatt@users.noreply.github.com> Date: Thu, 26 Feb 2026 10:22:25 +0200 Subject: [PATCH] fix: Unaccounted spill sort in row_hash (#20314) ## Which issue does this PR close? - Closes #20313 . ## Rationale for this change We must not use that much memory without reserving it. ## What changes are included in this PR? Added a reservation before the sort, made a shrink call for the group values after the emit and updated the reservation so the reservation will be possible. Moved the sort to use sort_chunked so we can immediately drop the original batch and shrink the reservation to the used sizes, added a new spill method for iterators, so we can use an accurate memory accounting. If said reservation did not succeed, fallback to an incrementing sort method which holds the original batch the whole time, and outputs one batch at the time, this requires a much smaller reservation. Made the reservation much more robust(otherwise the fuzz tests were failing now that we actually reserve the memory in the sort) ## Are these changes tested? Current tests should still function, but memory should be reserved. Added test that specifically verifies that we error on this when we shouldn't do the sort. Modified the tests that used to test the splitting function in the spill to test the new iter spilling function ## Are there any user-facing changes? No --- .../physical-plan/src/aggregates/mod.rs | 136 +++++++++++++++ .../physical-plan/src/aggregates/row_hash.rs | 72 ++++++-- datafusion/physical-plan/src/sorts/mod.rs | 2 + datafusion/physical-plan/src/sorts/sort.rs | 47 +---- datafusion/physical-plan/src/sorts/stream.rs | 164 +++++++++++++++++- datafusion/physical-plan/src/spill/mod.rs | 14 +- .../physical-plan/src/spill/spill_manager.rs | 59 +++---- 7 files changed, 401 insertions(+), 93 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 4b3ac1955..859999385 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -3837,6 +3837,142 @@ mod tests { Ok(()) } + /// Tests that when the memory pool is too small to accommodate the sort + /// reservation during spill, the error is properly propagated as + /// ResourcesExhausted rather than silently exceeding memory limits. + #[tokio::test] + async fn test_sort_reservation_fails_during_spill() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("g", DataType::Int64, false), + Field::new("a", DataType::Float64, false), + Field::new("b", DataType::Float64, false), + Field::new("c", DataType::Float64, false), + Field::new("d", DataType::Float64, false), + Field::new("e", DataType::Float64, false), + ])); + + let batches = vec![vec![ + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![1])), + Arc::new(Float64Array::from(vec![10.0])), + Arc::new(Float64Array::from(vec![20.0])), + Arc::new(Float64Array::from(vec![30.0])), + Arc::new(Float64Array::from(vec![40.0])), + Arc::new(Float64Array::from(vec![50.0])), + ], + )?, + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![2])), + Arc::new(Float64Array::from(vec![11.0])), + Arc::new(Float64Array::from(vec![21.0])), + Arc::new(Float64Array::from(vec![31.0])), + Arc::new(Float64Array::from(vec![41.0])), + Arc::new(Float64Array::from(vec![51.0])), + ], + )?, + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![3])), + Arc::new(Float64Array::from(vec![12.0])), + Arc::new(Float64Array::from(vec![22.0])), + Arc::new(Float64Array::from(vec![32.0])), + Arc::new(Float64Array::from(vec![42.0])), + Arc::new(Float64Array::from(vec![52.0])), + ], + )?, + ]]; + + let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new( + vec![(col("g", schema.as_ref())?, "g".to_string())], + vec![], + vec![vec![false]], + false, + ), + vec![ + Arc::new( + AggregateExprBuilder::new( + avg_udaf(), + vec![col("a", schema.as_ref())?], + ) + .schema(Arc::clone(&schema)) + .alias("AVG(a)") + .build()?, + ), + Arc::new( + AggregateExprBuilder::new( + avg_udaf(), + vec![col("b", schema.as_ref())?], + ) + .schema(Arc::clone(&schema)) + .alias("AVG(b)") + .build()?, + ), + Arc::new( + AggregateExprBuilder::new( + avg_udaf(), + vec![col("c", schema.as_ref())?], + ) + .schema(Arc::clone(&schema)) + .alias("AVG(c)") + .build()?, + ), + Arc::new( + AggregateExprBuilder::new( + avg_udaf(), + vec![col("d", schema.as_ref())?], + ) + .schema(Arc::clone(&schema)) + .alias("AVG(d)") + .build()?, + ), + Arc::new( + AggregateExprBuilder::new( + avg_udaf(), + vec![col("e", schema.as_ref())?], + ) + .schema(Arc::clone(&schema)) + .alias("AVG(e)") + .build()?, + ), + ], + vec![None, None, None, None, None], + Arc::new(scan) as Arc, + Arc::clone(&schema), + )?); + + // Pool must be large enough for accumulation to start but too small for + // sort_memory after clearing. + let task_ctx = new_spill_ctx(1, 500); + let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await; + + match &result { + Ok(_) => panic!("Expected ResourcesExhausted error but query succeeded"), + Err(e) => { + let root = e.find_root(); + assert!( + matches!(root, DataFusionError::ResourcesExhausted(_)), + "Expected ResourcesExhausted, got: {root}", + ); + let msg = root.to_string(); + assert!( + msg.contains("Failed to reserve memory for sort during spill"), + "Expected sort reservation error, got: {msg}", + ); + } + } + + Ok(()) + } + /// Tests that PartialReduce mode: /// 1. Accepts state as input (like Final) /// 2. Produces state as output (like Partial) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 35f32ac7a..4a1b0e5c8 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -30,9 +30,8 @@ use crate::aggregates::{ create_schema, evaluate_group_by, evaluate_many, evaluate_optional, }; use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput}; -use crate::sorts::sort::sort_batch; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; -use crate::spill::spill_manager::SpillManager; +use crate::spill::spill_manager::{GetSlicedSize, SpillManager}; use crate::{PhysicalExpr, aggregates, metrics}; use crate::{RecordBatchStream, SendableRecordBatchStream}; @@ -40,7 +39,7 @@ use arrow::array::*; use arrow::datatypes::SchemaRef; use datafusion_common::{ DataFusionError, Result, assert_eq_or_internal_err, assert_or_internal_err, - internal_err, + internal_err, resources_datafusion_err, }; use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::proxy::VecAllocExt; @@ -51,7 +50,9 @@ use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use crate::sorts::IncrementalSortIterator; use datafusion_common::instant::Instant; +use datafusion_common::utils::memory::get_record_batch_memory_size; use futures::ready; use futures::stream::{Stream, StreamExt}; use log::debug; @@ -1060,10 +1061,27 @@ impl GroupedHashAggregateStream { fn update_memory_reservation(&mut self) -> Result<()> { let acc = self.accumulators.iter().map(|x| x.size()).sum::(); - let new_size = acc + let groups_and_acc_size = acc + self.group_values.size() + self.group_ordering.size() + self.current_group_indices.allocated_size(); + + // Reserve extra headroom for sorting during potential spill. + // When OOM triggers, group_aggregate_batch has already processed the + // latest input batch, so the internal state may have grown well beyond + // the last successful reservation. The emit batch reflects this larger + // actual state, and the sort needs memory proportional to it. + // By reserving headroom equal to the data size, we trigger OOM earlier + // (before too much data accumulates), ensuring the freed reservation + // after clear_shrink is sufficient to cover the sort memory. + let sort_headroom = + if self.oom_mode == OutOfMemoryMode::Spill && !self.group_values.is_empty() { + acc + self.group_values.size() + } else { + 0 + }; + + let new_size = groups_and_acc_size + sort_headroom; let reservation_result = self.reservation.try_resize(new_size); if reservation_result.is_ok() { @@ -1122,17 +1140,47 @@ impl GroupedHashAggregateStream { let Some(emit) = self.emit(EmitTo::All, true)? else { return Ok(()); }; - let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?; - // Spill sorted state to disk + // Free accumulated state now that data has been emitted into `emit`. + // This must happen before reserving sort memory so the pool has room. + // Use 0 to minimize allocated capacity and maximize memory available for sorting. + self.clear_shrink(0); + self.update_memory_reservation()?; + + let batch_size_ratio = self.batch_size as f32 / emit.num_rows() as f32; + let batch_memory = get_record_batch_memory_size(&emit); + // The maximum worst case for a sort is 2X the original underlying buffers(regardless of slicing) + // First we get the underlying buffers' size, then we get the sliced("actual") size of the batch, + // and multiply it by the ratio of batch_size to actual size to get the estimated memory needed for sorting the batch. + // If something goes wrong in get_sliced_size()(double counting or something), + // we fall back to the worst case. + let sort_memory = (batch_memory + + (emit.get_sliced_size()? as f32 * batch_size_ratio) as usize) + .min(batch_memory * 2); + + // If we can't grow even that, we have no choice but to return an error since we can't spill to disk without sorting the data first. + self.reservation.try_grow(sort_memory).map_err(|err| { + resources_datafusion_err!( + "Failed to reserve memory for sort during spill: {err}" + ) + })?; + + let sorted_iter = IncrementalSortIterator::new( + emit, + self.spill_state.spill_expr.clone(), + self.batch_size, + ); let spillfile = self .spill_state .spill_manager - .spill_record_batch_by_size_and_return_max_batch_memory( - &sorted, + .spill_record_batch_iter_and_return_max_batch_memory( + sorted_iter, "HashAggSpill", - self.batch_size, )?; + + // Shrink the memory we allocated for sorting as the sorting is fully done at this point. + self.reservation.shrink(sort_memory); + match spillfile { Some((spillfile, max_record_batch_memory)) => { self.spill_state.spills.push(SortedSpillFile { @@ -1150,14 +1198,14 @@ impl GroupedHashAggregateStream { Ok(()) } - /// Clear memory and shirk capacities to the size of the batch. + /// Clear memory and shrink capacities to the given number of rows. fn clear_shrink(&mut self, num_rows: usize) { self.group_values.clear_shrink(num_rows); self.current_group_indices.clear(); self.current_group_indices.shrink_to(num_rows); } - /// Clear memory and shirk capacities to zero. + /// Clear memory and shrink capacities to zero. fn clear_all(&mut self) { self.clear_shrink(0); } @@ -1196,7 +1244,7 @@ impl GroupedHashAggregateStream { // instead. // Spilling to disk and reading back also ensures batch size is consistent // rather than potentially having one significantly larger last batch. - self.spill()?; // TODO: use sort_batch_chunked instead? + self.spill()?; // Mark that we're switching to stream merging mode. self.spill_state.is_stream_merging = true; diff --git a/datafusion/physical-plan/src/sorts/mod.rs b/datafusion/physical-plan/src/sorts/mod.rs index 9c72e34fe..a73872a17 100644 --- a/datafusion/physical-plan/src/sorts/mod.rs +++ b/datafusion/physical-plan/src/sorts/mod.rs @@ -26,3 +26,5 @@ pub mod sort; pub mod sort_preserving_merge; mod stream; pub mod streaming_merge; + +pub(crate) use stream::IncrementalSortIterator; diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index c735963d9..b3ea548d5 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -39,6 +39,7 @@ use crate::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, SpillMetrics, }; use crate::projection::{ProjectionExec, make_with_child, update_ordering}; +use crate::sorts::IncrementalSortIterator; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::spill::get_record_batch_memory_size; use crate::spill::in_progress_spill_file::InProgressSpillFile; @@ -728,7 +729,6 @@ impl ExternalSorter { // Sort the batch immediately and get all output batches let sorted_batches = sort_batch_chunked(&batch, &expressions, batch_size)?; - drop(batch); // Free the old reservation and grow it to match the actual sorted output size reservation.free(); @@ -853,11 +853,13 @@ pub(crate) fn get_reserved_bytes_for_record_batch_size( /// Estimate how much memory is needed to sort a `RecordBatch`. /// This will just call `get_reserved_bytes_for_record_batch_size` with the /// memory size of the record batch and its sliced size. -pub(super) fn get_reserved_bytes_for_record_batch(batch: &RecordBatch) -> Result { - Ok(get_reserved_bytes_for_record_batch_size( - get_record_batch_memory_size(batch), - batch.get_sliced_size()?, - )) +pub(crate) fn get_reserved_bytes_for_record_batch(batch: &RecordBatch) -> Result { + batch.get_sliced_size().map(|sliced_size| { + get_reserved_bytes_for_record_batch_size( + get_record_batch_memory_size(batch), + sliced_size, + ) + }) } impl Debug for ExternalSorter { @@ -900,38 +902,7 @@ pub fn sort_batch_chunked( expressions: &LexOrdering, batch_size: usize, ) -> Result> { - let sort_columns = expressions - .iter() - .map(|expr| expr.evaluate_to_sort_column(batch)) - .collect::>>()?; - - let indices = lexsort_to_indices(&sort_columns, None)?; - - // Split indices into chunks of batch_size - let num_rows = indices.len(); - let num_chunks = num_rows.div_ceil(batch_size); - - let result_batches = (0..num_chunks) - .map(|chunk_idx| { - let start = chunk_idx * batch_size; - let end = (start + batch_size).min(num_rows); - let chunk_len = end - start; - - // Create a slice of indices for this chunk - let chunk_indices = indices.slice(start, chunk_len); - - // Take the columns using this chunk of indices - let columns = take_arrays(batch.columns(), &chunk_indices, None)?; - - let options = RecordBatchOptions::new().with_row_count(Some(chunk_len)); - let chunk_batch = - RecordBatch::try_new_with_options(batch.schema(), columns, &options)?; - - Ok(chunk_batch) - }) - .collect::>>()?; - - Ok(result_batches) + IncrementalSortIterator::new(batch.clone(), expressions.clone(), batch_size).collect() } /// Sort execution plan. diff --git a/datafusion/physical-plan/src/sorts/stream.rs b/datafusion/physical-plan/src/sorts/stream.rs index 779511a86..ff7f259dd 100644 --- a/datafusion/physical-plan/src/sorts/stream.rs +++ b/datafusion/physical-plan/src/sorts/stream.rs @@ -18,16 +18,20 @@ use crate::SendableRecordBatchStream; use crate::sorts::cursor::{ArrayValues, CursorArray, RowValues}; use crate::{PhysicalExpr, PhysicalSortExpr}; -use arrow::array::Array; +use arrow::array::{Array, UInt32Array}; +use arrow::compute::take_record_batch; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, Rows, SortField}; +use arrow_ord::sort::lexsort_to_indices; use datafusion_common::{Result, internal_datafusion_err}; use datafusion_execution::memory_pool::MemoryReservation; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays; use futures::stream::{Fuse, StreamExt}; +use std::iter::FusedIterator; use std::marker::PhantomData; +use std::mem; use std::sync::Arc; use std::task::{Context, Poll, ready}; @@ -103,7 +107,7 @@ impl ReusableRows { self.inner[stream_idx][1] = Some(Arc::clone(rows)); // swap the current with the previous one, so that the next poll can reuse the Rows from the previous poll let [a, b] = &mut self.inner[stream_idx]; - std::mem::swap(a, b); + mem::swap(a, b); } } @@ -276,3 +280,159 @@ impl PartitionedStream for FieldCursorStream { })) } } + +/// A lazy, memory-efficient sort iterator used as a fallback during aggregate +/// spill when there is not enough memory for an eager sort (which requires ~2x +/// peak memory to hold both the unsorted and sorted copies simultaneously). +/// +/// On the first call to `next()`, a sorted index array (`UInt32Array`) is +/// computed via `lexsort_to_indices`. Subsequent calls yield chunks of +/// `batch_size` rows by `take`-ing from the original batch using slices of +/// this index array. Each `take` copies data for the chunk (not zero-copy), +/// but only one chunk is live at a time since the caller consumes it before +/// requesting the next. Once all rows have been yielded, the original batch +/// and index array are dropped to free memory. +/// +/// The caller must reserve `sizeof(batch) + sizeof(one chunk)` for this iterator, +/// and free the reservation once the iterator is depleted. +pub(crate) struct IncrementalSortIterator { + batch: RecordBatch, + expressions: LexOrdering, + batch_size: usize, + indices: Option, + cursor: usize, +} + +impl IncrementalSortIterator { + pub(crate) fn new( + batch: RecordBatch, + expressions: LexOrdering, + batch_size: usize, + ) -> Self { + Self { + batch, + expressions, + batch_size, + cursor: 0, + indices: None, + } + } +} + +impl Iterator for IncrementalSortIterator { + type Item = Result; + + fn next(&mut self) -> Option { + if self.cursor >= self.batch.num_rows() { + return None; + } + + match self.indices.as_ref() { + None => { + let sort_columns = match self + .expressions + .iter() + .map(|expr| expr.evaluate_to_sort_column(&self.batch)) + .collect::>>() + { + Ok(cols) => cols, + Err(e) => return Some(Err(e)), + }; + + let indices = match lexsort_to_indices(&sort_columns, None) { + Ok(indices) => indices, + Err(e) => return Some(Err(e.into())), + }; + self.indices = Some(indices); + + // Call again, this time it will hit the Some(indices) branch and return the first batch + self.next() + } + Some(indices) => { + let batch_size = self.batch_size.min(self.batch.num_rows() - self.cursor); + + // Perform the take to produce the next batch + let new_batch_indices = indices.slice(self.cursor, batch_size); + let new_batch = match take_record_batch(&self.batch, &new_batch_indices) { + Ok(batch) => batch, + Err(e) => return Some(Err(e.into())), + }; + + self.cursor += batch_size; + + // If this is the last batch, we can release the memory + if self.cursor >= self.batch.num_rows() { + let schema = self.batch.schema(); + let _ = mem::replace(&mut self.batch, RecordBatch::new_empty(schema)); + self.indices = None; + } + + // Return the new batch + Some(Ok(new_batch)) + } + } + } + + fn size_hint(&self) -> (usize, Option) { + let num_rows = self.batch.num_rows(); + let batch_size = self.batch_size; + let num_batches = num_rows.div_ceil(batch_size); + (num_batches, Some(num_batches)) + } +} + +impl FusedIterator for IncrementalSortIterator {} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{AsArray, Int32Array}; + use arrow::datatypes::{DataType, Field, Int32Type}; + use datafusion_common::DataFusionError; + use datafusion_physical_expr::expressions::col; + + /// Verifies that `take_record_batch` in `IncrementalSortIterator` actually + /// copies the data into a new allocation rather than returning a zero-copy + /// slice of the original batch. If the output arrays were slices, their + /// underlying buffer length would match the original array's length; a true + /// copy will have a buffer sized to fit only the chunk. + #[test] + fn incremental_sort_iterator_copies_data() -> Result<()> { + let original_len = 10; + let batch_size = 3; + + // Build a batch with a single Int32 column of descending values + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let col_a: Int32Array = Int32Array::from(vec![0; original_len]); + let batch = RecordBatch::try_new(schema, vec![Arc::new(col_a)])?; + + // Sort ascending on column "a" + let expressions = LexOrdering::new(vec![PhysicalSortExpr::new_default(col( + "a", + &batch.schema(), + )?)]) + .unwrap(); + + let mut total_rows = 0; + IncrementalSortIterator::new(batch.clone(), expressions, batch_size).try_for_each( + |result| { + let chunk = result?; + total_rows += chunk.num_rows(); + + // Every output column must be a fresh allocation whose length + // equals the chunk size, NOT the original array length. + chunk.columns().iter().zip(batch.columns()).for_each(|(arr, original_arr)| { + let (_, scalar_buf, _) = arr.as_primitive::().clone().into_parts(); + let (_, original_scalar_buf, _) = original_arr.as_primitive::().clone().into_parts(); + + assert_ne!(scalar_buf.inner().data_ptr(), original_scalar_buf.inner().data_ptr(), "Expected a copy of the data for each chunk, but got a slice that shares the same buffer as the original array"); + }); + + Result::<_, DataFusionError>::Ok(()) + }, + )?; + + assert_eq!(total_rows, original_len); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index 4c93c03b3..f6ce546a4 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -477,11 +477,12 @@ mod tests { let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema)); + let row_batches: Vec = + (0..batch1.num_rows()).map(|i| batch1.slice(i, 1)).collect(); let (spill_file, max_batch_mem) = spill_manager - .spill_record_batch_by_size_and_return_max_batch_memory( - &batch1, + .spill_record_batch_iter_and_return_max_batch_memory( + row_batches.iter().map(Ok), "Test Spill", - 1, )? .unwrap(); assert!(spill_file.path().exists()); @@ -731,7 +732,7 @@ mod tests { let completed_file = spill_manager.spill_record_batch_and_finish(&[], "Test")?; assert!(completed_file.is_none()); - // Test write empty batch with interface `spill_record_batch_by_size_and_return_max_batch_memory()` + // Test write empty batch with interface `spill_record_batch_iter_and_return_max_batch_memory()` let empty_batch = RecordBatch::try_new( Arc::clone(&schema), vec![ @@ -740,10 +741,9 @@ mod tests { ], )?; let completed_file = spill_manager - .spill_record_batch_by_size_and_return_max_batch_memory( - &empty_batch, + .spill_record_batch_iter_and_return_max_batch_memory( + std::iter::once(Ok(&empty_batch)), "Test", - 1, )?; assert!(completed_file.is_none()); diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 6d931112a..07ba6d398 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -17,18 +17,19 @@ //! Define the `SpillManager` struct, which is responsible for reading and writing `RecordBatch`es to raw files based on the provided configurations. -use arrow::array::StringViewArray; -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; -use datafusion_common::{Result, config::SpillCompression}; -use datafusion_execution::SendableRecordBatchStream; -use datafusion_execution::disk_manager::RefCountedTempFile; -use datafusion_execution::runtime_env::RuntimeEnv; -use std::sync::Arc; - use super::{SpillReaderStream, in_progress_spill_file::InProgressSpillFile}; use crate::coop::cooperative; use crate::{common::spawn_buffered, metrics::SpillMetrics}; +use arrow::array::StringViewArray; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::utils::memory::get_record_batch_memory_size; +use datafusion_common::{DataFusionError, Result, config::SpillCompression}; +use datafusion_execution::SendableRecordBatchStream; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::runtime_env::RuntimeEnv; +use std::borrow::Borrow; +use std::sync::Arc; /// The `SpillManager` is responsible for the following tasks: /// - Reading and writing `RecordBatch`es to raw files based on the provided configurations. @@ -109,39 +110,29 @@ impl SpillManager { in_progress_file.finish() } - /// Refer to the documentation for [`Self::spill_record_batch_and_finish`]. This method - /// additionally spills the `RecordBatch` into smaller batches, divided by `row_limit`. - /// - /// # Errors - /// - Returns an error if spilling would exceed the disk usage limit configured - /// by `max_temp_directory_size` in `DiskManager` - pub(crate) fn spill_record_batch_by_size_and_return_max_batch_memory( + /// Spill an iterator of `RecordBatch`es to disk and return the spill file and the size of the largest batch in memory + /// Note that this expects the caller to provide *non-sliced* batches, so the memory calculation of each batch is accurate. + pub(crate) fn spill_record_batch_iter_and_return_max_batch_memory( &self, - batch: &RecordBatch, + mut iter: impl Iterator>>, request_description: &str, - row_limit: usize, ) -> Result> { - let total_rows = batch.num_rows(); - let mut batches = Vec::new(); - let mut offset = 0; - - // It's ok to calculate all slices first, because slicing is zero-copy. - while offset < total_rows { - let length = std::cmp::min(total_rows - offset, row_limit); - let sliced_batch = batch.slice(offset, length); - batches.push(sliced_batch); - offset += length; - } - let mut in_progress_file = self.create_in_progress_file(request_description)?; let mut max_record_batch_size = 0; - for batch in batches { - in_progress_file.append_batch(&batch)?; + iter.try_for_each(|batch| { + let batch = batch?; + let borrowed = batch.borrow(); + if borrowed.num_rows() == 0 { + return Ok(()); + } + in_progress_file.append_batch(borrowed)?; - max_record_batch_size = max_record_batch_size.max(batch.get_sliced_size()?); - } + max_record_batch_size = + max_record_batch_size.max(get_record_batch_memory_size(borrowed)); + Result::<_, DataFusionError>::Ok(()) + })?; let file = in_progress_file.finish()?;