Skip to content

Commit 1393f7d

Browse files
chore: sql/parse cleanup (#33515)
1 parent b7ba500 commit 1393f7d

File tree

4 files changed

+38
-20
lines changed

4 files changed

+38
-20
lines changed

superset/sql/parse.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,17 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
164164

165165
def __init__(
166166
self,
167-
statement: str,
168-
engine: str,
167+
statement: str | None = None,
168+
engine: str = "base",
169169
ast: InternalRepresentation | None = None,
170170
):
171-
self._sql = statement
172-
self._parsed = ast or self._parse_statement(statement, engine)
171+
if ast:
172+
self._parsed = ast
173+
elif statement:
174+
self._parsed = self._parse_statement(statement, engine)
175+
else:
176+
raise SupersetParseError("Either statement or ast must be provided")
177+
173178
self.engine = engine
174179
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
175180

@@ -284,8 +289,8 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
284289

285290
def __init__(
286291
self,
287-
statement: str,
288-
engine: str,
292+
statement: str | None = None,
293+
engine: str = "base",
289294
ast: exp.Expression | None = None,
290295
):
291296
self._dialect = SQLGLOT_DIALECTS.get(engine)
@@ -423,7 +428,10 @@ def is_mutating(self) -> bool:
423428
and self._parsed.expression.name.upper().startswith("ANALYZE ")
424429
):
425430
analyzed_sql = self._parsed.expression.name[len("ANALYZE ") :]
426-
return SQLStatement(analyzed_sql, self.engine).is_mutating()
431+
return SQLStatement(
432+
statement=analyzed_sql,
433+
engine=self.engine,
434+
).is_mutating()
427435

428436
return False
429437

@@ -459,12 +467,11 @@ def optimize(self) -> SQLStatement:
459467
"""
460468
# only optimize statements that have a custom dialect
461469
if not self._dialect:
462-
return SQLStatement(self._sql, self.engine, self._parsed.copy())
470+
return SQLStatement(ast=self._parsed.copy(), engine=self.engine)
463471

464472
optimized = pushdown_predicates(self._parsed, dialect=self._dialect)
465-
sql = optimized.sql(dialect=self._dialect)
466473

467-
return SQLStatement(sql, self.engine, optimized)
474+
return SQLStatement(ast=optimized, engine=self.engine)
468475

469476
def check_functions_present(self, functions: set[str]) -> bool:
470477
"""
@@ -668,6 +675,14 @@ class KustoKQLStatement(BaseSQLStatement[str]):
668675
details about it.
669676
"""
670677

678+
def __init__(
679+
self,
680+
statement: str | None = None,
681+
engine: str = "kustokql",
682+
ast: str | None = None,
683+
):
684+
super().__init__(statement, engine, ast)
685+
671686
@classmethod
672687
def split_script(
673688
cls,
@@ -725,7 +740,7 @@ def format(self, comments: bool = True) -> str:
725740
"""
726741
Pretty-format the SQL statement.
727742
"""
728-
return self._sql.strip()
743+
return self._parsed.strip()
729744

730745
def get_settings(self) -> dict[str, str | bool]:
731746
"""
@@ -756,7 +771,7 @@ def optimize(self) -> KustoKQLStatement:
756771
757772
Kusto KQL doesn't support optimization, so this method is a no-op.
758773
"""
759-
return KustoKQLStatement(self._sql, self.engine, self._parsed)
774+
return KustoKQLStatement(ast=self._parsed, engine=self.engine)
760775

761776
def check_functions_present(self, functions: set[str]) -> bool:
762777
"""
@@ -774,7 +789,7 @@ def get_limit_value(self) -> int | None:
774789
"""
775790
tokens = [
776791
token
777-
for token in tokenize_kql(self._sql)
792+
for token in tokenize_kql(self._parsed)
778793
if token[0] != KQLTokenType.WHITESPACE
779794
]
780795
for idx, (ttype, val) in enumerate(tokens):
@@ -796,7 +811,7 @@ def set_limit_value(
796811
if method != LimitMethod.FORCE_LIMIT:
797812
raise SupersetParseError("Kusto KQL only supports the FORCE_LIMIT method.")
798813

799-
tokens = tokenize_kql(self._sql)
814+
tokens = tokenize_kql(self._parsed)
800815
found_limit_token = False
801816
for idx, (ttype, val) in enumerate(tokens):
802817
if ttype != KQLTokenType.STRING and val.lower() in {"take", "limit"}:
@@ -817,7 +832,7 @@ def set_limit_value(
817832
]
818833
)
819834

820-
self._parsed = self._sql = "".join(val for _, val in tokens)
835+
self._parsed = "".join(val for _, val in tokens)
821836

822837

823838
class SQLScript:

superset/sql_lab.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,10 @@ def execute_sql_statement( # pylint: disable=too-many-statements, too-many-loca
242242
if not database.allow_dml:
243243
errors = []
244244
try:
245-
parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine)
245+
parsed_statement = SQLStatement(
246+
statement=sql_statement,
247+
engine=db_engine_spec.engine,
248+
)
246249
disallowed = parsed_statement.is_mutating()
247250
except SupersetParseError as ex:
248251
# if we fail to parse the query, disallow by default

superset/sql_parse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def has_table_query(expression: str, engine: str) -> bool:
533533
expression = f"({expression})"
534534

535535
sql = f"SELECT {expression}"
536-
statement = SQLStatement(sql, engine)
536+
statement = SQLStatement(statement=sql, engine=engine)
537537
return any(statement.tables)
538538

539539

tests/unit_tests/sql/parse_tests.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,9 @@ def test_split_no_dialect() -> None:
318318
sql = "SELECT col FROM t WHERE col NOT IN (1, 2); SELECT * FROM t; SELECT foo"
319319
statements = SQLScript(sql, "dremio").statements
320320
assert len(statements) == 3
321-
assert statements[0]._sql == "SELECT col FROM t WHERE col NOT IN (1, 2)"
322-
assert statements[1]._sql == "SELECT * FROM t"
323-
assert statements[2]._sql == "SELECT foo"
321+
assert statements[0].format() == "SELECT\n col\nFROM t\nWHERE\n NOT col IN (1, 2)"
322+
assert statements[1].format() == "SELECT\n *\nFROM t"
323+
assert statements[2].format() == "SELECT\n foo"
324324

325325

326326
def test_extract_tables_show_columns_from() -> None:

0 commit comments

Comments
 (0)