Skip to content

Commit 085f117

Browse files
authored
[Kernel] Change comparator expression to lazy evaluation (#2853)
## Description Resolves #2541 ## How was this patch tested? Existing tests
1 parent 39e91af commit 085f117

File tree

2 files changed

+88
-128
lines changed

2 files changed

+88
-128
lines changed

kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,9 @@
3535
import io.delta.kernel.defaults.internal.data.vector.DefaultBooleanVector;
3636
import io.delta.kernel.defaults.internal.data.vector.DefaultConstantVector;
3737
import static io.delta.kernel.defaults.internal.DefaultEngineErrors.unsupportedExpressionException;
38+
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.*;
3839
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.booleanWrapperVector;
3940
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.childAt;
40-
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.compare;
41-
import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.evalNullability;
4241
import static io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo;
4342

4443
/**
@@ -421,44 +420,37 @@ ColumnVector visitAlwaysFalse(AlwaysFalse alwaysFalse) {
421420
@Override
422421
ColumnVector visitComparator(Predicate predicate) {
423422
PredicateChildrenEvalResult argResults = evalBinaryExpressionChildren(predicate);
424-
425-
int numRows = argResults.rowCount;
426-
boolean[] result = new boolean[numRows];
427-
boolean[] nullability = evalNullability(argResults.leftResult, argResults.rightResult);
428-
int[] compareResult = compare(argResults.leftResult, argResults.rightResult);
429423
switch (predicate.getName()) {
430424
case "=":
431-
for (int rowId = 0; rowId < numRows; rowId++) {
432-
result[rowId] = compareResult[rowId] == 0;
433-
}
434-
break;
425+
return comparatorVector(
426+
argResults.leftResult,
427+
argResults.rightResult,
428+
(compareResult) -> (compareResult == 0));
435429
case ">":
436-
for (int rowId = 0; rowId < numRows; rowId++) {
437-
result[rowId] = compareResult[rowId] > 0;
438-
}
439-
break;
430+
return comparatorVector(
431+
argResults.leftResult,
432+
argResults.rightResult,
433+
(compareResult) -> (compareResult > 0));
440434
case ">=":
441-
for (int rowId = 0; rowId < numRows; rowId++) {
442-
result[rowId] = compareResult[rowId] >= 0;
443-
}
444-
break;
435+
return comparatorVector(
436+
argResults.leftResult,
437+
argResults.rightResult,
438+
(compareResult) -> (compareResult >= 0));
445439
case "<":
446-
for (int rowId = 0; rowId < numRows; rowId++) {
447-
result[rowId] = compareResult[rowId] < 0;
448-
}
449-
break;
440+
return comparatorVector(
441+
argResults.leftResult,
442+
argResults.rightResult,
443+
(compareResult) -> (compareResult < 0));
450444
case "<=":
451-
for (int rowId = 0; rowId < numRows; rowId++) {
452-
result[rowId] = compareResult[rowId] <= 0;
453-
}
454-
break;
445+
return comparatorVector(
446+
argResults.leftResult,
447+
argResults.rightResult,
448+
(compareResult) -> (compareResult <= 0));
455449
default:
456450
// We should never reach this based on the ExpressionVisitor
457451
throw new IllegalStateException(
458452
String.format("%s is not a recognized comparator", predicate.getName()));
459453
}
460-
461-
return new DefaultBooleanVector(numRows, Optional.of(nullability), result);
462454
}
463455

464456
@Override

kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java

Lines changed: 67 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.Comparator;
2020
import java.util.List;
2121
import java.util.function.Function;
22+
import java.util.function.IntPredicate;
2223
import java.util.stream.Collectors;
2324

2425
import io.delta.kernel.data.ArrayValue;
@@ -33,6 +34,20 @@
3334
* Utility methods used by the default expression evaluator.
3435
*/
3536
class DefaultExpressionUtils {
37+
38+
static final Comparator<BigDecimal> BIGDECIMAL_COMPARATOR = Comparator.naturalOrder();
39+
static final Comparator<String> STRING_COMPARATOR = Comparator.naturalOrder();
40+
static final Comparator<byte[]> BINARY_COMPARTOR = (leftOp, rightOp) -> {
41+
int i = 0;
42+
while (i < leftOp.length && i < rightOp.length) {
43+
if (leftOp[i] != rightOp[i]) {
44+
return Byte.compare(leftOp[i], rightOp[i]);
45+
}
46+
i++;
47+
}
48+
return Integer.compare(leftOp.length, rightOp.length);
49+
};
50+
3651
private DefaultExpressionUtils() {}
3752

3853
/**
@@ -87,138 +102,91 @@ public boolean getBoolean(int rowId) {
87102
}
88103

89104
/**
90-
* Utility method to compare the left and right according to the natural ordering
91-
* and return an integer array where each row contains the comparison result (-1, 0, 1) for
92-
* corresponding rows in the input vectors compared.
105+
* Utility method to create a column vector that lazily evaluate the
106+
* comparator ex. (ie. ==, >=, <=......) for left and right
107+
* column vector according to the natural ordering of numbers
93108
* <p>
94109
* Only primitive data types are supported.
95110
*/
96-
static int[] compare(ColumnVector left, ColumnVector right) {
111+
static ColumnVector comparatorVector(
112+
ColumnVector left,
113+
ColumnVector right,
114+
IntPredicate booleanComparator) {
97115
checkArgument(
98-
left.getSize() == right.getSize(),
99-
"Left and right operand have different vector sizes.");
100-
DataType dataType = left.getDataType();
116+
left.getSize() == right.getSize(),
117+
"Left and right operand have different vector sizes.");
101118

102-
int numRows = left.getSize();
103-
int[] result = new int[numRows];
119+
DataType dataType = left.getDataType();
120+
IntPredicate vectorValueComparator;
104121
if (dataType instanceof BooleanType) {
105-
compareBoolean(left, right, result);
122+
vectorValueComparator = rowId -> booleanComparator.test(
123+
Boolean.compare(left.getBoolean(rowId), right.getBoolean(rowId)));
106124
} else if (dataType instanceof ByteType) {
107-
compareByte(left, right, result);
125+
vectorValueComparator = rowId -> booleanComparator.test(
126+
Byte.compare(left.getByte(rowId), right.getByte(rowId)));
108127
} else if (dataType instanceof ShortType) {
109-
compareShort(left, right, result);
128+
vectorValueComparator = rowId -> booleanComparator.test(
129+
Short.compare(left.getShort(rowId), right.getShort(rowId)));
110130
} else if (dataType instanceof IntegerType || dataType instanceof DateType) {
111-
compareInt(left, right, result);
131+
vectorValueComparator = rowId -> booleanComparator.test(
132+
Integer.compare(left.getInt(rowId), right.getInt(rowId)));
112133
} else if (dataType instanceof LongType ||
113134
dataType instanceof TimestampType ||
114135
dataType instanceof TimestampNTZType) {
115-
compareLong(left, right, result);
136+
vectorValueComparator = rowId -> booleanComparator.test(
137+
Long.compare(left.getLong(rowId), right.getLong(rowId)));
116138
} else if (dataType instanceof FloatType) {
117-
compareFloat(left, right, result);
139+
vectorValueComparator = rowId -> booleanComparator.test(
140+
Float.compare(left.getFloat(rowId), right.getFloat(rowId)));
118141
} else if (dataType instanceof DoubleType) {
119-
compareDouble(left, right, result);
142+
vectorValueComparator = rowId -> booleanComparator.test(
143+
Double.compare(left.getDouble(rowId), right.getDouble(rowId)));
120144
} else if (dataType instanceof DecimalType) {
121-
compareDecimal(left, right, result);
145+
vectorValueComparator = rowId -> booleanComparator.test(
146+
BIGDECIMAL_COMPARATOR.compare(
147+
left.getDecimal(rowId), right.getDecimal(rowId)));
122148
} else if (dataType instanceof StringType) {
123-
compareString(left, right, result);
149+
vectorValueComparator = rowId -> booleanComparator.test(
150+
STRING_COMPARATOR.compare(
151+
left.getString(rowId), right.getString(rowId)));
124152
} else if (dataType instanceof BinaryType) {
125-
compareBinary(left, right, result);
153+
vectorValueComparator = rowId -> booleanComparator.test(
154+
BINARY_COMPARTOR.compare(
155+
left.getBinary(rowId), right.getBinary(rowId)));
126156
} else {
127157
throw new UnsupportedOperationException(dataType + " can not be compared.");
128158
}
129-
return result;
130-
}
131-
132-
static void compareBoolean(ColumnVector left, ColumnVector right, int[] result) {
133-
for (int rowId = 0; rowId < left.getSize(); rowId++) {
134-
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
135-
result[rowId] = Boolean.compare(left.getBoolean(rowId), right.getBoolean(rowId));
136-
}
137-
}
138-
}
139159

140-
static void compareByte(ColumnVector left, ColumnVector right, int[] result) {
141-
for (int rowId = 0; rowId < left.getSize(); rowId++) {
142-
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
143-
result[rowId] = Byte.compare(left.getByte(rowId), right.getByte(rowId));
144-
}
145-
}
146-
}
147-
148-
static void compareShort(ColumnVector left, ColumnVector right, int[] result) {
149-
for (int rowId = 0; rowId < left.getSize(); rowId++) {
150-
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
151-
result[rowId] = Short.compare(left.getShort(rowId), right.getShort(rowId));
152-
}
153-
}
154-
}
155-
156-
static void compareInt(ColumnVector left, ColumnVector right, int[] result) {
157-
for (int rowId = 0; rowId < left.getSize(); rowId++) {
158-
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
159-
result[rowId] = Integer.compare(left.getInt(rowId), right.getInt(rowId));
160-
}
161-
}
162-
}
163-
164-
static void compareLong(ColumnVector left, ColumnVector right, int[] result) {
165-
for (int rowId = 0; rowId < left.getSize(); rowId++) {
166-
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
167-
result[rowId] = Long.compare(left.getLong(rowId), right.getLong(rowId));
168-
}
169-
}
170-
}
160+
return new ColumnVector() {
171161

172-
static void compareFloat(ColumnVector left, ColumnVector right, int[] result) {
173-
for (int rowId = 0; rowId < left.getSize(); rowId++) {
174-
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
175-
result[rowId] = Float.compare(left.getFloat(rowId), right.getFloat(rowId));
162+
@Override
163+
public DataType getDataType() {
164+
return BooleanType.BOOLEAN;
176165
}
177-
}
178-
}
179166

180-
static void compareDouble(ColumnVector left, ColumnVector right, int[] result) {
181-
for (int rowId = 0; rowId < left.getSize(); rowId++) {
182-
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
183-
result[rowId] = Double.compare(left.getDouble(rowId), right.getDouble(rowId));
167+
@Override
168+
public void close() {
169+
Utils.closeCloseables(left, right);
184170
}
185-
}
186-
}
187171

188-
static void compareString(ColumnVector left, ColumnVector right, int[] result) {
189-
Comparator<String> comparator = Comparator.naturalOrder();
190-
for (int rowId = 0; rowId < left.getSize(); rowId++) {
191-
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
192-
result[rowId] = comparator.compare(left.getString(rowId), right.getString(rowId));
172+
@Override
173+
public int getSize() {
174+
return left.getSize();
193175
}
194-
}
195-
}
196176

197-
static void compareDecimal(ColumnVector left, ColumnVector right, int[] result) {
198-
Comparator<BigDecimal> comparator = Comparator.naturalOrder();
199-
for (int rowId = 0; rowId < left.getSize(); rowId++) {
200-
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
201-
result[rowId] = comparator.compare(left.getDecimal(rowId), right.getDecimal(rowId));
177+
@Override
178+
public boolean isNullAt(int rowId) {
179+
return left.isNullAt(rowId) || right.isNullAt(rowId);
202180
}
203-
}
204-
}
205181

206-
static void compareBinary(ColumnVector left, ColumnVector right, int[] result) {
207-
Comparator<byte[]> comparator = (leftOp, rightOp) -> {
208-
int i = 0;
209-
while (i < leftOp.length && i < rightOp.length) {
210-
if (leftOp[i] != rightOp[i]) {
211-
return Byte.compare(leftOp[i], rightOp[i]);
182+
@Override
183+
public boolean getBoolean(int rowId) {
184+
if (isNullAt(rowId)) {
185+
return false;
212186
}
213-
i++;
187+
return vectorValueComparator.test(rowId);
214188
}
215-
return Integer.compare(leftOp.length, rightOp.length);
216189
};
217-
for (int rowId = 0; rowId < left.getSize(); rowId++) {
218-
if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) {
219-
result[rowId] = comparator.compare(left.getBinary(rowId), right.getBinary(rowId));
220-
}
221-
}
222190
}
223191

224192
static Expression childAt(Expression expression, int index) {

0 commit comments

Comments
 (0)