Skip to content

Commit 5b65cbd

Browse files
ceekay47facebook-github-bot
authored andcommitted
feat(planner): Support GROUP BY and ORDER BY ordinals in MV query rewriting (prestodb#27422)
Summary: Queries using GROUP BY/ORDER BY ordinals (e.g. `GROUP BY 1`) silently fell back to the base table because the MV optimizer runs before the analyzer resolves ordinals to column references. Fix by resolving ordinals to SELECT expressions during MV validation and passing them through unchanged during rewriting. ``` == RELEASE NOTES == General Changes * Add support for ``GROUP BY`` and ``ORDER BY`` ordinal references in materialized view query rewriting. Previously, queries like ``SELECT a, SUM(b) FROM t GROUP BY 1`` would silently skip materialized view optimization. ``` Differential Revision: D97920227
1 parent f4b55fe commit 5b65cbd

File tree

2 files changed

+69
-6
lines changed

2 files changed

+69
-6
lines changed

presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MaterializedViewQueryOptimizer.java

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import com.facebook.presto.sql.tree.Join;
5757
import com.facebook.presto.sql.tree.Lateral;
5858
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
59+
import com.facebook.presto.sql.tree.LongLiteral;
5960
import com.facebook.presto.sql.tree.Node;
6061
import com.facebook.presto.sql.tree.OrderBy;
6162
import com.facebook.presto.sql.tree.QualifiedName;
@@ -107,6 +108,7 @@
107108
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED;
108109
import static com.facebook.presto.sql.relational.Expressions.call;
109110
import static com.facebook.presto.util.AnalyzerUtil.createParsingOptions;
111+
import static java.lang.Math.toIntExact;
110112
import static java.lang.String.format;
111113
import static java.util.Objects.requireNonNull;
112114

@@ -463,18 +465,35 @@ protected Node visitQuerySpecification(QuerySpecification node, Void context)
463465
removablePrefix = Optional.of(new Identifier(baseTable.getName().toString()));
464466
}
465467
if (node.getGroupBy().isPresent()) {
468+
List<SelectItem> selectItems = node.getSelect().getSelectItems();
466469
ImmutableSet.Builder<Expression> expressionsInGroupByBuilder = ImmutableSet.builder();
467470
for (GroupingElement element : node.getGroupBy().get().getGroupingElements()) {
468471
element = removeGroupingElementPrefix(element, removablePrefix);
469472
Optional<Set<Expression>> groupByOfMaterializedView = materializedViewInfo.getGroupBy();
470473
if (groupByOfMaterializedView.isPresent()) {
471474
for (Expression expression : element.getExpressions()) {
472-
if (!groupByOfMaterializedView.get().contains(expression) || !materializedViewInfo.getBaseToViewColumnMap().containsKey(expression)) {
475+
// Resolve ordinal references (e.g. GROUP BY 1) to the corresponding SELECT expression
476+
Expression resolved = expression;
477+
if (expression instanceof LongLiteral) {
478+
int ordinal = toIntExact(((LongLiteral) expression).getValue());
479+
SelectItem selectItem = selectItems.get(ordinal - 1);
480+
if (selectItem instanceof SingleColumn) {
481+
resolved = removeExpressionPrefix(((SingleColumn) selectItem).getExpression(), removablePrefix);
482+
}
483+
else {
484+
throw new IllegalStateException("GROUP BY ordinal references non-single-column select item");
485+
}
486+
}
487+
if (!groupByOfMaterializedView.get().contains(resolved) || !materializedViewInfo.getBaseToViewColumnMap().containsKey(resolved)) {
473488
throw new IllegalStateException(format("Grouping element %s is not present in materialized view groupBy field", element));
474489
}
490+
// Store the resolved expression so visitSingleColumn can match against it
491+
expressionsInGroupByBuilder.add(resolved);
475492
}
476493
}
477-
expressionsInGroupByBuilder.addAll(element.getExpressions());
494+
else {
495+
expressionsInGroupByBuilder.addAll(element.getExpressions());
496+
}
478497
}
479498
expressionsInGroupBy = Optional.of(expressionsInGroupByBuilder.build());
480499
}
@@ -683,8 +702,11 @@ protected Node visitOrderBy(OrderBy node, Void context)
683702
ImmutableList.Builder<SortItem> rewrittenOrderBy = ImmutableList.builder();
684703
for (SortItem sortItem : node.getSortItems()) {
685704
sortItem = removeSortItemPrefix(sortItem, removablePrefix);
686-
if (!materializedViewInfo.getBaseToViewColumnMap().containsKey(sortItem.getSortKey())) {
687-
throw new IllegalStateException(format("Sort key %s is not present in materialized view select fields", sortItem.getSortKey()));
705+
// Ordinal references (e.g. ORDER BY 3) refer to SELECT items which are already validated
706+
Expression sortKey = sortItem.getSortKey();
707+
if (!(sortKey instanceof LongLiteral)
708+
&& !materializedViewInfo.getBaseToViewColumnMap().containsKey(sortKey)) {
709+
throw new IllegalStateException(format("Sort key %s is not present in materialized view select fields", sortKey));
688710
}
689711
rewrittenOrderBy.add((SortItem) process(sortItem, context));
690712
}
@@ -694,15 +716,26 @@ protected Node visitOrderBy(OrderBy node, Void context)
694716
@Override
695717
protected Node visitSortItem(SortItem node, Void context)
696718
{
697-
return new SortItem((Expression) process(node.getSortKey(), context), node.getOrdering(), node.getNullOrdering());
719+
Expression sortKey = node.getSortKey();
720+
// Ordinal references (e.g. ORDER BY 1) refer to SELECT positions which are already rewritten; pass through unchanged
721+
if (sortKey instanceof LongLiteral) {
722+
return node;
723+
}
724+
return new SortItem((Expression) process(sortKey, context), node.getOrdering(), node.getNullOrdering());
698725
}
699726

700727
@Override
701728
protected Node visitSimpleGroupBy(SimpleGroupBy node, Void context)
702729
{
703730
ImmutableList.Builder<Expression> rewrittenSimpleGroupBy = ImmutableList.builder();
704731
for (Expression column : node.getExpressions()) {
705-
rewrittenSimpleGroupBy.add((Expression) process(removeExpressionPrefix(column, removablePrefix), context));
732+
// Ordinal references (e.g. GROUP BY 1) refer to SELECT positions which are already rewritten; pass through unchanged
733+
if (column instanceof LongLiteral) {
734+
rewrittenSimpleGroupBy.add(column);
735+
}
736+
else {
737+
rewrittenSimpleGroupBy.add((Expression) process(removeExpressionPrefix(column, removablePrefix), context));
738+
}
706739
}
707740
return new SimpleGroupBy(rewrittenSimpleGroupBy.build());
708741
}

presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestMaterializedViewQueryOptimizer.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,36 @@ public void testWithOrderBy()
255255
assertOptimizedQuery(baseQuerySql, expectedRewrittenSql, originalViewSql, BASE_TABLE_1, VIEW_1);
256256
}
257257

258+
@Test
259+
public void testWithGroupByOrdinals()
260+
{
261+
String originalViewSql = format("SELECT a as mv_a, b, c as mv_c FROM %s", BASE_TABLE_1);
262+
String baseQuerySql = format("SELECT SUM(a * b), MAX(a + b), c FROM %s GROUP BY 3", BASE_TABLE_1);
263+
String expectedRewrittenSql = format("SELECT SUM(mv_a * b), MAX(mv_a + b), mv_c as c FROM %s GROUP BY 3", VIEW_1);
264+
265+
assertOptimizedQuery(baseQuerySql, expectedRewrittenSql, originalViewSql, BASE_TABLE_1, VIEW_1);
266+
}
267+
268+
@Test
269+
public void testWithOrderByOrdinals()
270+
{
271+
String originalViewSql = format("SELECT a as mv_a, b, c as mv_c FROM %s", BASE_TABLE_1);
272+
String baseQuerySql = format("SELECT a, b, c FROM %s ORDER BY 3 ASC, 2 DESC, 1", BASE_TABLE_1);
273+
String expectedRewrittenSql = format("SELECT mv_a as a, b, mv_c as c FROM %s ORDER BY 3 ASC, 2 DESC, 1", VIEW_1);
274+
275+
assertOptimizedQuery(baseQuerySql, expectedRewrittenSql, originalViewSql, BASE_TABLE_1, VIEW_1);
276+
}
277+
278+
@Test
279+
public void testWithGroupByAndOrderByOrdinals()
280+
{
281+
String originalViewSql = format("SELECT MAX(a) as mv_max_a, b FROM %s GROUP BY b", BASE_TABLE_1);
282+
String baseQuerySql = format("SELECT MAX(a), b FROM %s GROUP BY 2 ORDER BY 1 DESC, 2 ASC", BASE_TABLE_1);
283+
String expectedRewrittenSql = format("SELECT MAX(mv_max_a), b FROM %s GROUP BY 2 ORDER BY 1 DESC, 2 ASC", VIEW_1);
284+
285+
assertOptimizedQuery(baseQuerySql, expectedRewrittenSql, originalViewSql, BASE_TABLE_1, VIEW_1);
286+
}
287+
258288
@Test
259289
public void testWithNoMatchingBaseTable()
260290
{

0 commit comments

Comments
 (0)