feat(function): Add array_union_sum aggregation function#26842
feat(function): Add array_union_sum aggregation function#26842DHRUV6029 wants to merge 2 commits intoprestodb:masterfrom
Conversation
Reviewer's GuideImplements a new array_union_sum SQL aggregation function for non-decimal numeric arrays, including execution engine support, state management, serialization, registration in the built-in function set, tests, and documentation/release notes updates. Sequence diagram for array_union_sum aggregation executionsequenceDiagram
actor User
participant Parser as SqlParser
participant Planner as QueryPlanner
participant Engine as ExecutionEngine
participant Acc as ArrayUnionSumAccumulator
participant State as ArrayUnionSumState
participant Result as ArrayUnionSumResult
User->>Parser: Submit SELECT array_union_sum(arr) ...
Parser->>Planner: ParsedQuery(plan with array_union_sum)
Planner->>Engine: PhysicalPlan(using ArrayUnionSumAggregation)
loop For each input page
Engine->>Acc: addInput(page)
loop For each row with array value
Acc->>+State: input(elementType, state, arrayBlock)
alt state.get() is null (first array)
State-->>State: create ArrayUnionSumResult.create(elementType, adder, arrayBlock)
State-->>Acc: updated state with SingleArrayBlock
else subsequent arrays
State->>Result: get()
Result->>Result: unionSum(arrayBlock)
Result-->>State: new AccumulatedValues
end
end
end
note over Engine,Acc: Partial aggregation results may be combined across threads/nodes
Engine->>Acc: combine(partialState1, partialState2)
Acc->>State: combine(state, otherState)
State->>Result: get()
Result->>Result: unionSum(otherState.get())
Result-->>State: merged AccumulatedValues
Engine->>Acc: evaluateFinal()
Acc->>State: output(state, blockBuilder)
alt state.get() is null
State-->>Engine: appendNull()
else non null
State->>Result: get()
Result->>Result: serialize(outBlockBuilder)
Result-->>Engine: final array block
end
Engine-->>User: QueryResult(with aggregated array_union_sum column)
Class diagram for array_union_sum aggregation componentsclassDiagram
class ArrayUnionSumAggregation {
+String NAME
+ArrayUnionSumAggregation()
+String getDescription()
+BuiltInAggregationFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager)
-static BuiltInAggregationFunctionImplementation generateAggregation(Type elementType, ArrayType outputType)
-static List~ParameterMetadata~ createInputParameterMetadata(Type inputType)
+static void input(Type elementType, ArrayUnionSumState state, Block arrayBlock)
+static void combine(ArrayUnionSumState state, ArrayUnionSumState otherState)
+static void output(ArrayUnionSumState state, BlockBuilder out)
-static MethodHandle INPUT_FUNCTION
-static MethodHandle COMBINE_FUNCTION
-static MethodHandle OUTPUT_FUNCTION
}
class ArrayUnionSumResult {
-Type elementType
-Adder adder
+ArrayUnionSumResult(Type elementType, Adder adder)
+static ArrayUnionSumResult create(Type elementType, Adder adder, Block arrayBlock)
+Type getElementType()
+void serialize(BlockBuilder out)
+ArrayUnionSumResult unionSum(ArrayUnionSumResult other)
+ArrayUnionSumResult unionSum(Block arrayBlock)
+long getRetainedSizeInBytes()
+int size()
+void appendValue(int i, BlockBuilder blockBuilder)
+boolean isValueNull(int i)
+Block getValueBlock()
+int getValueBlockIndex(int i)
+static void appendValue(Type elementType, Block block, int position, BlockBuilder blockBuilder)
}
class ArrayUnionSumResult_SingleArrayBlock {
-Block arrayBlock
+SingleArrayBlock(Type elementType, Adder adder, Block arrayBlock)
+int size()
+void appendValue(int i, BlockBuilder blockBuilder)
+boolean isValueNull(int i)
+long getRetainedSizeInBytes()
+Block getValueBlock()
+int getValueBlockIndex(int i)
}
class ArrayUnionSumResult_AccumulatedValues {
-Block valueBlock
+AccumulatedValues(Type elementType, Adder adder, Block valueBlock)
+int size()
+void appendValue(int i, BlockBuilder blockBuilder)
+boolean isValueNull(int i)
+long getRetainedSizeInBytes()
+Block getValueBlock()
+int getValueBlockIndex(int i)
}
class ArrayUnionSumState {
<<interface>>
+ArrayUnionSumResult get()
+void set(ArrayUnionSumResult value)
+void addMemoryUsage(long memory)
+Type getElementType()
+Adder getAdder()
}
class ArrayUnionSumStateFactory {
+ArrayUnionSumStateFactory(Type elementType)
+ArrayUnionSumState createSingleState()
+Class~? extends ArrayUnionSumState~ getSingleStateClass()
+ArrayUnionSumState createGroupedState()
+Class~? extends ArrayUnionSumState~ getGroupedStateClass()
-Type elementType
-Adder adder
-static Adder LONG_ADDER
-static Adder DOUBLE_ADDER
-static Adder FLOAT_ADDER
-static Adder getAdder(Type type)
}
class ArrayUnionSumStateFactory_GroupedState {
-Type elementType
-Adder adder
-ObjectBigArray~ArrayUnionSumResult~ results
-long size
+GroupedState(Type elementType, Adder adder)
+void ensureCapacity(long size)
+ArrayUnionSumResult get()
+void set(ArrayUnionSumResult value)
+void addMemoryUsage(long memory)
+Type getElementType()
+long getEstimatedSize()
+Adder getAdder()
}
class ArrayUnionSumStateFactory_SingleState {
-Type elementType
-Adder adder
-ArrayUnionSumResult result
+SingleState(Type elementType, Adder adder)
+ArrayUnionSumResult get()
+void set(ArrayUnionSumResult value)
+void addMemoryUsage(long memory)
+Type getElementType()
+long getEstimatedSize()
+Adder getAdder()
}
class ArrayUnionSumStateSerializer {
-ArrayType arrayType
+ArrayUnionSumStateSerializer(ArrayType arrayType)
+Type getSerializedType()
+void serialize(ArrayUnionSumState state, BlockBuilder out)
+void deserialize(Block block, int index, ArrayUnionSumState state)
}
class Adder {
<<interface>>
+void writeSum(Type type, Block block1, int position1, Block block2, int position2, BlockBuilder blockBuilder)
}
class SqlAggregationFunction
class AccumulatorStateFactory
class AccumulatorStateSerializer
class AccumulatorState
class AbstractGroupedAccumulatorState
class ArrayType
class Type
class Block
class BlockBuilder
ArrayUnionSumAggregation --> ArrayUnionSumState : uses
ArrayUnionSumAggregation --> ArrayUnionSumStateFactory : creates
ArrayUnionSumAggregation --> ArrayUnionSumStateSerializer : uses
ArrayUnionSumAggregation --> ArrayUnionSumResult : uses
ArrayUnionSumAggregation --|> SqlAggregationFunction
ArrayUnionSumResult <|-- ArrayUnionSumResult_SingleArrayBlock
ArrayUnionSumResult <|-- ArrayUnionSumResult_AccumulatedValues
ArrayUnionSumResult --> Adder : uses
ArrayUnionSumResult --> Type : elementType
ArrayUnionSumResult --> Block : holds
ArrayUnionSumState <|.. ArrayUnionSumStateFactory_GroupedState
ArrayUnionSumState <|.. ArrayUnionSumStateFactory_SingleState
ArrayUnionSumStateFactory --|> AccumulatorStateFactory
ArrayUnionSumStateFactory --> ArrayUnionSumStateFactory_GroupedState : creates
ArrayUnionSumStateFactory --> ArrayUnionSumStateFactory_SingleState : creates
ArrayUnionSumStateFactory --> Adder : configures
ArrayUnionSumStateFactory_GroupedState --|> AbstractGroupedAccumulatorState
ArrayUnionSumStateFactory_GroupedState --> ArrayUnionSumResult : stores
ArrayUnionSumStateFactory_SingleState --> ArrayUnionSumResult : stores
ArrayUnionSumStateSerializer --|> AccumulatorStateSerializer
ArrayUnionSumStateSerializer --> ArrayUnionSumResult : reconstructs
ArrayUnionSumStateSerializer --> ArrayType : uses
ArrayUnionSumState --|> AccumulatorState
ArrayType --> Type
Block --> Type
BlockBuilder --> Type
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've found 4 issues, and left some high level feedback:
- In
ArrayUnionSumAggregation.combine, callingstate.get().unionSum(otherState.get())whenotherState.get()is null will throw; add an early return whenotherState.get() == nullto mirror typical combine semantics and avoid NPEs. - The
inputfunction currently treats every row as having a non-null array block; if the contract is that NULL arrays should be skipped (as described in the PR), add an explicit check to ignore null input values rather than creating anArrayUnionSumResultfor them.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- In `ArrayUnionSumAggregation.combine`, calling `state.get().unionSum(otherState.get())` when `otherState.get()` is null will throw; add an early return when `otherState.get() == null` to mirror typical combine semantics and avoid NPEs.
- The `input` function currently treats every row as having a non-null array block; if the contract is that NULL arrays should be skipped (as described in the PR), add an explicit check to ignore null input values rather than creating an `ArrayUnionSumResult` for them.
## Individual Comments
### Comment 1
<location> `presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/ArrayUnionSumAggregation.java:130-135` </location>
<code_context>
+ arrayUnionSumResult = ArrayUnionSumResult.create(elementType, state.getAdder(), arrayBlock);
+ state.set(arrayUnionSumResult);
+ }
+ else {
+ startSize = arrayUnionSumResult.getRetainedSizeInBytes();
+ state.set(state.get().unionSum(arrayBlock));
+ }
+
+ state.addMemoryUsage(arrayUnionSumResult.getRetainedSizeInBytes() - startSize);
+ }
+
</code_context>
<issue_to_address>
**issue (bug_risk):** Memory accounting in input() uses the old result size instead of the new one
In the `else` branch, `arrayUnionSumResult` still points to the pre-union result. After `state.set(state.get().unionSum(arrayBlock));`, the new size is only available via `state.get()`, so using `arrayUnionSumResult.getRetainedSizeInBytes()` keeps using the old size and the delta is always 0.
You could fix this by capturing the current state and its size, creating the new result, then updating memory based on the new result, e.g.:
```java
else {
ArrayUnionSumResult current = state.get();
startSize = current.getRetainedSizeInBytes();
ArrayUnionSumResult newResult = current.unionSum(arrayBlock);
state.set(newResult);
state.addMemoryUsage(newResult.getRetainedSizeInBytes() - startSize);
}
```
or by recomputing the size from `state.get()` after the `set`.
</issue_to_address>
### Comment 2
<location> `presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/ArrayUnionSumAggregation.java:140-142` </location>
<code_context>
+
+ public static void combine(ArrayUnionSumState state, ArrayUnionSumState otherState)
+ {
+ if (state.get() == null) {
+ state.set(otherState.get());
+ return;
+ }
+
</code_context>
<issue_to_address>
**issue (bug_risk):** combine() does not update memory usage when the target state is initially null
When `state.get() == null`, you assign `otherState.get()` and return without updating memory usage, so combining into an empty state under-reports the memory of the incoming result. You should add the retained size of `otherState.get()` in this branch (as you do in the non-null case) to keep memory accounting consistent.
</issue_to_address>
### Comment 3
<location> `presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java:6339` </location>
<code_context>
+ ImmutableList.of(),
+ parseTypeSignature("array<T>"),
+ ImmutableList.of(parseTypeSignature("array<T>")));
+ }
+
+ @Override
</code_context>
<issue_to_address>
**suggestion (testing):** Consider adding query tests for NULL arrays and empty arrays to cover documented behavior
`testArrayUnionSum` currently only covers arrays with elements and element-level NULLs. Please also add tests that:
- Mix `CAST(NULL AS array<bigint>)` with non-null arrays and assert that NULL-array rows are skipped, producing the same result as aggregating only non-null arrays.
- Mix `CAST(array[] AS array<bigint>)` with non-empty arrays and assert that empty arrays don’t affect the result, and that queries with only empty arrays succeed without errors.
This will exercise the documented semantics for NULL and empty arrays.
</issue_to_address>
### Comment 4
<location> `presto-main-base/src/test/java/com/facebook/presto/operator/aggregation/TestArrayUnionSumResult.java:40` </location>
<code_context>
+ private static final ArrayType ARRAY_DOUBLE = new ArrayType(DOUBLE);
+ private static final ArrayType ARRAY_REAL = new ArrayType(REAL);
+
+ @Test
+ public void testBasicUnionSum()
+ {
</code_context>
<issue_to_address>
**suggestion (testing):** Add a unit test for handling empty input arrays in ArrayUnionSumResult
Current tests don’t cover the case where `positionCount = 0`. Please add a test that builds an empty `Block` (BIGINT, and optionally other types), constructs an `ArrayUnionSumResult` from it, asserts `size() == 0`, verifies serialization produces an array block with 0 elements, and (optionally) unions this empty result/block with a non-empty one to confirm the non-empty result is preserved. This guards against implementations that assume `size() > 0` and validates empty-result serialization behavior.
Suggested implementation:
```java
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;
```
```java
public class TestArrayUnionSumResult
{
private static final ArrayType ARRAY_BIGINT = new ArrayType(BIGINT);
private static final ArrayType ARRAY_DOUBLE = new ArrayType(DOUBLE);
private static final ArrayType ARRAY_REAL = new ArrayType(REAL);
@Test
public void testEmptyInputArrays()
{
// Build an empty BIGINT array block
BlockBuilder arrayBlockBuilder = ARRAY_BIGINT.createBlockBuilder(null, 0);
Block emptyArrayBlock = arrayBlockBuilder.build();
// Create an ArrayUnionSumState and update it with the empty block
ArrayUnionSumState state = new ArrayUnionSumStateFactory(BIGINT).createSingleState();
state.setResult(new ArrayUnionSumResult(ARRAY_BIGINT, BIGINT, emptyArrayBlock));
ArrayUnionSumResult result = state.getResult();
// Verify empty result
assertNotNull(result);
assertEquals(result.size(), 0);
// Verify serialization produces an empty array block
Block serialized = result.serialize();
assertEquals(serialized.getPositionCount(), 0);
// Build a non-empty BIGINT array block containing [1, 2]
BlockBuilder nonEmptyArrayBlockBuilder = ARRAY_BIGINT.createBlockBuilder(null, 1);
BlockBuilder elementBlockBuilder = BIGINT.createBlockBuilder(null, 2);
BIGINT.writeLong(elementBlockBuilder, 1L);
BIGINT.writeLong(elementBlockBuilder, 2L);
nonEmptyArrayBlockBuilder.writeObject(elementBlockBuilder.build()).closeEntry();
Block nonEmptyArrayBlock = nonEmptyArrayBlockBuilder.build();
ArrayUnionSumState nonEmptyState = new ArrayUnionSumStateFactory(BIGINT).createSingleState();
nonEmptyState.setResult(new ArrayUnionSumResult(ARRAY_BIGINT, BIGINT, nonEmptyArrayBlock));
ArrayUnionSumResult nonEmptyResult = nonEmptyState.getResult();
// Union empty with non-empty and verify non-empty is preserved
ArrayUnionSumResult unionResult = result.union(nonEmptyResult);
assertEquals(unionResult.size(), nonEmptyResult.size());
Block unionSerialized = unionResult.serialize();
assertEquals(unionSerialized.getPositionCount(), nonEmptyArrayBlock.getPositionCount());
}
@Test
public void testBasicUnionSum()
{
ArrayUnionSumState state = new ArrayUnionSumStateFactory(BIGINT).createSingleState();
// Create array [1, 2, 3]
```
The above test assumes the following APIs exist and are accessible in this test:
1. `Block` and `BlockBuilder` are already imported (typically from `com.facebook.presto.spi.block`).
2. `ArrayUnionSumResult` has:
- A constructor `ArrayUnionSumResult(ArrayType arrayType, Type elementType, Block arrayBlock)`.
- Methods `int size()`, `Block serialize()`, and `ArrayUnionSumResult union(ArrayUnionSumResult other)`.
3. `ArrayUnionSumState` has `setResult(ArrayUnionSumResult)` and `getResult()` methods.
4. `ArrayUnionSumStateFactory` has a constructor `ArrayUnionSumStateFactory(Type elementType)` and a `createSingleState()` method.
If the actual APIs differ slightly, adjust the constructor and method calls in `testEmptyInputArrays()` to match the real `ArrayUnionSumResult` / state APIs while preserving the test semantics:
- Construct a state/result from an empty input array.
- Assert `size() == 0`.
- Assert the serialized block has `positionCount == 0`.
- Union the empty result with a non-empty one and assert the non-empty union behavior.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
...in-base/src/main/java/com/facebook/presto/operator/aggregation/ArrayUnionSumAggregation.java
Show resolved
Hide resolved
...in-base/src/main/java/com/facebook/presto/operator/aggregation/ArrayUnionSumAggregation.java
Show resolved
Hide resolved
|
Hi @DHRUV6029 : Thanks for your code contribution. In recent years, Presto Native Worker has become the default execution engine for Presto. So we are adding new functions to C++ engine over Java. Have you considered using it ? Do you have plans to add a native function implementation as well ? Would you be able to give more details about your use-case for Presto Java engine so that we can decide further about the code ? |
|
hello @aditi-pandit I'm flexible — happy to provide:
Whatever works best for the project. |
steveburnett
left a comment
There was a problem hiding this comment.
Thanks for the doc in presto-docs/src/main/sphinx/functions/aggregate.rst!
Please see my comment about the 0.296 release notes, and let me know.
...in-base/src/main/java/com/facebook/presto/operator/aggregation/ArrayUnionSumAggregation.java
Show resolved
Hide resolved
|
@abhinavmuk04 I CC'd you because I thought you were looking at native implementation of this function, |
@kaikalur I am also working on the native imp of this function, haha |
|
@kaikalur Are you working on C++/Velox native version of this function ? if not i can take that up |
Adds a new SQL aggregation function that combines arrays by summing values at corresponding indices. Features: - Result array length is the maximum of all input arrays - Missing elements treated as 0 - Null values coalesced to 0 - Supports BIGINT, INTEGER, SMALLINT, TINYINT, DOUBLE, and REAL Files: - ArrayUnionSumAggregation.java - Main aggregation logic - ArrayUnionSumResult.java - Result container with union-sum logic - ArrayUnionSumState*.java - State management for aggregation - TestArrayUnionSumResult.java - Unit tests (14 test cases) - Updated docs and integration tests
1af6be6 to
9f10110
Compare
|
@steveburnett @abhinavmuk04 I addressed your comments on the PR , please take a look |
steveburnett
left a comment
There was a problem hiding this comment.
LGTM! (docs)
Pull updated branch, new local doc build, looks good. Thanks!
|
@abhinavmuk04, when you have time would you please take a look at @DHRUV6029's addressing of your comments on this PR? |
1 similar comment
|
@abhinavmuk04, when you have time would you please take a look at @DHRUV6029's addressing of your comments on this PR? |
|
Hello @abhinavmuk04 please can you review this PR Thanks |
|
@abhinavmuk04, would you please re-review this PR and see if your comments have been addressed? Thank you! |
| @@ -0,0 +1,168 @@ | |||
| /* | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
There was a problem hiding this comment.
Hi @DHRUV6029 could you please take a look at [[https://github.com/facebookincubator/velox/pull/16498](https://github.com/facebookincubator/velox/pull/15973). ](facebookincubator/velox#15973 is the native implementation, it can be helpful to base off of
|
@abhinavmuk04 Thanks for sharing the native implementation (facebookincubator/velox#15972). I've reviewed it in detail and confirmed the Java and C++ implementations are aligned on all user-facing semantics:
One difference I noticed: the C++ implementation has a zero-skipping optimization ( The variadic form ( Also addressed your earlier review comments:
|
abhinavmuk04
left a comment
There was a problem hiding this comment.
@DHRUV6029 Thanks for the changes and this PR, LGTM cc @feilong-liu
Description
Adds a new SQL aggregation function
array_union_sumthat combines multiple arrays by summing values at corresponding indices. This function is analogous tomap_union_sumbut operates on arrays instead of maps.Example usage:
Implementation details:
Motivation and Context
Users often need to aggregate arrays element-wise, such as:
Currently, this requires complex workarounds using unnest, zip, and manual reconstruction. The array_union_sum function provides a clean, efficient solution similar to the existing map_union_sum function.
Impact
Public API changes:
Behavior:
Supported types:
Performance:
Test Plan
Contributor checklist
Release Notes
== RELEASE NOTES ==
General
array_union_sumaggregation function that combines arrays by summing values at corresponding indices. Supports all non-decimal numeric types (BIGINT, INTEGER, DOUBLE, REAL). The result array length is the maximum of all input arrays, with missing elements treated as 0 and NULL values coalesced to 0.