|
| 1 | +from _typeshed import Incomplete |
| 2 | +from collections.abc import Callable |
| 3 | +from typing import Any, TypeVar |
| 4 | +from typing_extensions import Self |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import tensorflow as tf |
| 8 | +from tensorflow.core.example.example_pb2 import Example as Example |
| 9 | +from tensorflow.core.example.feature_pb2 import ( |
| 10 | + BytesList as BytesList, |
| 11 | + Feature as Feature, |
| 12 | + Features as Features, |
| 13 | + FloatList as FloatList, |
| 14 | + Int64List as Int64List, |
| 15 | +) |
| 16 | +from tensorflow.core.protobuf.cluster_pb2 import ClusterDef as ClusterDef |
| 17 | +from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef as ServerDef |
| 18 | +from tensorflow.python.trackable.base import Trackable |
| 19 | + |
| 20 | +class CheckpointOptions: |
| 21 | + experimental_io_device: None | str |
| 22 | + experimental_enable_async_checkpoint: bool |
| 23 | + # Uncomment when the stubs' TF version is updated to 2.15 |
| 24 | + # experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] |
| 25 | + enable_async: bool |
| 26 | + |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + experimental_io_device: None | str = None, |
| 30 | + experimental_enable_async_checkpoint: bool = False, |
| 31 | + # Uncomment when the stubs' TF version is updated to 2.15 |
| 32 | + # experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] = None, |
| 33 | + enable_async: bool = False, |
| 34 | + ) -> None: ... |
| 35 | + |
| 36 | +_T = TypeVar("_T", bound=list[str] | tuple[str] | dict[int, str]) |
| 37 | + |
| 38 | +class ClusterSpec: |
| 39 | + def __init__(self, cluster: dict[str, _T] | ClusterDef | ClusterSpec) -> None: ... |
| 40 | + def as_dict(self) -> dict[str, list[str] | tuple[str] | dict[int, str]]: ... |
| 41 | + def num_tasks(self, job_name: str) -> int: ... |
| 42 | + |
| 43 | +class _CheckpointLoadStatus: |
| 44 | + def assert_consumed(self) -> Self: ... |
| 45 | + def assert_existing_objects_matched(self) -> Self: ... |
| 46 | + def assert_nontrivial_match(self) -> Self: ... |
| 47 | + def expect_partial(self) -> Self: ... |
| 48 | + |
| 49 | +class Checkpoint: |
| 50 | + def __init__(self, root: Trackable | None = None, **kwargs: Trackable) -> None: ... |
| 51 | + def read(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... |
| 52 | + def restore(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ... |
| 53 | + def save(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ... |
| 54 | + # def sync(self) -> None: ... # Uncomment when the stubs' TF version is updated to 2.15 |
| 55 | + def write(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ... |
| 56 | + |
| 57 | +class CheckpointManager: |
| 58 | + def __init__( |
| 59 | + self, |
| 60 | + checkpoint: Checkpoint, |
| 61 | + directory: str, |
| 62 | + max_to_keep: int, |
| 63 | + keep_checkpoint_every_n_hours: int | None = None, |
| 64 | + checkpoint_name: str = "ckpt", |
| 65 | + step_counter: tf.Variable | None = None, |
| 66 | + checkpoint_interval: int | None = None, |
| 67 | + init_fn: Callable[[], object] | None = None, |
| 68 | + ) -> None: ... |
| 69 | + def _sweep(self) -> None: ... |
| 70 | + |
| 71 | +def latest_checkpoint(checkpoint_dir: str, latest_filename: str | None = None) -> str: ... |
| 72 | +def load_variable(ckpt_dir_or_file: str, name: str) -> np.ndarray[Any, Any]: ... |
| 73 | +def list_variables(ckpt_dir_or_file: str) -> list[tuple[str, list[int]]]: ... |
| 74 | +def __getattr__(name: str) -> Incomplete: ... |
0 commit comments