@@ -164,12 +164,17 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
164164
165165 def __init__ (
166166 self ,
167- statement : str ,
168- engine : str ,
167+ statement : str | None = None ,
168+ engine : str = "base" ,
169169 ast : InternalRepresentation | None = None ,
170170 ):
171- self ._sql = statement
172- self ._parsed = ast or self ._parse_statement (statement , engine )
171+ if ast :
172+ self ._parsed = ast
173+ elif statement :
174+ self ._parsed = self ._parse_statement (statement , engine )
175+ else :
176+ raise SupersetParseError ("Either statement or ast must be provided" )
177+
173178 self .engine = engine
174179 self .tables = self ._extract_tables_from_statement (self ._parsed , self .engine )
175180
@@ -284,8 +289,8 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
284289
285290 def __init__ (
286291 self ,
287- statement : str ,
288- engine : str ,
292+ statement : str | None = None ,
293+ engine : str = "base" ,
289294 ast : exp .Expression | None = None ,
290295 ):
291296 self ._dialect = SQLGLOT_DIALECTS .get (engine )
@@ -423,7 +428,10 @@ def is_mutating(self) -> bool:
423428 and self ._parsed .expression .name .upper ().startswith ("ANALYZE " )
424429 ):
425430 analyzed_sql = self ._parsed .expression .name [len ("ANALYZE " ) :]
426- return SQLStatement (analyzed_sql , self .engine ).is_mutating ()
431+ return SQLStatement (
432+ statement = analyzed_sql ,
433+ engine = self .engine ,
434+ ).is_mutating ()
427435
428436 return False
429437
@@ -459,12 +467,11 @@ def optimize(self) -> SQLStatement:
459467 """
460468 # only optimize statements that have a custom dialect
461469 if not self ._dialect :
462- return SQLStatement (self ._sql , self . engine , self . _parsed .copy ())
470+ return SQLStatement (ast = self ._parsed .copy (), engine = self . engine )
463471
464472 optimized = pushdown_predicates (self ._parsed , dialect = self ._dialect )
465- sql = optimized .sql (dialect = self ._dialect )
466473
467- return SQLStatement (sql , self .engine , optimized )
474+ return SQLStatement (ast = optimized , engine = self .engine )
468475
469476 def check_functions_present (self , functions : set [str ]) -> bool :
470477 """
@@ -668,6 +675,14 @@ class KustoKQLStatement(BaseSQLStatement[str]):
668675 details about it.
669676 """
670677
678+ def __init__ (
679+ self ,
680+ statement : str | None = None ,
681+ engine : str = "kustokql" ,
682+ ast : str | None = None ,
683+ ):
684+ super ().__init__ (statement , engine , ast )
685+
671686 @classmethod
672687 def split_script (
673688 cls ,
@@ -725,7 +740,7 @@ def format(self, comments: bool = True) -> str:
725740 """
726741 Pretty-format the SQL statement.
727742 """
728- return self ._sql .strip ()
743+ return self ._parsed .strip ()
729744
730745 def get_settings (self ) -> dict [str , str | bool ]:
731746 """
@@ -756,7 +771,7 @@ def optimize(self) -> KustoKQLStatement:
756771
757772 Kusto KQL doesn't support optimization, so this method is a no-op.
758773 """
759- return KustoKQLStatement (self ._sql , self . engine , self ._parsed )
774+ return KustoKQLStatement (ast = self ._parsed , engine = self .engine )
760775
761776 def check_functions_present (self , functions : set [str ]) -> bool :
762777 """
@@ -774,7 +789,7 @@ def get_limit_value(self) -> int | None:
774789 """
775790 tokens = [
776791 token
777- for token in tokenize_kql (self ._sql )
792+ for token in tokenize_kql (self ._parsed )
778793 if token [0 ] != KQLTokenType .WHITESPACE
779794 ]
780795 for idx , (ttype , val ) in enumerate (tokens ):
@@ -796,7 +811,7 @@ def set_limit_value(
796811 if method != LimitMethod .FORCE_LIMIT :
797812 raise SupersetParseError ("Kusto KQL only supports the FORCE_LIMIT method." )
798813
799- tokens = tokenize_kql (self ._sql )
814+ tokens = tokenize_kql (self ._parsed )
800815 found_limit_token = False
801816 for idx , (ttype , val ) in enumerate (tokens ):
802817 if ttype != KQLTokenType .STRING and val .lower () in {"take" , "limit" }:
@@ -817,7 +832,7 @@ def set_limit_value(
817832 ]
818833 )
819834
820- self ._parsed = self . _sql = "" .join (val for _ , val in tokens )
835+ self ._parsed = "" .join (val for _ , val in tokens )
821836
822837
823838class SQLScript :
0 commit comments