Skip to content

Commit 587e75f

Browse files
authored
tensorflow: add tf.train.CheckpointOptions and other tf.train members. (#11327)
1 parent 547cbc7 commit 587e75f

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

stubs/tensorflow/@tests/stubtest_allowlist.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,12 @@ tensorflow.io.SparseFeature.__new__
8585

8686
# Metaclass inconsistency. The runtime metaclass is defined from c++ extension and is undocumented.
8787
tensorflow.io.TFRecordWriter
88+
89+
# stubtest does not pass for protobuf generated stubs.
90+
tensorflow.train.Example.*
91+
tensorflow.train.BytesList.*
92+
tensorflow.train.Feature.*
93+
tensorflow.train.FloatList.*
94+
tensorflow.train.Int64List.*
95+
tensorflow.train.ClusterDef.*
96+
tensorflow.train.ServerDef.*
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from _typeshed import Incomplete
2+
from collections.abc import Callable
3+
from typing import Any, TypeVar
4+
from typing_extensions import Self
5+
6+
import numpy as np
7+
import tensorflow as tf
8+
from tensorflow.core.example.example_pb2 import Example as Example
9+
from tensorflow.core.example.feature_pb2 import (
10+
BytesList as BytesList,
11+
Feature as Feature,
12+
Features as Features,
13+
FloatList as FloatList,
14+
Int64List as Int64List,
15+
)
16+
from tensorflow.core.protobuf.cluster_pb2 import ClusterDef as ClusterDef
17+
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef as ServerDef
18+
from tensorflow.python.trackable.base import Trackable
19+
20+
class CheckpointOptions:
21+
experimental_io_device: None | str
22+
experimental_enable_async_checkpoint: bool
23+
# Uncomment when the stubs' TF version is updated to 2.15
24+
# experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]]
25+
enable_async: bool
26+
27+
def __init__(
28+
self,
29+
experimental_io_device: None | str = None,
30+
experimental_enable_async_checkpoint: bool = False,
31+
# Uncomment when the stubs' TF version is updated to 2.15
32+
# experimental_write_callbacks: None | list[Callable[[str], Any] | Callable[[], Any]] = None,
33+
enable_async: bool = False,
34+
) -> None: ...
35+
36+
_T = TypeVar("_T", bound=list[str] | tuple[str] | dict[int, str])
37+
38+
class ClusterSpec:
39+
def __init__(self, cluster: dict[str, _T] | ClusterDef | ClusterSpec) -> None: ...
40+
def as_dict(self) -> dict[str, list[str] | tuple[str] | dict[int, str]]: ...
41+
def num_tasks(self, job_name: str) -> int: ...
42+
43+
class _CheckpointLoadStatus:
44+
def assert_consumed(self) -> Self: ...
45+
def assert_existing_objects_matched(self) -> Self: ...
46+
def assert_nontrivial_match(self) -> Self: ...
47+
def expect_partial(self) -> Self: ...
48+
49+
class Checkpoint:
50+
def __init__(self, root: Trackable | None = None, **kwargs: Trackable) -> None: ...
51+
def read(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ...
52+
def restore(self, save_path: str, options: CheckpointOptions | None = None) -> _CheckpointLoadStatus: ...
53+
def save(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ...
54+
# def sync(self) -> None: ... # Uncomment when the stubs' TF version is updated to 2.15
55+
def write(self, file_prefix: str, options: CheckpointOptions | None = None) -> str: ...
56+
57+
class CheckpointManager:
58+
def __init__(
59+
self,
60+
checkpoint: Checkpoint,
61+
directory: str,
62+
max_to_keep: int,
63+
keep_checkpoint_every_n_hours: int | None = None,
64+
checkpoint_name: str = "ckpt",
65+
step_counter: tf.Variable | None = None,
66+
checkpoint_interval: int | None = None,
67+
init_fn: Callable[[], object] | None = None,
68+
) -> None: ...
69+
def _sweep(self) -> None: ...
70+
71+
def latest_checkpoint(checkpoint_dir: str, latest_filename: str | None = None) -> str: ...
72+
def load_variable(ckpt_dir_or_file: str, name: str) -> np.ndarray[Any, Any]: ...
73+
def list_variables(ckpt_dir_or_file: str) -> list[tuple[str, list[int]]]: ...
74+
def __getattr__(name: str) -> Incomplete: ...
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import abc
2+
from _typeshed import Incomplete
3+
from typing_extensions import Self
4+
5+
from tensorflow.python.trackable.base import Trackable
6+
7+
class PythonState(Trackable, metaclass=abc.ABCMeta):
8+
@abc.abstractmethod
9+
def serialize(self) -> str: ...
10+
@abc.abstractmethod
11+
def deserialize(self, string_value: str) -> Self: ...
12+
13+
def __getattr__(name: str) -> Incomplete: ...

0 commit comments

Comments
 (0)