Skip to content

Commit ce359bc

Browse files
vinodkcyaooqinn
authored andcommitted
[SPARK-44325][SQL] Use PartitionEvaluator API in SortMergeJoinExec
### What changes were proposed in this pull request? SQL operator `SortMergeJoinExec` updated to use the `PartitionEvaluator` API to do execution. ### Why are the changes needed? To avoid the use of lambda during distributed execution. Ref: SPARK-43061 for more details. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Updated 1 test case, once all the SQL operators are migrated, the flag `spark.sql.execution.useTaskEvaluator` will be enabled by default to avoid running the tests with and without this TaskEvaluator Closes #41884 from vinodkc/br_refactorSortMergeJoinEvaluatorFactory. Authored-by: Vinod KC <[email protected]> Signed-off-by: Kent Yao <[email protected]>
1 parent 443b49e commit ce359bc

File tree

3 files changed

+332
-252
lines changed

3 files changed

+332
-252
lines changed
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.joins
19+
20+
import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory}
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Predicate, Projection, RowOrdering, UnsafeProjection, UnsafeRow}
23+
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
24+
import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan}
25+
import org.apache.spark.sql.execution.metric.SQLMetric
26+
27+
class SortMergeJoinEvaluatorFactory(
28+
leftKeys: Seq[Expression],
29+
rightKeys: Seq[Expression],
30+
joinType: JoinType,
31+
condition: Option[Expression],
32+
left: SparkPlan,
33+
right: SparkPlan,
34+
output: Seq[Attribute],
35+
inMemoryThreshold: Int,
36+
spillThreshold: Int,
37+
numOutputRows: SQLMetric,
38+
spillSize: SQLMetric,
39+
onlyBufferFirstMatchedRow: Boolean)
40+
extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
41+
override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] =
42+
new SortMergeJoinEvaluator
43+
44+
private class SortMergeJoinEvaluator extends PartitionEvaluator[InternalRow, InternalRow] {
45+
46+
private def cleanupResources(): Unit = {
47+
IndexedSeq(left, right).foreach(_.cleanupResources())
48+
}
49+
private def createLeftKeyGenerator(): Projection =
50+
UnsafeProjection.create(leftKeys, left.output)
51+
52+
private def createRightKeyGenerator(): Projection =
53+
UnsafeProjection.create(rightKeys, right.output)
54+
55+
override def eval(
56+
partitionIndex: Int,
57+
inputs: Iterator[InternalRow]*): Iterator[InternalRow] = {
58+
assert(inputs.length == 2)
59+
val leftIter = inputs(0)
60+
val rightIter = inputs(1)
61+
62+
val boundCondition: InternalRow => Boolean = {
63+
condition.map { cond =>
64+
Predicate.create(cond, left.output ++ right.output).eval _
65+
}.getOrElse {
66+
(r: InternalRow) => true
67+
}
68+
}
69+
70+
// An ordering that can be used to compare keys from both sides.
71+
val keyOrdering = RowOrdering.createNaturalAscendingOrdering(leftKeys.map(_.dataType))
72+
val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
73+
74+
joinType match {
75+
case _: InnerLike =>
76+
new RowIterator {
77+
private[this] var currentLeftRow: InternalRow = _
78+
private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _
79+
private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
80+
private[this] val smjScanner = new SortMergeJoinScanner(
81+
createLeftKeyGenerator(),
82+
createRightKeyGenerator(),
83+
keyOrdering,
84+
RowIterator.fromScala(leftIter),
85+
RowIterator.fromScala(rightIter),
86+
inMemoryThreshold,
87+
spillThreshold,
88+
spillSize,
89+
cleanupResources)
90+
private[this] val joinRow = new JoinedRow
91+
92+
if (smjScanner.findNextInnerJoinRows()) {
93+
currentRightMatches = smjScanner.getBufferedMatches
94+
currentLeftRow = smjScanner.getStreamedRow
95+
rightMatchesIterator = currentRightMatches.generateIterator()
96+
}
97+
98+
override def advanceNext(): Boolean = {
99+
while (rightMatchesIterator != null) {
100+
if (!rightMatchesIterator.hasNext) {
101+
if (smjScanner.findNextInnerJoinRows()) {
102+
currentRightMatches = smjScanner.getBufferedMatches
103+
currentLeftRow = smjScanner.getStreamedRow
104+
rightMatchesIterator = currentRightMatches.generateIterator()
105+
} else {
106+
currentRightMatches = null
107+
currentLeftRow = null
108+
rightMatchesIterator = null
109+
return false
110+
}
111+
}
112+
joinRow(currentLeftRow, rightMatchesIterator.next())
113+
if (boundCondition(joinRow)) {
114+
numOutputRows += 1
115+
return true
116+
}
117+
}
118+
false
119+
}
120+
121+
override def getRow: InternalRow = resultProj(joinRow)
122+
}.toScala
123+
124+
case LeftOuter =>
125+
val smjScanner = new SortMergeJoinScanner(
126+
streamedKeyGenerator = createLeftKeyGenerator(),
127+
bufferedKeyGenerator = createRightKeyGenerator(),
128+
keyOrdering,
129+
streamedIter = RowIterator.fromScala(leftIter),
130+
bufferedIter = RowIterator.fromScala(rightIter),
131+
inMemoryThreshold,
132+
spillThreshold,
133+
spillSize,
134+
cleanupResources)
135+
val rightNullRow = new GenericInternalRow(right.output.length)
136+
new LeftOuterIterator(
137+
smjScanner,
138+
rightNullRow,
139+
boundCondition,
140+
resultProj,
141+
numOutputRows).toScala
142+
143+
case RightOuter =>
144+
val smjScanner = new SortMergeJoinScanner(
145+
streamedKeyGenerator = createRightKeyGenerator(),
146+
bufferedKeyGenerator = createLeftKeyGenerator(),
147+
keyOrdering,
148+
streamedIter = RowIterator.fromScala(rightIter),
149+
bufferedIter = RowIterator.fromScala(leftIter),
150+
inMemoryThreshold,
151+
spillThreshold,
152+
spillSize,
153+
cleanupResources)
154+
val leftNullRow = new GenericInternalRow(left.output.length)
155+
new RightOuterIterator(
156+
smjScanner,
157+
leftNullRow,
158+
boundCondition,
159+
resultProj,
160+
numOutputRows).toScala
161+
162+
case FullOuter =>
163+
val leftNullRow = new GenericInternalRow(left.output.length)
164+
val rightNullRow = new GenericInternalRow(right.output.length)
165+
val smjScanner = new SortMergeFullOuterJoinScanner(
166+
leftKeyGenerator = createLeftKeyGenerator(),
167+
rightKeyGenerator = createRightKeyGenerator(),
168+
keyOrdering,
169+
leftIter = RowIterator.fromScala(leftIter),
170+
rightIter = RowIterator.fromScala(rightIter),
171+
boundCondition,
172+
leftNullRow,
173+
rightNullRow)
174+
175+
new FullOuterIterator(smjScanner, resultProj, numOutputRows).toScala
176+
177+
case LeftSemi =>
178+
new RowIterator {
179+
private[this] var currentLeftRow: InternalRow = _
180+
private[this] val smjScanner = new SortMergeJoinScanner(
181+
createLeftKeyGenerator(),
182+
createRightKeyGenerator(),
183+
keyOrdering,
184+
RowIterator.fromScala(leftIter),
185+
RowIterator.fromScala(rightIter),
186+
inMemoryThreshold,
187+
spillThreshold,
188+
spillSize,
189+
cleanupResources,
190+
onlyBufferFirstMatchedRow)
191+
private[this] val joinRow = new JoinedRow
192+
193+
override def advanceNext(): Boolean = {
194+
while (smjScanner.findNextInnerJoinRows()) {
195+
val currentRightMatches = smjScanner.getBufferedMatches
196+
currentLeftRow = smjScanner.getStreamedRow
197+
if (currentRightMatches != null && currentRightMatches.length > 0) {
198+
val rightMatchesIterator = currentRightMatches.generateIterator()
199+
while (rightMatchesIterator.hasNext) {
200+
joinRow(currentLeftRow, rightMatchesIterator.next())
201+
if (boundCondition(joinRow)) {
202+
numOutputRows += 1
203+
return true
204+
}
205+
}
206+
}
207+
}
208+
false
209+
}
210+
211+
override def getRow: InternalRow = currentLeftRow
212+
}.toScala
213+
214+
case LeftAnti =>
215+
new RowIterator {
216+
private[this] var currentLeftRow: InternalRow = _
217+
private[this] val smjScanner = new SortMergeJoinScanner(
218+
createLeftKeyGenerator(),
219+
createRightKeyGenerator(),
220+
keyOrdering,
221+
RowIterator.fromScala(leftIter),
222+
RowIterator.fromScala(rightIter),
223+
inMemoryThreshold,
224+
spillThreshold,
225+
spillSize,
226+
cleanupResources,
227+
onlyBufferFirstMatchedRow)
228+
private[this] val joinRow = new JoinedRow
229+
230+
override def advanceNext(): Boolean = {
231+
while (smjScanner.findNextOuterJoinRows()) {
232+
currentLeftRow = smjScanner.getStreamedRow
233+
val currentRightMatches = smjScanner.getBufferedMatches
234+
if (currentRightMatches == null || currentRightMatches.length == 0) {
235+
numOutputRows += 1
236+
return true
237+
}
238+
var found = false
239+
val rightMatchesIterator = currentRightMatches.generateIterator()
240+
while (!found && rightMatchesIterator.hasNext) {
241+
joinRow(currentLeftRow, rightMatchesIterator.next())
242+
if (boundCondition(joinRow)) {
243+
found = true
244+
}
245+
}
246+
if (!found) {
247+
numOutputRows += 1
248+
return true
249+
}
250+
}
251+
false
252+
}
253+
254+
override def getRow: InternalRow = currentLeftRow
255+
}.toScala
256+
257+
case j: ExistenceJoin =>
258+
new RowIterator {
259+
private[this] var currentLeftRow: InternalRow = _
260+
private[this] val result: InternalRow = new GenericInternalRow(Array[Any](null))
261+
private[this] val smjScanner = new SortMergeJoinScanner(
262+
createLeftKeyGenerator(),
263+
createRightKeyGenerator(),
264+
keyOrdering,
265+
RowIterator.fromScala(leftIter),
266+
RowIterator.fromScala(rightIter),
267+
inMemoryThreshold,
268+
spillThreshold,
269+
spillSize,
270+
cleanupResources,
271+
onlyBufferFirstMatchedRow)
272+
private[this] val joinRow = new JoinedRow
273+
274+
override def advanceNext(): Boolean = {
275+
while (smjScanner.findNextOuterJoinRows()) {
276+
currentLeftRow = smjScanner.getStreamedRow
277+
val currentRightMatches = smjScanner.getBufferedMatches
278+
var found = false
279+
if (currentRightMatches != null && currentRightMatches.length > 0) {
280+
val rightMatchesIterator = currentRightMatches.generateIterator()
281+
while (!found && rightMatchesIterator.hasNext) {
282+
joinRow(currentLeftRow, rightMatchesIterator.next())
283+
if (boundCondition(joinRow)) {
284+
found = true
285+
}
286+
}
287+
}
288+
result.setBoolean(0, found)
289+
numOutputRows += 1
290+
return true
291+
}
292+
false
293+
}
294+
295+
override def getRow: InternalRow = resultProj(joinRow(currentLeftRow, result))
296+
}.toScala
297+
298+
case x =>
299+
throw new IllegalArgumentException(s"SortMergeJoin should not take $x as the JoinType")
300+
}
301+
302+
}
303+
}
304+
}

0 commit comments

Comments
 (0)