Skip to content

[Typing][B-93] Add type annotations for python/paddle/reader/decorator.py #66305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 110 additions & 23 deletions python/paddle/reader/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import itertools
import logging
import multiprocessing
Expand All @@ -21,9 +23,27 @@
from itertools import zip_longest
from queue import Queue
from threading import Thread
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generator,
Sequence,
TypedDict,
TypeVar,
overload,
)

from typing_extensions import NotRequired, TypeAlias, Unpack

from paddle.base.reader import QUEUE_GET_TIMEOUT

if TYPE_CHECKING:

class _ComposeOptions(TypedDict):
check_alignment: NotRequired[bool]


__all__ = []

# On macOS, the 'spawn' start method is now the default in Python3.8 multiprocessing,
Expand All @@ -41,8 +61,18 @@
else:
fork_context = multiprocessing

_T = TypeVar('_T')
_T1 = TypeVar('_T1')
_T2 = TypeVar('_T2')
_T3 = TypeVar('_T3')
_T4 = TypeVar('_T4')
_U = TypeVar('_U')


def cache(reader):
_Reader: TypeAlias = Callable[[], Generator[_T, None, None]]


def cache(reader: _Reader[_T]) -> _Reader[_T]:
"""
Cache the reader data into memory.

Expand Down Expand Up @@ -77,12 +107,60 @@ def cache(reader):
"""
all_data = tuple(reader())

def __impl__():
def __impl__() -> Generator[_T, None, None]:
yield from all_data

return __impl__


# A temporary solution like builtin map function.
# `Map` maybe the final solution in the future.
# See https://github.com/python/typing/issues/1383
@overload
def map_readers(
func: Callable[[_T1], _U], reader1: _Reader[_T1], /
) -> _Reader[_U]:
...


@overload
def map_readers(
func: Callable[[_T1, _T2], _U],
reader1: _Reader[_T1],
reader2: _Reader[_T2],
/,
) -> _Reader[_U]:
...


@overload
def map_readers(
func: Callable[[_T1, _T2, _T3], _U],
reader1: _Reader[_T1],
reader2: _Reader[_T2],
reader3: _Reader[_T3],
/,
) -> _Reader[_U]:
...


@overload
def map_readers(
func: Callable[[_T1, _T2, _T3, _T4], _U],
reader1: _Reader[_T1],
reader2: _Reader[_T2],
reader3: _Reader[_T3],
reader4: _Reader[_T4],
/,
) -> _Reader[_U]:
...


@overload
def map_readers(func: Callable[..., _U], *readers: _Reader[Any]) -> _Reader[_U]:
...


def map_readers(func, *readers):
"""
Creates a data reader that outputs return value of function using
Expand Down Expand Up @@ -124,7 +202,7 @@ def reader():
return reader


def shuffle(reader, buf_size):
def shuffle(reader: _Reader[_T], buf_size: int) -> _Reader[_T]:
"""
This API creates a decorated reader that outputs the shuffled data.

Expand All @@ -151,7 +229,7 @@ def shuffle(reader, buf_size):
>>> # outputs are 0~4 unordered arrangement
"""

def data_reader():
def data_reader() -> Generator[_T, None, None]:
buf = []
for e in reader():
buf.append(e)
Expand All @@ -169,7 +247,7 @@ def data_reader():
return data_reader


def chain(*readers):
def chain(*readers: _Reader[_T]) -> _Reader[_T]:
"""
Use the input data readers to create a chained data reader. The new created reader
chains the outputs of input readers together as its output, and it do not change
Expand Down Expand Up @@ -218,8 +296,8 @@ def chain(*readers):

"""

def reader():
rs = []
def reader() -> Generator[_T, None, None]:
rs: list[Generator[_T, None, None]] = []
for r in readers:
rs.append(r())

Expand All @@ -232,7 +310,9 @@ class ComposeNotAligned(ValueError):
pass


def compose(*readers, **kwargs):
def compose(
*readers: _Reader[Any], **kwargs: Unpack[_ComposeOptions]
) -> _Reader[Any]:
"""
Creates a data reader whose output is the combination of input readers.

Expand Down Expand Up @@ -289,7 +369,7 @@ def reader():
return reader


def buffered(reader, size):
def buffered(reader: _Reader[_T], size: int) -> _Reader[_T]:
"""
Creates a buffered data reader.

Expand Down Expand Up @@ -339,10 +419,7 @@ def data_reader():
q = Queue(maxsize=size)
t = Thread(
target=read_worker,
args=(
r,
q,
),
args=(r, q),
)
t.daemon = True
t.start()
Expand All @@ -354,7 +431,7 @@ def data_reader():
return data_reader


def firstn(reader, n):
def firstn(reader: _Reader[_T], n: int) -> _Reader[_T]:
"""

This API creates a decorated reader, and limits the max number of
Expand Down Expand Up @@ -399,7 +476,13 @@ class XmapEndSignal:
pass


def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
def xmap_readers(
mapper: Callable[[_T], _U],
reader: _Reader[_T],
process_num: int,
buffer_size: int,
order: bool = False,
) -> _Reader[_U]:
"""
Use multi-threads to map samples from reader by a mapper defined by user.

Expand Down Expand Up @@ -495,7 +578,11 @@ def xreader():
return xreader


def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
def multiprocess_reader(
readers: Sequence[_Reader[_T]],
use_pipe: bool = True,
queue_size: int = 1000,
) -> _Reader[list[_T]]:
"""
This API use python ``multiprocessing`` to read data from ``readers`` parallelly,
and then ``multiprocess.Queue`` or ``multiprocess.Pipe`` is used to merge
Expand All @@ -508,13 +595,13 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
in some platforms.

Parameters:
readers (list( ``generator`` ) | tuple( ``generator`` )): a python ``generator`` list
used to read input data
use_pipe (bool, optional): control the inner API used to implement the multi-processing,
default True - use ``multiprocess.Pipe`` which is recommended
queue_size (int, optional): only useful when ``use_pipe`` is False - ``multiprocess.Queue``
is used, default 1000. Increase this value can speed up the data reading, and more memory
will be consumed.
readers (list( ``generator`` ) | tuple( ``generator`` )): a python ``generator`` list
used to read input data
use_pipe (bool, optional): control the inner API used to implement the multi-processing,
default True - use ``multiprocess.Pipe`` which is recommended
queue_size (int, optional): only useful when ``use_pipe`` is False - ``multiprocess.Queue``
is used, default 1000. Increase this value can speed up the data reading, and more memory
will be consumed.

Returns:
``generator``: a new reader which can be run parallelly
Expand Down