Skip to content

Commit 68a101d

Browse files
committed
typing: add return type annotations
1 parent 59bb016 commit 68a101d

File tree

1 file changed

+108
-11
lines changed

1 file changed

+108
-11
lines changed

connectorx-python/connectorx/__init__.py

Lines changed: 108 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import Any, Literal, TYPE_CHECKING, overload
44

55
from importlib.metadata import version
66

@@ -11,6 +11,13 @@
1111
get_meta as _get_meta,
1212
)
1313

14+
if TYPE_CHECKING:
15+
import pandas as pd
16+
import polars as pl
17+
import modin.pandas as mpd
18+
import dask.dataframe as dd
19+
import pyarrow as pa
20+
1421
__version__ = version(__name__)
1522

1623
import os
@@ -27,8 +34,10 @@
2734
"CX_REWRITER_PATH", os.path.join(dir_path, "dependencies/federated-rewriter.jar")
2835
)
2936

37+
Protocol = Literal["csv", "binary", "cursor", "simple", "text"]
3038

31-
def rewrite_conn(conn: str, protocol: str | None = None):
39+
40+
def rewrite_conn(conn: str, protocol: Protocol | None = None) -> tuple[str, Protocol]:
3241
if not protocol:
3342
# note: redshift/clickhouse are not compatible with the 'binary' protocol, and use other database
3443
# drivers to connect. set a compatible protocol and masquerade as the appropriate backend.
@@ -47,8 +56,8 @@ def rewrite_conn(conn: str, protocol: str | None = None):
4756
def get_meta(
4857
conn: str,
4958
query: str,
50-
protocol: str | None = None,
51-
):
59+
protocol: Protocol | None = None,
60+
) -> pd.DataFrame:
5261
"""
5362
Get metadata (header) of the given query (only for pandas)
5463
@@ -75,7 +84,7 @@ def partition_sql(
7584
partition_on: str,
7685
partition_num: int,
7786
partition_range: tuple[int, int] | None = None,
78-
):
87+
) -> list[str]:
7988
"""
8089
Partition the sql query
8190
@@ -110,7 +119,7 @@ def read_sql_pandas(
110119
partition_on: str | None = None,
111120
partition_range: tuple[int, int] | None = None,
112121
partition_num: int | None = None,
113-
):
122+
) -> pd.DataFrame:
114123
"""
115124
Run the SQL query, download the data from database into a dataframe.
116125
First several parameters are in the same name and order with `pandas.read_sql`.
@@ -142,17 +151,103 @@ def read_sql_pandas(
142151
)
143152

144153

154+
# default return pd.DataFrame
155+
@overload
145156
def read_sql(
146157
conn: str | dict[str, str],
147158
query: list[str] | str,
148159
*,
149-
return_type: str = "pandas",
150160
protocol: str | None = None,
151161
partition_on: str | None = None,
152162
partition_range: tuple[int, int] | None = None,
153163
partition_num: int | None = None,
154164
index_col: str | None = None,
155-
):
165+
) -> pd.DataFrame: ...
166+
167+
168+
@overload
169+
def read_sql(
170+
conn: str | dict[str, str],
171+
query: list[str] | str,
172+
*,
173+
return_type: Literal["pandas"],
174+
protocol: str | None = None,
175+
partition_on: str | None = None,
176+
partition_range: tuple[int, int] | None = None,
177+
partition_num: int | None = None,
178+
index_col: str | None = None,
179+
) -> pd.DataFrame: ...
180+
181+
182+
@overload
183+
def read_sql(
184+
conn: str | dict[str, str],
185+
query: list[str] | str,
186+
*,
187+
return_type: Literal["arrow", "arrow2"],
188+
protocol: str | None = None,
189+
partition_on: str | None = None,
190+
partition_range: tuple[int, int] | None = None,
191+
partition_num: int | None = None,
192+
index_col: str | None = None,
193+
) -> pa.Table: ...
194+
195+
196+
@overload
197+
def read_sql(
198+
conn: str | dict[str, str],
199+
query: list[str] | str,
200+
*,
201+
return_type: Literal["modin"],
202+
protocol: str | None = None,
203+
partition_on: str | None = None,
204+
partition_range: tuple[int, int] | None = None,
205+
partition_num: int | None = None,
206+
index_col: str | None = None,
207+
) -> mpd.DataFrame: ...
208+
209+
210+
@overload
211+
def read_sql(
212+
conn: str | dict[str, str],
213+
query: list[str] | str,
214+
*,
215+
return_type: Literal["dask"],
216+
protocol: str | None = None,
217+
partition_on: str | None = None,
218+
partition_range: tuple[int, int] | None = None,
219+
partition_num: int | None = None,
220+
index_col: str | None = None,
221+
) -> dd.DataFrame: ...
222+
223+
224+
@overload
225+
def read_sql(
226+
conn: str | dict[str, str],
227+
query: list[str] | str,
228+
*,
229+
return_type: Literal["polars", "polars2"],
230+
protocol: str | None = None,
231+
partition_on: str | None = None,
232+
partition_range: tuple[int, int] | None = None,
233+
partition_num: int | None = None,
234+
index_col: str | None = None,
235+
) -> pl.DataFrame: ...
236+
237+
238+
def read_sql(
239+
conn: str | dict[str, str],
240+
query: list[str] | str,
241+
*,
242+
return_type: Literal[
243+
"pandas", "polars", "polars2", "arrow", "arrow2", "modin", "dask"
244+
] = "pandas",
245+
protocol: str | None = None,
246+
partition_on: str | None = None,
247+
partition_range: tuple[int, int] | None = None,
248+
partition_num: int | None = None,
249+
index_col: str | None = None,
250+
) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table:
156251
"""
157252
Run the SQL query, download the data from database into a dataframe.
158253
@@ -318,7 +413,9 @@ def read_sql(
318413
return df
319414

320415

321-
def reconstruct_arrow(result: tuple[list[str], list[list[tuple[int, int]]]]):
416+
def reconstruct_arrow(
417+
result: tuple[list[str], list[list[tuple[int, int]]]],
418+
) -> pa.Table:
322419
import pyarrow as pa
323420

324421
names, ptrs = result
@@ -334,7 +431,7 @@ def reconstruct_arrow(result: tuple[list[str], list[list[tuple[int, int]]]]):
334431
return pa.Table.from_batches(rbs)
335432

336433

337-
def reconstruct_pandas(df_infos: dict[str, Any]):
434+
def reconstruct_pandas(df_infos: dict[str, Any]) -> pd.DataFrame:
338435
import pandas as pd
339436

340437
data = df_infos["data"]
@@ -388,6 +485,6 @@ def remove_ending_semicolon(query: str) -> str:
388485
SQL query
389486
390487
"""
391-
if query.endswith(';'):
488+
if query.endswith(";"):
392489
query = query[:-1]
393490
return query

0 commit comments

Comments
 (0)