Skip to content

Commit 5c10bd0

Browse files
committed
feat(optimizer): Rewrite ROW constructor IN to disjunction for partition pruning
Add a new iterative optimizer rule RewriteRowConstructorInToDisjunction that rewrites predicates of the form: ROW(pk1, pk2) IN (ROW('a', 1), ROW('b', 2)) into: (pk1 = 'a' AND pk2 = 1) OR (pk1 = 'b' AND pk2 = 2) This transformation fires only when ALL fields of the left-side ROW constructor are partition key columns of the underlying table. The rewrite enables PickTableLayout's RowExpressionDomainTranslator to extract per-column TupleDomain constraints for partition pruning, which is impossible when the predicate uses ROW-level IN comparisons. Without this rewrite, the domain translator sees TupleDomain{ALL} (no constraints, full table scan). After the rewrite, it extracts per-column domains like {pk1 -> {'a','b'}, pk2 -> {1,2}}, enabling Hive partition pruning via HivePartitionManager. The rule is gated behind a session property rewrite_row_constructor_in_to_disjunction (default: disabled) and runs before the first PickTableLayout invocation in PlanOptimizers.
1 parent 52ad58a commit 5c10bd0

File tree

6 files changed

+948
-0
lines changed

6 files changed

+948
-0
lines changed

presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ public final class SystemSessionProperties
393393
public static final String NATIVE_ENFORCE_JOIN_BUILD_INPUT_PARTITION = "native_enforce_join_build_input_partition";
394394
public static final String NATIVE_EXECUTION_SCALE_WRITER_THREADS_ENABLED = "native_execution_scale_writer_threads_enabled";
395395
public static final String TRY_FUNCTION_CATCHABLE_ERRORS = "try_function_catchable_errors";
396+
public static final String REWRITE_ROW_CONSTRUCTOR_IN_TO_DISJUNCTION = "rewrite_row_constructor_in_to_disjunction";
396397

397398
private final List<PropertyMetadata<?>> sessionProperties;
398399

@@ -2213,6 +2214,10 @@ public SystemSessionProperties(
22132214
TRY_FUNCTION_CATCHABLE_ERRORS,
22142215
"Comma-separated list of error code names that TRY function should catch (such as 'GENERIC_INTERNAL_ERROR,INVALID_ARGUMENTS')",
22152216
featuresConfig.getTryFunctionCatchableErrors(),
2217+
false),
2218+
booleanProperty(REWRITE_ROW_CONSTRUCTOR_IN_TO_DISJUNCTION,
2219+
"Rewrite ROW(...) IN (ROW(...), ...) into OR of ANDs for partition pruning",
2220+
featuresConfig.isRewriteRowConstructorInToDisjunction(),
22162221
false));
22172222
}
22182223

@@ -3772,4 +3777,9 @@ public static String getTryFunctionCatchableErrors(Session session)
37723777
{
37733778
return session.getSystemProperty(TRY_FUNCTION_CATCHABLE_ERRORS, String.class);
37743779
}
3780+
3781+
public static boolean isRewriteRowConstructorInToDisjunction(Session session)
3782+
{
3783+
return session.getSystemProperty(REWRITE_ROW_CONSTRUCTOR_IN_TO_DISJUNCTION, Boolean.class);
3784+
}
37753785
}

presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ public class FeaturesConfig
299299
private boolean pullUpExpressionFromLambda;
300300
private boolean rewriteConstantArrayContainsToIn;
301301
private boolean rewriteExpressionWithConstantVariable = true;
302+
private boolean rewriteRowConstructorInToDisjunction;
302303
private boolean optimizeConditionalApproxDistinct = true;
303304

304305
private boolean preProcessMetadataCalls;
@@ -3159,6 +3160,19 @@ public FeaturesConfig setRewriteExpressionWithConstantVariable(boolean rewriteEx
31593160
return this;
31603161
}
31613162

3163+
public boolean isRewriteRowConstructorInToDisjunction()
3164+
{
3165+
return this.rewriteRowConstructorInToDisjunction;
3166+
}
3167+
3168+
@Config("optimizer.rewrite-row-constructor-in-to-disjunction")
3169+
@ConfigDescription("Rewrite ROW(...) IN (ROW(...), ...) into OR of ANDs for partition pruning")
3170+
public FeaturesConfig setRewriteRowConstructorInToDisjunction(boolean rewriteRowConstructorInToDisjunction)
3171+
{
3172+
this.rewriteRowConstructorInToDisjunction = rewriteRowConstructorInToDisjunction;
3173+
return this;
3174+
}
3175+
31623176
public boolean isOptimizeConditionalApproxDistinct()
31633177
{
31643178
return this.optimizeConditionalApproxDistinct;

presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@
146146
import com.facebook.presto.sql.planner.iterative.rule.RewriteConstantArrayContainsToInExpression;
147147
import com.facebook.presto.sql.planner.iterative.rule.RewriteExcludeColumnsFunctionToProjection;
148148
import com.facebook.presto.sql.planner.iterative.rule.RewriteFilterWithExternalFunctionToProject;
149+
import com.facebook.presto.sql.planner.iterative.rule.RewriteRowConstructorInToDisjunction;
149150
import com.facebook.presto.sql.planner.iterative.rule.RewriteRowExpressions;
150151
import com.facebook.presto.sql.planner.iterative.rule.RewriteSpatialPartitioningAggregation;
151152
import com.facebook.presto.sql.planner.iterative.rule.RuntimeReorderJoinSides;
@@ -685,6 +686,12 @@ public PlanOptimizers(
685686
estimatedExchangesCostCalculator,
686687
ImmutableSet.of(new AddDistinctForSemiJoinBuild())),
687688
new KeyBasedSampler(metadata),
689+
new IterativeOptimizer(
690+
metadata,
691+
ruleStats,
692+
statsCalculator,
693+
estimatedExchangesCostCalculator,
694+
ImmutableSet.of(new RewriteRowConstructorInToDisjunction(metadata))),
688695
new IterativeOptimizer(
689696
metadata,
690697
ruleStats,
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.presto.sql.planner.iterative.rule;
15+
16+
import com.facebook.presto.matching.Capture;
17+
import com.facebook.presto.matching.Captures;
18+
import com.facebook.presto.matching.Pattern;
19+
import com.facebook.presto.metadata.Metadata;
20+
import com.facebook.presto.spi.ColumnHandle;
21+
import com.facebook.presto.spi.ConnectorTableMetadata;
22+
import com.facebook.presto.spi.TableHandle;
23+
import com.facebook.presto.spi.plan.FilterNode;
24+
import com.facebook.presto.spi.plan.TableScanNode;
25+
import com.facebook.presto.spi.relation.CallExpression;
26+
import com.facebook.presto.spi.relation.RowExpression;
27+
import com.facebook.presto.spi.relation.SpecialFormExpression;
28+
import com.facebook.presto.spi.relation.VariableReferenceExpression;
29+
import com.facebook.presto.sql.planner.iterative.Rule;
30+
import com.facebook.presto.sql.relational.FunctionResolution;
31+
import com.google.common.collect.ImmutableList;
32+
import com.google.common.collect.ImmutableSet;
33+
34+
import java.util.ArrayList;
35+
import java.util.HashSet;
36+
import java.util.List;
37+
import java.util.Map;
38+
import java.util.Set;
39+
40+
import static com.facebook.presto.SystemSessionProperties.isRewriteRowConstructorInToDisjunction;
41+
import static com.facebook.presto.common.function.OperatorType.EQUAL;
42+
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
43+
import static com.facebook.presto.expressions.LogicalRowExpressions.and;
44+
import static com.facebook.presto.expressions.LogicalRowExpressions.or;
45+
import static com.facebook.presto.matching.Capture.newCapture;
46+
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IN;
47+
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.ROW_CONSTRUCTOR;
48+
import static com.facebook.presto.sql.planner.plan.Patterns.filter;
49+
import static com.facebook.presto.sql.planner.plan.Patterns.source;
50+
import static com.facebook.presto.sql.planner.plan.Patterns.tableScan;
51+
import static java.util.Objects.requireNonNull;
52+
53+
/**
54+
* Rewrites predicates of the form:
55+
* <pre>
56+
* ROW(partition_key1, partition_key2) IN (ROW('a', 1), ROW('b', 2), ...)
57+
* </pre>
58+
* into:
59+
* <pre>
60+
* (partition_key1 = 'a' AND partition_key2 = 1)
61+
* OR (partition_key1 = 'b' AND partition_key2 = 2)
62+
* OR ...
63+
* </pre>
64+
*
65+
* This transformation only fires when ALL fields of the left-side ROW constructor
66+
* are partition key columns of the underlying table. The rewrite enables
67+
* {@code PickTableLayout} to extract per-column domains for partition pruning,
68+
* which is impossible when the predicate uses ROW-level IN comparisons.
69+
*/
70+
public class RewriteRowConstructorInToDisjunction
71+
implements Rule<FilterNode>
72+
{
73+
private static final Capture<TableScanNode> TABLE_SCAN = newCapture();
74+
private static final Pattern<FilterNode> PATTERN = filter().with(source().matching(
75+
tableScan().capturedAs(TABLE_SCAN)));
76+
private static final String PARTITIONED_BY_PROPERTY = "partitioned_by";
77+
78+
private final Metadata metadata;
79+
private final FunctionResolution functionResolution;
80+
81+
public RewriteRowConstructorInToDisjunction(Metadata metadata)
82+
{
83+
this.metadata = requireNonNull(metadata, "metadata is null");
84+
this.functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver());
85+
}
86+
87+
@Override
88+
public Pattern<FilterNode> getPattern()
89+
{
90+
return PATTERN;
91+
}
92+
93+
@Override
94+
public Result apply(FilterNode filterNode, Captures captures, Context context)
95+
{
96+
if (!isRewriteRowConstructorInToDisjunction(context.getSession())) {
97+
return Result.empty();
98+
}
99+
100+
TableScanNode tableScan = captures.get(TABLE_SCAN);
101+
Set<VariableReferenceExpression> partitionVars = resolvePartitionVariables(
102+
context.getSession(), tableScan);
103+
104+
if (partitionVars.isEmpty()) {
105+
return Result.empty();
106+
}
107+
108+
RowExpression predicate = filterNode.getPredicate();
109+
RowExpression rewritten = rewritePredicate(predicate, partitionVars);
110+
111+
if (predicate.equals(rewritten)) {
112+
return Result.empty();
113+
}
114+
115+
return Result.ofPlanNode(new FilterNode(
116+
filterNode.getSourceLocation(),
117+
filterNode.getId(),
118+
filterNode.getSource(),
119+
rewritten));
120+
}
121+
122+
private Set<VariableReferenceExpression> resolvePartitionVariables(
123+
com.facebook.presto.Session session,
124+
TableScanNode tableScan)
125+
{
126+
TableHandle tableHandle = tableScan.getTable();
127+
ConnectorTableMetadata tableMetadata;
128+
try {
129+
tableMetadata = metadata.getTableMetadata(session, tableHandle).getMetadata();
130+
}
131+
catch (RuntimeException e) {
132+
return ImmutableSet.of();
133+
}
134+
135+
Object partitionedByObj = tableMetadata.getProperties().get(PARTITIONED_BY_PROPERTY);
136+
if (!(partitionedByObj instanceof List)) {
137+
return ImmutableSet.of();
138+
}
139+
140+
@SuppressWarnings("unchecked")
141+
List<String> partitionColumnNames = (List<String>) partitionedByObj;
142+
if (partitionColumnNames.isEmpty()) {
143+
return ImmutableSet.of();
144+
}
145+
146+
Map<String, ColumnHandle> columnHandles = metadata.getColumnHandles(session, tableHandle);
147+
Set<ColumnHandle> partitionHandles = new HashSet<>();
148+
for (String name : partitionColumnNames) {
149+
ColumnHandle handle = columnHandles.get(name);
150+
if (handle != null) {
151+
partitionHandles.add(handle);
152+
}
153+
}
154+
155+
ImmutableSet.Builder<VariableReferenceExpression> result = ImmutableSet.builder();
156+
for (Map.Entry<VariableReferenceExpression, ColumnHandle> entry : tableScan.getAssignments().entrySet()) {
157+
if (partitionHandles.contains(entry.getValue())) {
158+
result.add(entry.getKey());
159+
}
160+
}
161+
return result.build();
162+
}
163+
164+
/**
165+
* Walks the predicate tree looking for rewritable ROW IN expressions.
166+
* Handles AND conjuncts at the top level and rewrites each eligible IN independently.
167+
*/
168+
private RowExpression rewritePredicate(RowExpression predicate, Set<VariableReferenceExpression> partitionVars)
169+
{
170+
if (predicate instanceof SpecialFormExpression) {
171+
SpecialFormExpression specialForm = (SpecialFormExpression) predicate;
172+
173+
if (specialForm.getForm() == IN) {
174+
RowExpression rewritten = tryRewriteRowIn(specialForm, partitionVars);
175+
if (rewritten != null) {
176+
return rewritten;
177+
}
178+
}
179+
180+
if (specialForm.getForm() == SpecialFormExpression.Form.AND) {
181+
List<RowExpression> args = specialForm.getArguments();
182+
boolean anyChanged = false;
183+
ImmutableList.Builder<RowExpression> newArgs = ImmutableList.builder();
184+
for (RowExpression arg : args) {
185+
RowExpression rewritten = rewritePredicate(arg, partitionVars);
186+
if (!rewritten.equals(arg)) {
187+
anyChanged = true;
188+
}
189+
newArgs.add(rewritten);
190+
}
191+
if (anyChanged) {
192+
return and(newArgs.build());
193+
}
194+
}
195+
}
196+
return predicate;
197+
}
198+
199+
/**
200+
* Attempts to rewrite a single SpecialFormExpression(IN, ...) where the first argument
201+
* is a ROW_CONSTRUCTOR of partition key variables and all candidates are ROW_CONSTRUCTORs
202+
* of matching arity.
203+
*
204+
* Returns the rewritten expression, or null if the pattern does not match.
205+
*/
206+
private RowExpression tryRewriteRowIn(SpecialFormExpression inExpr, Set<VariableReferenceExpression> partitionVars)
207+
{
208+
List<RowExpression> args = inExpr.getArguments();
209+
if (args.size() < 2) {
210+
return null;
211+
}
212+
213+
RowExpression target = args.get(0);
214+
if (!(target instanceof SpecialFormExpression)) {
215+
return null;
216+
}
217+
218+
SpecialFormExpression targetRow = (SpecialFormExpression) target;
219+
if (targetRow.getForm() != ROW_CONSTRUCTOR) {
220+
return null;
221+
}
222+
223+
List<RowExpression> rowFields = targetRow.getArguments();
224+
if (rowFields.isEmpty()) {
225+
return null;
226+
}
227+
228+
// All fields of the left-side ROW must be partition key VariableReferenceExpressions
229+
List<VariableReferenceExpression> fieldVars = new ArrayList<>(rowFields.size());
230+
for (RowExpression field : rowFields) {
231+
if (!(field instanceof VariableReferenceExpression)) {
232+
return null;
233+
}
234+
VariableReferenceExpression varRef = (VariableReferenceExpression) field;
235+
if (!partitionVars.contains(varRef)) {
236+
return null;
237+
}
238+
fieldVars.add(varRef);
239+
}
240+
241+
// All candidate values must be ROW_CONSTRUCTORs with matching arity
242+
int arity = rowFields.size();
243+
List<SpecialFormExpression> candidateRows = new ArrayList<>(args.size() - 1);
244+
for (int i = 1; i < args.size(); i++) {
245+
if (!(args.get(i) instanceof SpecialFormExpression)) {
246+
return null;
247+
}
248+
SpecialFormExpression candidate = (SpecialFormExpression) args.get(i);
249+
if (candidate.getForm() != ROW_CONSTRUCTOR || candidate.getArguments().size() != arity) {
250+
return null;
251+
}
252+
candidateRows.add(candidate);
253+
}
254+
255+
// Build: (pk1 = v1_1 AND pk2 = v1_2) OR (pk1 = v2_1 AND pk2 = v2_2) OR ...
256+
ImmutableList.Builder<RowExpression> disjuncts = ImmutableList.builder();
257+
for (SpecialFormExpression candidate : candidateRows) {
258+
ImmutableList.Builder<RowExpression> conjuncts = ImmutableList.builder();
259+
for (int fieldIdx = 0; fieldIdx < arity; fieldIdx++) {
260+
VariableReferenceExpression leftVar = fieldVars.get(fieldIdx);
261+
RowExpression rightVal = candidate.getArguments().get(fieldIdx);
262+
263+
conjuncts.add(new CallExpression(
264+
EQUAL.name(),
265+
functionResolution.comparisonFunction(EQUAL, leftVar.getType(), rightVal.getType()),
266+
BOOLEAN,
267+
ImmutableList.of(leftVar, rightVal)));
268+
}
269+
disjuncts.add(and(conjuncts.build()));
270+
}
271+
272+
return or(disjuncts.build());
273+
}
274+
}

presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ public void testDefaults()
264264
.setCteFilterAndProjectionPushdownEnabled(true)
265265
.setGenerateDomainFilters(false)
266266
.setRewriteExpressionWithConstantVariable(true)
267+
.setRewriteRowConstructorInToDisjunction(false)
267268
.setOptimizeConditionalApproxDistinct(true)
268269
.setDefaultWriterReplicationCoefficient(3.0)
269270
.setDefaultViewSecurityMode(DEFINER)
@@ -510,6 +511,7 @@ public void testExplicitPropertyMappings()
510511
.put("optimizer.skip-hash-generation-for-join-with-table-scan-input", "true")
511512
.put("optimizer.generate-domain-filters", "true")
512513
.put("optimizer.rewrite-expression-with-constant-variable", "false")
514+
.put("optimizer.rewrite-row-constructor-in-to-disjunction", "true")
513515
.put("optimizer.optimize-constant-approx-distinct", "false")
514516
.put("optimizer.default-writer-replication-coefficient", "5.0")
515517
.put("default-view-security-mode", INVOKER.name())
@@ -752,6 +754,7 @@ public void testExplicitPropertyMappings()
752754
.setCteFilterAndProjectionPushdownEnabled(false)
753755
.setGenerateDomainFilters(true)
754756
.setRewriteExpressionWithConstantVariable(false)
757+
.setRewriteRowConstructorInToDisjunction(true)
755758
.setOptimizeConditionalApproxDistinct(false)
756759
.setDefaultWriterReplicationCoefficient(5.0)
757760
.setDefaultViewSecurityMode(INVOKER)

0 commit comments

Comments
 (0)