1
1
from __future__ import annotations
2
2
3
- from typing import Any
3
+ from typing import Any , Literal , TYPE_CHECKING , overload
4
4
5
5
from importlib .metadata import version
6
6
11
11
get_meta as _get_meta ,
12
12
)
13
13
14
+ if TYPE_CHECKING :
15
+ import pandas as pd
16
+ import polars as pl
17
+ import modin .pandas as mpd
18
+ import dask .dataframe as dd
19
+ import pyarrow as pa
20
+
14
21
__version__ = version (__name__ )
15
22
16
23
import os
27
34
"CX_REWRITER_PATH" , os .path .join (dir_path , "dependencies/federated-rewriter.jar" )
28
35
)
29
36
37
+ Protocol = Literal ["csv" , "binary" , "cursor" , "simple" , "text" ]
30
38
31
- def rewrite_conn (conn : str , protocol : str | None = None ):
39
+
40
+ def rewrite_conn (conn : str , protocol : Protocol | None = None ) -> tuple [str , Protocol ]:
32
41
if not protocol :
33
42
# note: redshift/clickhouse are not compatible with the 'binary' protocol, and use other database
34
43
# drivers to connect. set a compatible protocol and masquerade as the appropriate backend.
@@ -47,8 +56,8 @@ def rewrite_conn(conn: str, protocol: str | None = None):
47
56
def get_meta (
48
57
conn : str ,
49
58
query : str ,
50
- protocol : str | None = None ,
51
- ):
59
+ protocol : Protocol | None = None ,
60
+ ) -> pd . DataFrame :
52
61
"""
53
62
Get metadata (header) of the given query (only for pandas)
54
63
@@ -75,7 +84,7 @@ def partition_sql(
75
84
partition_on : str ,
76
85
partition_num : int ,
77
86
partition_range : tuple [int , int ] | None = None ,
78
- ):
87
+ ) -> list [ str ] :
79
88
"""
80
89
Partition the sql query
81
90
@@ -110,7 +119,7 @@ def read_sql_pandas(
110
119
partition_on : str | None = None ,
111
120
partition_range : tuple [int , int ] | None = None ,
112
121
partition_num : int | None = None ,
113
- ):
122
+ ) -> pd . DataFrame :
114
123
"""
115
124
Run the SQL query, download the data from database into a dataframe.
116
125
First several parameters are in the same name and order with `pandas.read_sql`.
@@ -142,17 +151,103 @@ def read_sql_pandas(
142
151
)
143
152
144
153
154
+ # default return pd.DataFrame
155
+ @overload
145
156
def read_sql (
146
157
conn : str | dict [str , str ],
147
158
query : list [str ] | str ,
148
159
* ,
149
- return_type : str = "pandas" ,
150
160
protocol : str | None = None ,
151
161
partition_on : str | None = None ,
152
162
partition_range : tuple [int , int ] | None = None ,
153
163
partition_num : int | None = None ,
154
164
index_col : str | None = None ,
155
- ):
165
+ ) -> pd .DataFrame : ...
166
+
167
+
168
+ @overload
169
+ def read_sql (
170
+ conn : str | dict [str , str ],
171
+ query : list [str ] | str ,
172
+ * ,
173
+ return_type : Literal ["pandas" ],
174
+ protocol : str | None = None ,
175
+ partition_on : str | None = None ,
176
+ partition_range : tuple [int , int ] | None = None ,
177
+ partition_num : int | None = None ,
178
+ index_col : str | None = None ,
179
+ ) -> pd .DataFrame : ...
180
+
181
+
182
+ @overload
183
+ def read_sql (
184
+ conn : str | dict [str , str ],
185
+ query : list [str ] | str ,
186
+ * ,
187
+ return_type : Literal ["arrow" , "arrow2" ],
188
+ protocol : str | None = None ,
189
+ partition_on : str | None = None ,
190
+ partition_range : tuple [int , int ] | None = None ,
191
+ partition_num : int | None = None ,
192
+ index_col : str | None = None ,
193
+ ) -> pa .Table : ...
194
+
195
+
196
+ @overload
197
+ def read_sql (
198
+ conn : str | dict [str , str ],
199
+ query : list [str ] | str ,
200
+ * ,
201
+ return_type : Literal ["modin" ],
202
+ protocol : str | None = None ,
203
+ partition_on : str | None = None ,
204
+ partition_range : tuple [int , int ] | None = None ,
205
+ partition_num : int | None = None ,
206
+ index_col : str | None = None ,
207
+ ) -> mpd .DataFrame : ...
208
+
209
+
210
+ @overload
211
+ def read_sql (
212
+ conn : str | dict [str , str ],
213
+ query : list [str ] | str ,
214
+ * ,
215
+ return_type : Literal ["dask" ],
216
+ protocol : str | None = None ,
217
+ partition_on : str | None = None ,
218
+ partition_range : tuple [int , int ] | None = None ,
219
+ partition_num : int | None = None ,
220
+ index_col : str | None = None ,
221
+ ) -> dd .DataFrame : ...
222
+
223
+
224
+ @overload
225
+ def read_sql (
226
+ conn : str | dict [str , str ],
227
+ query : list [str ] | str ,
228
+ * ,
229
+ return_type : Literal ["polars" , "polars2" ],
230
+ protocol : str | None = None ,
231
+ partition_on : str | None = None ,
232
+ partition_range : tuple [int , int ] | None = None ,
233
+ partition_num : int | None = None ,
234
+ index_col : str | None = None ,
235
+ ) -> pl .DataFrame : ...
236
+
237
+
238
+ def read_sql (
239
+ conn : str | dict [str , str ],
240
+ query : list [str ] | str ,
241
+ * ,
242
+ return_type : Literal [
243
+ "pandas" , "polars" , "polars2" , "arrow" , "arrow2" , "modin" , "dask"
244
+ ] = "pandas" ,
245
+ protocol : str | None = None ,
246
+ partition_on : str | None = None ,
247
+ partition_range : tuple [int , int ] | None = None ,
248
+ partition_num : int | None = None ,
249
+ index_col : str | None = None ,
250
+ ) -> pd .DataFrame | mpd .DataFrame | dd .DataFrame | pl .DataFrame | pa .Table :
156
251
"""
157
252
Run the SQL query, download the data from database into a dataframe.
158
253
@@ -318,7 +413,9 @@ def read_sql(
318
413
return df
319
414
320
415
321
- def reconstruct_arrow (result : tuple [list [str ], list [list [tuple [int , int ]]]]):
416
+ def reconstruct_arrow (
417
+ result : tuple [list [str ], list [list [tuple [int , int ]]]],
418
+ ) -> pa .Table :
322
419
import pyarrow as pa
323
420
324
421
names , ptrs = result
@@ -334,7 +431,7 @@ def reconstruct_arrow(result: tuple[list[str], list[list[tuple[int, int]]]]):
334
431
return pa .Table .from_batches (rbs )
335
432
336
433
337
- def reconstruct_pandas (df_infos : dict [str , Any ]):
434
+ def reconstruct_pandas (df_infos : dict [str , Any ]) -> pd . DataFrame :
338
435
import pandas as pd
339
436
340
437
data = df_infos ["data" ]
@@ -388,6 +485,6 @@ def remove_ending_semicolon(query: str) -> str:
388
485
SQL query
389
486
390
487
"""
391
- if query .endswith (';' ):
488
+ if query .endswith (";" ):
392
489
query = query [:- 1 ]
393
490
return query
0 commit comments