Skip to content

Commit c9047eb

Browse files
michaelreneercopybara-github
authored andcommitted
Remove usages of tff.structure.update_struct in examples.
We want to remove this functionality from the API. PiperOrigin-RevId: 746124000
1 parent 0c17626 commit c9047eb

File tree

2 files changed

+11
-22
lines changed

2 files changed

+11
-22
lines changed

examples/simple_fedavg/simple_fedavg_tf.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,15 @@
2626
https://arxiv.org/abs/1602.05629
2727
"""
2828

29-
from typing import Any
29+
from typing import Any, NamedTuple
3030

31-
import attrs
3231
import numpy as np
3332
import tensorflow as tf
3433
import tensorflow_federated as tff
3534

3635

3736
# TODO: b/295181362 - Update from `Any` to a more specific type.
38-
@attrs.define(eq=False, frozen=True)
39-
class ClientOutput:
37+
class ClientOutput(NamedTuple):
4038
"""Structure for outputs returned from clients during federated optimization.
4139
4240
Attributes:
@@ -53,8 +51,7 @@ class ClientOutput:
5351
model_output: Any
5452

5553

56-
@attrs.define(eq=False, frozen=True)
57-
class ServerState:
54+
class ServerState(NamedTuple):
5855
"""Structure for state on the server.
5956
6057
Attributes:
@@ -69,8 +66,7 @@ class ServerState:
6966
round_num: int
7067

7168

72-
@attrs.define(eq=False, frozen=True)
73-
class BroadcastMessage:
69+
class BroadcastMessage(NamedTuple):
7470
"""Structure for tensors broadcasted by server during federated optimization.
7571
7672
Attributes:
@@ -119,8 +115,7 @@ def server_update(model, server_optimizer, server_state, weights_delta):
119115
)
120116

121117
# Create a new state based on the updated model.
122-
return tff.structure.update_struct(
123-
server_state,
118+
return ServerState(
124119
model=model_weights,
125120
optimizer_state=server_optimizer.variables(),
126121
round_num=server_state.round_num + 1,

examples/stateful_clients/stateful_fedavg_tf.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717
"""
1818

1919
import collections
20-
from typing import Any, Union
20+
from typing import Any, NamedTuple, Union
2121

22-
import attrs
2322
import tensorflow as tf
2423
import tensorflow_federated as tff
2524

@@ -98,8 +97,7 @@ def keras_evaluate(model, test_data, metric):
9897
return metric.result()
9998

10099

101-
@attrs.define(eq=False, frozen=True)
102-
class ClientState:
100+
class ClientState(NamedTuple):
103101
"""Structure for state on the client.
104102
105103
Fields:
@@ -113,8 +111,7 @@ class ClientState:
113111
iters_count: int
114112

115113

116-
@attrs.define(eq=False, frozen=True)
117-
class ClientOutput:
114+
class ClientOutput(NamedTuple):
118115
"""Structure for outputs returned from clients during federated optimization.
119116
120117
Fields:
@@ -135,8 +132,7 @@ class ClientOutput:
135132
client_state: ClientState
136133

137134

138-
@attrs.define(eq=False, frozen=True)
139-
class ServerState:
135+
class ServerState(NamedTuple):
140136
"""Structure for state on the server.
141137
142138
Fields:
@@ -152,8 +148,7 @@ class ServerState:
152148
total_iters_count: int
153149

154150

155-
@attrs.define(eq=False, frozen=True)
156-
class BroadcastMessage:
151+
class BroadcastMessage(NamedTuple):
157152
"""Structure for tensors broadcasted by server during federated optimization.
158153
159154
Fields:
@@ -204,8 +199,7 @@ def server_update(
204199
)
205200

206201
# Create a new state based on the updated model.
207-
return tff.structure.update_struct(
208-
server_state,
202+
return ServerState(
209203
model_weights=model_weights,
210204
optimizer_state=server_optimizer.variables(),
211205
round_num=server_state.round_num + 1,

0 commit comments

Comments
 (0)