Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@

import static com.facebook.presto.common.function.OperatorType.SUBSCRIPT;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.RealType.REAL;
import static com.facebook.presto.common.type.Varchars.isVarcharType;
import static com.facebook.presto.hive.HiveCommonSessionProperties.isRangeFiltersOnSubscriptsEnabled;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.DEREFERENCE;
Expand Down Expand Up @@ -96,6 +98,16 @@ private static boolean hasSubscripts(Optional<Subfield> subfield)
return subfield.isPresent() && subfield.get().getPath().stream().anyMatch(Subfield.PathElement::isSubscript);
}

private static boolean hasFloatingPointMapKey(Type type)
{
return type instanceof MapType && isFloatingPointType(((MapType) type).getKeyType());
}

private static boolean isFloatingPointType(Type type)
{
return type.equals(DOUBLE) || type.equals(REAL);
}

public Optional<Subfield> extract(RowExpression expression)
{
return toSubfield(expression, functionResolution, expressionOptimizer, connectorSession);
Expand Down Expand Up @@ -148,6 +160,9 @@ private static Optional<Subfield> toSubfield(
if (indexExpression instanceof ConstantExpression) {
Object index = ((ConstantExpression) indexExpression).getValue();
if (index instanceof Number) {
if (hasFloatingPointMapKey(arguments.get(0).getType())) {
return Optional.empty();
}
elements.add(new Subfield.LongSubscript(((Number) index).longValue()));
expression = arguments.get(0);
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@
import static com.facebook.presto.common.predicate.TupleDomain.withColumnDomains;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.RealType.REAL;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.hive.HiveTestUtils.mapType;
Expand All @@ -71,13 +73,16 @@
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.Expressions.specialForm;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.Float.floatToRawIntBits;
import static org.testng.Assert.assertEquals;

public class TestDomainTranslator
{
private static final VariableReferenceExpression C_BIGINT = new VariableReferenceExpression(Optional.empty(), "c_bigint", BIGINT);
private static final VariableReferenceExpression C_BIGINT_ARRAY = new VariableReferenceExpression(Optional.empty(), "c_bigint_array", new ArrayType(BIGINT));
private static final VariableReferenceExpression C_BIGINT_TO_BIGINT_MAP = new VariableReferenceExpression(Optional.empty(), "c_bigint_to_bigint_map", mapType(BIGINT, BIGINT));
private static final VariableReferenceExpression C_DOUBLE_TO_BIGINT_MAP = new VariableReferenceExpression(Optional.empty(), "c_double_to_bigint_map", mapType(DOUBLE, BIGINT));
private static final VariableReferenceExpression C_REAL_TO_BIGINT_MAP = new VariableReferenceExpression(Optional.empty(), "c_real_to_bigint_map", mapType(REAL, BIGINT));
private static final VariableReferenceExpression C_VARCHAR_TO_BIGINT_MAP = new VariableReferenceExpression(Optional.empty(), "c_varchar_to_bigint_map", mapType(VARCHAR, BIGINT));
private static final VariableReferenceExpression C_STRUCT = new VariableReferenceExpression(Optional.empty(), "c_struct", RowType.from(ImmutableList.of(
RowType.field("a", BIGINT),
Expand Down Expand Up @@ -165,6 +170,13 @@ public void testSubfields()
assertPredicateDoesNotTranslate(equal(C_BIGINT_TO_BIGINT_MAP, createConstantExpression(createMapBlock(mapType, ImmutableMap.of(1, 100)), mapType)));
}

@Test
public void testFloatingPointMapKeysDoNotTranslate()
{
assertPredicateDoesNotTranslate(equal(mapSubscript(C_DOUBLE_TO_BIGINT_MAP, constant(0.99, DOUBLE)), bigintLiteral(2L)));
assertPredicateDoesNotTranslate(equal(mapSubscript(C_REAL_TO_BIGINT_MAP, constant((long) floatToRawIntBits(0.99f), REAL)), bigintLiteral(2L)));
}

private RowExpression dereference(RowExpression base, int field)
{
Type fieldType = base.getType().getTypeParameters().get(field);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotEquals;
import static org.testng.Assert.assertNotSame;
import static org.testng.Assert.assertTrue;

Expand Down Expand Up @@ -998,6 +999,8 @@ public void testPushdownFilterOnSubfields()
"id bigint, " +
"a array(bigint), " +
"b map(varchar, bigint), " +
"f map(double, bigint), " +
"g map(real, bigint), " +
"c row(" +
"a bigint, " +
"b row(x bigint), " +
Expand Down Expand Up @@ -1035,6 +1038,9 @@ public void testPushdownFilterOnSubfields()
assertPushdownFilterOnSubfields("SELECT * FROM test_pushdown_filter_on_subfields WHERE c.e['foo'] = 1",
ImmutableMap.of(new Subfield("c.e[\"foo\"]"), singleValue(BIGINT, 1L)));

assertNoPushdownFilterOnSubfields("SELECT * FROM test_pushdown_filter_on_subfields WHERE f[CAST(0.99 AS DOUBLE)] = 1", "f");
assertNoPushdownFilterOnSubfields("SELECT * FROM test_pushdown_filter_on_subfields WHERE g[CAST(0.99 AS REAL)] = 1", "g");

assertPushdownFilterOnSubfields("SELECT * FROM test_pushdown_filter_on_subfields WHERE c.a IS NOT NULL AND c.c IS NOT NULL",
ImmutableMap.of(new Subfield("c.a"), notNull(BIGINT), new Subfield("c.c"), notNull(new ArrayType(BIGINT))));

Expand Down Expand Up @@ -1089,6 +1095,8 @@ public void testPushdownMapSubscripts()
"a map(bigint, bigint), " +
"b map(bigint, map(bigint, varchar)), " +
"c map(varchar, bigint), \n" +
"d map(double, bigint), \n" +
"e map(real, bigint), \n" +
"y map(bigint, row(a bigint, b varchar, c double, d row(d1 bigint, d2 double)))," +
"z map(bigint, map(bigint, row(p bigint, e row(e1 bigint, e2 varchar)))))");

Expand Down Expand Up @@ -1117,6 +1125,12 @@ public void testPushdownMapSubscripts()
assertPushdownSubfields("SELECT mod(c['cat'], 2) FROM test_pushdown_map_subscripts WHERE c['dog'] > 10", "test_pushdown_map_subscripts",
ImmutableMap.of("c", toSubfields("c[\"cat\"]", "c[\"dog\"]")));

assertPushdownSubfields("SELECT d[CAST(0.99 AS DOUBLE)] FROM test_pushdown_map_subscripts", "test_pushdown_map_subscripts",
ImmutableMap.of());

assertPushdownSubfields("SELECT e[CAST(0.99 AS REAL)] FROM test_pushdown_map_subscripts", "test_pushdown_map_subscripts",
ImmutableMap.of());

// No subfield pruning
assertPushdownSubfields("SELECT map_keys(a)[1] FROM test_pushdown_map_subscripts", "test_pushdown_map_subscripts",
ImmutableMap.of());
Expand Down Expand Up @@ -1501,6 +1515,11 @@ public void testPushdownSubfieldsForMapSubset()
ImmutableMap.of("x", toSubfields()));
assertUpdate("DROP TABLE test_pushdown_map_subfields");

assertUpdate("CREATE TABLE test_pushdown_map_subfields(id integer, x map(double, double))");
assertPushdownSubfields(mapSubset, "SELECT t.id, map_subset(x, array[CAST(0.99 AS DOUBLE), CAST(1.01 AS DOUBLE)]) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields",
ImmutableMap.of());
assertUpdate("DROP TABLE test_pushdown_map_subfields");

assertUpdate("CREATE TABLE test_pushdown_map_subfields(id integer, x array(map(integer, double)))");
assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_subset(mp, array[1, 2, 3])) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields",
ImmutableMap.of("x", toSubfields("x[*][1]", "x[*][2]", "x[*][3]")));
Expand Down Expand Up @@ -1571,6 +1590,11 @@ public void testPushdownSubfieldsForMapFilter()
ImmutableMap.of("x", toSubfields()));
assertUpdate("DROP TABLE test_pushdown_map_subfields");

assertUpdate("CREATE TABLE test_pushdown_map_subfields(id integer, x map(double, double))");
assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> k = CAST(0.99 AS DOUBLE)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields",
ImmutableMap.of());
assertUpdate("DROP TABLE test_pushdown_map_subfields");

assertUpdate("CREATE TABLE test_pushdown_map_subfields(id integer, x array(map(integer, double)))");
assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_filter(mp, (k, v) -> k in (1, 2, 3))) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields",
ImmutableMap.of("x", toSubfields("x[*][1]", "x[*][2]", "x[*][3]")));
Expand Down Expand Up @@ -2592,6 +2616,25 @@ private void assertPushdownFilterOnSubfields(String query, Map<Subfield, Domain>
predicateDomains.keySet().stream().map(Subfield::getRootName).collect(toImmutableSet())));
}

private void assertNoPushdownFilterOnSubfields(String query, String predicateColumnName)
{
String tableName = "test_pushdown_filter_on_subfields";
assertPlan(pushdownFilterAndNestedColumnFilterEnabled(), query,
output(exchange(PlanMatchPattern.tableScan(tableName))),
plan -> {
TableScanNode tableScan = searchFrom(plan.getRoot())
.where(node -> isTableScanNode(node, tableName))
.findOnlyElement();

assertTrue(tableScan.getTable().getLayout().isPresent());
HiveTableLayoutHandle layoutHandle = (HiveTableLayoutHandle) tableScan.getTable().getLayout().get();

assertEquals(layoutHandle.getPredicateColumns().keySet(), ImmutableSet.of(predicateColumnName));
assertEquals(layoutHandle.getDomainPredicate(), TupleDomain.all());
assertNotEquals(layoutHandle.getRemainingPredicate(), TRUE_CONSTANT);
});
}

private void assertParquetDereferencePushDown(String query, String tableName, Map<String, Subfield> expectedDeferencePushDowns)
{
assertParquetDereferencePushDown(withParquetDereferencePushDownEnabled(), query, tableName, expectedDeferencePushDowns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,16 @@
import static com.facebook.presto.sql.relational.Expressions.specialForm;
import static com.facebook.presto.testing.assertions.Assert.assertEquals;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.Float.floatToRawIntBits;
import static org.testng.Assert.assertTrue;

public class TestSubfieldExtractor
{
private static final VariableReferenceExpression C_BIGINT = new VariableReferenceExpression(Optional.empty(), "c_bigint", BIGINT);
private static final VariableReferenceExpression C_BIGINT_ARRAY = new VariableReferenceExpression(Optional.empty(), "c_bigint_array", new ArrayType(BIGINT));
private static final VariableReferenceExpression C_BIGINT_TO_BIGINT_MAP = new VariableReferenceExpression(Optional.empty(), "c_bigint_to_bigint_map", mapType(BIGINT, BIGINT));
private static final VariableReferenceExpression C_DOUBLE_TO_BIGINT_MAP = new VariableReferenceExpression(Optional.empty(), "c_double_to_bigint_map", mapType(DOUBLE, BIGINT));
private static final VariableReferenceExpression C_REAL_TO_BIGINT_MAP = new VariableReferenceExpression(Optional.empty(), "c_real_to_bigint_map", mapType(REAL, BIGINT));
private static final VariableReferenceExpression C_VARCHAR_TO_BIGINT_MAP = new VariableReferenceExpression(Optional.empty(), "c_varchar_to_bigint_map", mapType(VARCHAR, BIGINT));
private static final VariableReferenceExpression C_STRUCT = new VariableReferenceExpression(Optional.empty(), "c_struct", RowType.from(ImmutableList.of(
RowType.field("a", BIGINT),
Expand Down Expand Up @@ -115,6 +118,13 @@ public void test()
assertEquals(subfieldExtractor.extract(constant(2L, INTEGER)), Optional.empty());
}

@Test
public void testFloatingPointMapKeys()
{
assertEquals(subfieldExtractor.extract(mapSubscript(C_DOUBLE_TO_BIGINT_MAP, constant(0.99, DOUBLE))), Optional.empty());
assertEquals(subfieldExtractor.extract(mapSubscript(C_REAL_TO_BIGINT_MAP, constant((long) floatToRawIntBits(0.99f), REAL))), Optional.empty());
}

@Test
public void testToRowExpression()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import com.facebook.presto.sql.planner.plan.UpdateNode;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.tree.QualifiedName;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
Expand All @@ -101,6 +102,8 @@
import static com.facebook.presto.common.Subfield.allSubscripts;
import static com.facebook.presto.common.Subfield.noSubfield;
import static com.facebook.presto.common.Subfield.structureOnly;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.RealType.REAL;
import static com.facebook.presto.common.type.TypeUtils.readNativeValue;
import static com.facebook.presto.common.type.Varchars.isVarcharType;
import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE;
Expand Down Expand Up @@ -156,6 +159,24 @@ public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider
return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged());
}

@VisibleForTesting
static Optional<List<Subfield>> toSubfield(
RowExpression expression,
FunctionResolution functionResolution,
ExpressionOptimizer expressionOptimizer,
ConnectorSession connectorSession,
FunctionAndTypeManager functionAndTypeManager,
boolean isPushdownSubfieldsForMapFunctionsEnabled)
{
return Rewriter.toSubfield(
expression,
functionResolution,
expressionOptimizer,
connectorSession,
functionAndTypeManager,
isPushdownSubfieldsForMapFunctionsEnabled);
}

private static class Rewriter
extends SimplePlanRewriter<Rewriter.Context>
{
Expand Down Expand Up @@ -652,7 +673,8 @@ private static String getColumnName(Session session, Metadata metadata, TableHan
return metadata.getColumnMetadata(session, tableHandle, columnHandle).getName();
}

private static Optional<List<Subfield>> toSubfield(
@VisibleForTesting
static Optional<List<Subfield>> toSubfield(
RowExpression expression,
FunctionResolution functionResolution,
ExpressionOptimizer expressionOptimizer,
Expand Down Expand Up @@ -720,6 +742,9 @@ private static Optional<List<Subfield>> toSubfield(
if (((Number) index).longValue() < 0 && arguments.get(0).getType() instanceof ArrayType) {
return Optional.empty();
}
if (hasFloatingPointMapKey(arguments.get(0).getType())) {
return Optional.empty();
}

elements.add(new Subfield.LongSubscript(((Number) index).longValue()));
expression = arguments.get(0);
Expand Down Expand Up @@ -784,6 +809,9 @@ private static Optional<List<Subfield>> extractSubfieldsFromArray(ConstantExpres
{
ImmutableList.Builder<Subfield> arguments = ImmutableList.builder();
checkState(constantArray.getValue() instanceof Block && constantArray.getType() instanceof ArrayType);
if (hasFloatingPointMapKey(mapVariable.getType())) {
return Optional.empty();
}
Block arrayValue = (Block) constantArray.getValue();
Type arrayElementType = ((ArrayType) constantArray.getType()).getElementType();
for (int i = 0; i < arrayValue.getPositionCount(); ++i) {
Expand All @@ -803,6 +831,9 @@ private static Optional<List<Subfield>> extractSubfieldsFromArray(ConstantExpres

private static Optional<Subfield> extractSubfieldsFromSingleValue(ConstantExpression mapKey, VariableReferenceExpression mapVariable)
{
if (hasFloatingPointMapKey(mapVariable.getType())) {
return Optional.empty();
}
Object value = mapKey.getValue();
if (value == null) {
return Optional.empty();
Expand All @@ -816,6 +847,16 @@ private static Optional<Subfield> extractSubfieldsFromSingleValue(ConstantExpres
return Optional.empty();
}

private static boolean hasFloatingPointMapKey(Type type)
{
return type instanceof MapType && isFloatingPointType(((MapType) type).getKeyType());
}

private static boolean isFloatingPointType(Type type)
{
return type.equals(DOUBLE) || type.equals(REAL);
}

private static NestedField nestedField(String name)
{
return new NestedField(name.toLowerCase(Locale.ENGLISH));
Expand Down
Loading
Loading