|
19 | 19 | import java.util.Comparator; |
20 | 20 | import java.util.List; |
21 | 21 | import java.util.function.Function; |
| 22 | +import java.util.function.IntPredicate; |
22 | 23 | import java.util.stream.Collectors; |
23 | 24 |
|
24 | 25 | import io.delta.kernel.data.ArrayValue; |
|
33 | 34 | * Utility methods used by the default expression evaluator. |
34 | 35 | */ |
35 | 36 | class DefaultExpressionUtils { |
| 37 | + |
| 38 | + static final Comparator<BigDecimal> BIGDECIMAL_COMPARATOR = Comparator.naturalOrder(); |
| 39 | + static final Comparator<String> STRING_COMPARATOR = Comparator.naturalOrder(); |
| 40 | + static final Comparator<byte[]> BINARY_COMPARTOR = (leftOp, rightOp) -> { |
| 41 | + int i = 0; |
| 42 | + while (i < leftOp.length && i < rightOp.length) { |
| 43 | + if (leftOp[i] != rightOp[i]) { |
| 44 | + return Byte.compare(leftOp[i], rightOp[i]); |
| 45 | + } |
| 46 | + i++; |
| 47 | + } |
| 48 | + return Integer.compare(leftOp.length, rightOp.length); |
| 49 | + }; |
| 50 | + |
36 | 51 | private DefaultExpressionUtils() {} |
37 | 52 |
|
38 | 53 | /** |
@@ -87,138 +102,91 @@ public boolean getBoolean(int rowId) { |
87 | 102 | } |
88 | 103 |
|
89 | 104 | /** |
90 | | - * Utility method to compare the left and right according to the natural ordering |
91 | | - * and return an integer array where each row contains the comparison result (-1, 0, 1) for |
92 | | - * corresponding rows in the input vectors compared. |
| 105 | + * Utility method to create a column vector that lazily evaluate the |
| 106 | + * comparator ex. (ie. ==, >=, <=......) for left and right |
| 107 | + * column vector according to the natural ordering of numbers |
93 | 108 | * <p> |
94 | 109 | * Only primitive data types are supported. |
95 | 110 | */ |
96 | | - static int[] compare(ColumnVector left, ColumnVector right) { |
| 111 | + static ColumnVector comparatorVector( |
| 112 | + ColumnVector left, |
| 113 | + ColumnVector right, |
| 114 | + IntPredicate booleanComparator) { |
97 | 115 | checkArgument( |
98 | | - left.getSize() == right.getSize(), |
99 | | - "Left and right operand have different vector sizes."); |
100 | | - DataType dataType = left.getDataType(); |
| 116 | + left.getSize() == right.getSize(), |
| 117 | + "Left and right operand have different vector sizes."); |
101 | 118 |
|
102 | | - int numRows = left.getSize(); |
103 | | - int[] result = new int[numRows]; |
| 119 | + DataType dataType = left.getDataType(); |
| 120 | + IntPredicate vectorValueComparator; |
104 | 121 | if (dataType instanceof BooleanType) { |
105 | | - compareBoolean(left, right, result); |
| 122 | + vectorValueComparator = rowId -> booleanComparator.test( |
| 123 | + Boolean.compare(left.getBoolean(rowId), right.getBoolean(rowId))); |
106 | 124 | } else if (dataType instanceof ByteType) { |
107 | | - compareByte(left, right, result); |
| 125 | + vectorValueComparator = rowId -> booleanComparator.test( |
| 126 | + Byte.compare(left.getByte(rowId), right.getByte(rowId))); |
108 | 127 | } else if (dataType instanceof ShortType) { |
109 | | - compareShort(left, right, result); |
| 128 | + vectorValueComparator = rowId -> booleanComparator.test( |
| 129 | + Short.compare(left.getShort(rowId), right.getShort(rowId))); |
110 | 130 | } else if (dataType instanceof IntegerType || dataType instanceof DateType) { |
111 | | - compareInt(left, right, result); |
| 131 | + vectorValueComparator = rowId -> booleanComparator.test( |
| 132 | + Integer.compare(left.getInt(rowId), right.getInt(rowId))); |
112 | 133 | } else if (dataType instanceof LongType || |
113 | 134 | dataType instanceof TimestampType || |
114 | 135 | dataType instanceof TimestampNTZType) { |
115 | | - compareLong(left, right, result); |
| 136 | + vectorValueComparator = rowId -> booleanComparator.test( |
| 137 | + Long.compare(left.getLong(rowId), right.getLong(rowId))); |
116 | 138 | } else if (dataType instanceof FloatType) { |
117 | | - compareFloat(left, right, result); |
| 139 | + vectorValueComparator = rowId -> booleanComparator.test( |
| 140 | + Float.compare(left.getFloat(rowId), right.getFloat(rowId))); |
118 | 141 | } else if (dataType instanceof DoubleType) { |
119 | | - compareDouble(left, right, result); |
| 142 | + vectorValueComparator = rowId -> booleanComparator.test( |
| 143 | + Double.compare(left.getDouble(rowId), right.getDouble(rowId))); |
120 | 144 | } else if (dataType instanceof DecimalType) { |
121 | | - compareDecimal(left, right, result); |
| 145 | + vectorValueComparator = rowId -> booleanComparator.test( |
| 146 | + BIGDECIMAL_COMPARATOR.compare( |
| 147 | + left.getDecimal(rowId), right.getDecimal(rowId))); |
122 | 148 | } else if (dataType instanceof StringType) { |
123 | | - compareString(left, right, result); |
| 149 | + vectorValueComparator = rowId -> booleanComparator.test( |
| 150 | + STRING_COMPARATOR.compare( |
| 151 | + left.getString(rowId), right.getString(rowId))); |
124 | 152 | } else if (dataType instanceof BinaryType) { |
125 | | - compareBinary(left, right, result); |
| 153 | + vectorValueComparator = rowId -> booleanComparator.test( |
| 154 | + BINARY_COMPARTOR.compare( |
| 155 | + left.getBinary(rowId), right.getBinary(rowId))); |
126 | 156 | } else { |
127 | 157 | throw new UnsupportedOperationException(dataType + " can not be compared."); |
128 | 158 | } |
129 | | - return result; |
130 | | - } |
131 | | - |
132 | | - static void compareBoolean(ColumnVector left, ColumnVector right, int[] result) { |
133 | | - for (int rowId = 0; rowId < left.getSize(); rowId++) { |
134 | | - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { |
135 | | - result[rowId] = Boolean.compare(left.getBoolean(rowId), right.getBoolean(rowId)); |
136 | | - } |
137 | | - } |
138 | | - } |
139 | 159 |
|
140 | | - static void compareByte(ColumnVector left, ColumnVector right, int[] result) { |
141 | | - for (int rowId = 0; rowId < left.getSize(); rowId++) { |
142 | | - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { |
143 | | - result[rowId] = Byte.compare(left.getByte(rowId), right.getByte(rowId)); |
144 | | - } |
145 | | - } |
146 | | - } |
147 | | - |
148 | | - static void compareShort(ColumnVector left, ColumnVector right, int[] result) { |
149 | | - for (int rowId = 0; rowId < left.getSize(); rowId++) { |
150 | | - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { |
151 | | - result[rowId] = Short.compare(left.getShort(rowId), right.getShort(rowId)); |
152 | | - } |
153 | | - } |
154 | | - } |
155 | | - |
156 | | - static void compareInt(ColumnVector left, ColumnVector right, int[] result) { |
157 | | - for (int rowId = 0; rowId < left.getSize(); rowId++) { |
158 | | - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { |
159 | | - result[rowId] = Integer.compare(left.getInt(rowId), right.getInt(rowId)); |
160 | | - } |
161 | | - } |
162 | | - } |
163 | | - |
164 | | - static void compareLong(ColumnVector left, ColumnVector right, int[] result) { |
165 | | - for (int rowId = 0; rowId < left.getSize(); rowId++) { |
166 | | - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { |
167 | | - result[rowId] = Long.compare(left.getLong(rowId), right.getLong(rowId)); |
168 | | - } |
169 | | - } |
170 | | - } |
| 160 | + return new ColumnVector() { |
171 | 161 |
|
172 | | - static void compareFloat(ColumnVector left, ColumnVector right, int[] result) { |
173 | | - for (int rowId = 0; rowId < left.getSize(); rowId++) { |
174 | | - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { |
175 | | - result[rowId] = Float.compare(left.getFloat(rowId), right.getFloat(rowId)); |
| 162 | + @Override |
| 163 | + public DataType getDataType() { |
| 164 | + return BooleanType.BOOLEAN; |
176 | 165 | } |
177 | | - } |
178 | | - } |
179 | 166 |
|
180 | | - static void compareDouble(ColumnVector left, ColumnVector right, int[] result) { |
181 | | - for (int rowId = 0; rowId < left.getSize(); rowId++) { |
182 | | - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { |
183 | | - result[rowId] = Double.compare(left.getDouble(rowId), right.getDouble(rowId)); |
| 167 | + @Override |
| 168 | + public void close() { |
| 169 | + Utils.closeCloseables(left, right); |
184 | 170 | } |
185 | | - } |
186 | | - } |
187 | 171 |
|
188 | | - static void compareString(ColumnVector left, ColumnVector right, int[] result) { |
189 | | - Comparator<String> comparator = Comparator.naturalOrder(); |
190 | | - for (int rowId = 0; rowId < left.getSize(); rowId++) { |
191 | | - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { |
192 | | - result[rowId] = comparator.compare(left.getString(rowId), right.getString(rowId)); |
| 172 | + @Override |
| 173 | + public int getSize() { |
| 174 | + return left.getSize(); |
193 | 175 | } |
194 | | - } |
195 | | - } |
196 | 176 |
|
197 | | - static void compareDecimal(ColumnVector left, ColumnVector right, int[] result) { |
198 | | - Comparator<BigDecimal> comparator = Comparator.naturalOrder(); |
199 | | - for (int rowId = 0; rowId < left.getSize(); rowId++) { |
200 | | - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { |
201 | | - result[rowId] = comparator.compare(left.getDecimal(rowId), right.getDecimal(rowId)); |
| 177 | + @Override |
| 178 | + public boolean isNullAt(int rowId) { |
| 179 | + return left.isNullAt(rowId) || right.isNullAt(rowId); |
202 | 180 | } |
203 | | - } |
204 | | - } |
205 | 181 |
|
206 | | - static void compareBinary(ColumnVector left, ColumnVector right, int[] result) { |
207 | | - Comparator<byte[]> comparator = (leftOp, rightOp) -> { |
208 | | - int i = 0; |
209 | | - while (i < leftOp.length && i < rightOp.length) { |
210 | | - if (leftOp[i] != rightOp[i]) { |
211 | | - return Byte.compare(leftOp[i], rightOp[i]); |
| 182 | + @Override |
| 183 | + public boolean getBoolean(int rowId) { |
| 184 | + if (isNullAt(rowId)) { |
| 185 | + return false; |
212 | 186 | } |
213 | | - i++; |
| 187 | + return vectorValueComparator.test(rowId); |
214 | 188 | } |
215 | | - return Integer.compare(leftOp.length, rightOp.length); |
216 | 189 | }; |
217 | | - for (int rowId = 0; rowId < left.getSize(); rowId++) { |
218 | | - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { |
219 | | - result[rowId] = comparator.compare(left.getBinary(rowId), right.getBinary(rowId)); |
220 | | - } |
221 | | - } |
222 | 190 | } |
223 | 191 |
|
224 | 192 | static Expression childAt(Expression expression, int index) { |
|
0 commit comments