Skip to content

Commit daa9844

Browse files
itholiczhengruifeng
authored andcommitted
[SPARK-43665][CONNECT][PS] Enable PandasSQLStringFormatter.vformat to work with Spark Connect
### What changes were proposed in this pull request? This PR aims enabling SQL parity test `test_sql_with_python_objects` for pandas API on Spark with Spark Connect. ### Why are the changes needed? To increase the API coverage for pandas API on Spark with Spark Connect. ### Does this PR introduce _any_ user-facing change? This enables `ps.sql` with Python objects. ### How was this patch tested? Reuse the existing SQL tests. Closes #41931 from itholic/SPARK-43665. Authored-by: itholic <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent ce359bc commit daa9844

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

python/pyspark/pandas/sql_formatter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from pyspark.pandas.internal import InternalFrame
2727
from pyspark.pandas.namespace import _get_index_map
28-
from pyspark.sql.functions import lit
2928
from pyspark import pandas as ps
3029
from pyspark.sql import SparkSession
3130
from pyspark.pandas.utils import default_session
@@ -265,7 +264,10 @@ def _convert_value(self, val: Any, name: str) -> Optional[str]:
265264
val._to_spark().createOrReplaceTempView(df_name)
266265
return df_name
267266
elif isinstance(val, str):
268-
return lit(val)._jc.expr().sql() # for escaped characters.
267+
# This is matched to behavior from JVM implementation.
268+
# See `sql` definition from `sql/catalyst/src/main/scala/org/apache/spark/
269+
# sql/catalyst/expressions/literals.scala`
270+
return "'" + val.replace("\\", "\\\\").replace("'", "\\'") + "'"
269271
else:
270272
return val
271273

python/pyspark/pandas/tests/connect/test_parity_sql.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,6 @@ def test_sql_with_index_col(self):
3030
def test_sql_with_pandas_on_spark_objects(self):
3131
super().test_sql_with_pandas_on_spark_objects()
3232

33-
@unittest.skip(
34-
"TODO(SPARK-43665): Enable PandasSQLStringFormatter.vformat to work with Spark Connect."
35-
)
36-
def test_sql_with_python_objects(self):
37-
super().test_sql_with_python_objects()
38-
3933

4034
if __name__ == "__main__":
4135
from pyspark.pandas.tests.connect.test_parity_sql import * # noqa: F401

python/pyspark/pandas/tests/test_sql.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ def test_sql_with_python_objects(self):
8181
ps.sql("SELECT id FROM range(10) WHERE id IN {pred}", col="lit", pred=(1, 2, 3)),
8282
ps.DataFrame({"id": [1, 2, 3]}),
8383
)
84+
self.assert_eq(
85+
ps.sql("SELECT {col} as a FROM range(1)", col="a'''c''d"),
86+
ps.DataFrame({"a": ["a'''c''d"]}),
87+
)
88+
self.assert_eq(
89+
ps.sql("SELECT id FROM range(10) WHERE id IN {pred}", col="a'''c''d", pred=(1, 2, 3)),
90+
ps.DataFrame({"id": [1, 2, 3]}),
91+
)
8492

8593
def test_sql_with_pandas_on_spark_objects(self):
8694
psdf = ps.DataFrame({"a": [1, 2, 3, 4]})

0 commit comments

Comments
 (0)