Skip to content

Commit 060d3cb

Browse files
ZacharyGarrettcopybara-github
authored andcommitted
Add support for Tensor learning rates and gradients with mixed types.
PiperOrigin-RevId: 671726026
1 parent c3e36d7 commit 060d3cb

File tree

7 files changed

+96
-41
lines changed

7 files changed

+96
-41
lines changed

RELEASE.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@ and this project adheres to
2626

2727
### Fixed
2828

29-
* A bug where `tff.learning.optimizers.build_adafactor(...)` would update its
30-
step counter twice upon every invocation of `.next()`.
29+
* A bug where `tff.learning.optimizers.build_adafactor` would update its step
30+
counter twice upon every invocation of `.next()`.
31+
* A bug where tensor learning rates for `tff.learning.optimizers.build_sgdm`
32+
would fail with mixed dtype gradients.
3133

3234
### Removed
3335

tensorflow_federated/python/learning/optimizers/adagrad.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
_HPARAMS_KEYS = [optimizer.LEARNING_RATE_KEY, _EPSILON_KEY]
2828

2929
State = TypeVar('State', bound=collections.OrderedDict[str, Any])
30-
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, float])
30+
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, Any])
3131

3232

3333
class _Adagrad(optimizer.Optimizer[State, optimizer.Weights, Hparams]):
@@ -40,31 +40,35 @@ def __init__(
4040
epsilon: optimizer.Float = 1e-7,
4141
):
4242
"""Initializes SGD optimizer."""
43-
if learning_rate < 0.0:
43+
if not tf.is_symbolic_tensor(learning_rate) and learning_rate < 0.0:
4444
raise ValueError(
4545
f'Adagrad `learning_rate` must be nonnegative, found {learning_rate}.'
4646
)
47-
if initial_preconditioner_value < 0.0:
47+
if (
48+
not tf.is_symbolic_tensor(initial_preconditioner_value)
49+
and initial_preconditioner_value < 0.0
50+
):
4851
raise ValueError(
4952
'Adagrad `initial_preconditioner_value` must be nonnegative, found '
5053
f'{initial_preconditioner_value}.'
5154
)
52-
if epsilon < 0.0:
55+
if not tf.is_symbolic_tensor(epsilon) and epsilon < 0.0:
5356
raise ValueError(f'Adagrad epsilon must be nonnegative, found {epsilon}.')
5457
self._lr = learning_rate
5558
self._initial_precond = initial_preconditioner_value
5659
self._epsilon = epsilon
5760

5861
def initialize(self, specs: Any) -> State:
5962
initial_preconditioner = tf.nest.map_structure(
60-
lambda s: tf.ones(s.shape, s.dtype) * self._initial_precond, specs
63+
lambda s: tf.ones(s.shape, s.dtype)
64+
* tf.cast(self._initial_precond, s.dtype),
65+
specs,
6166
)
62-
state = collections.OrderedDict([
67+
return collections.OrderedDict([
6368
(optimizer.LEARNING_RATE_KEY, self._lr),
6469
(_EPSILON_KEY, self._epsilon),
6570
(_PRECONDITIONER_KEY, initial_preconditioner),
6671
])
67-
return state
6872

6973
def next(
7074
self, state: State, weights: optimizer.Weights, gradients: Any
@@ -82,7 +86,9 @@ def _adagrad_update(w, p, g):
8286
if g is None:
8387
return w, p
8488
p = p + tf.math.square(g)
85-
w = w - lr * g / tf.math.sqrt(p + epsilon)
89+
w = w - tf.cast(lr, g.dtype) * g / tf.math.sqrt(
90+
p + tf.cast(epsilon, p.dtype)
91+
)
8692
return w, p
8793

8894
updated_weights, updated_preconditioner = nest_utils.map_at_leaves(

tensorflow_federated/python/learning/optimizers/adagrad_test.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ def random_vector():
145145
genarator.normal(shape=s.shape, dtype=s.dtype) for s in weight_spec
146146
]
147147

148-
intial_weight = random_vector()
149-
model_variables_fn = lambda: [tf.Variable(v) for v in intial_weight]
148+
initial_weight = random_vector()
149+
model_variables_fn = lambda: [tf.Variable(v) for v in initial_weight]
150150
gradients = [random_vector() for _ in range(steps)]
151151
tff_optimizer_fn = lambda: adagrad.build_adagrad(0.01)
152152
keras_optimizer_fn = lambda: tf.keras.optimizers.Adagrad(0.01)
@@ -227,6 +227,22 @@ def test_set_get_hparams_is_no_op(self, spec):
227227
updated_state = optimizer.set_hparams(state, hparams)
228228
self.assertEqual(state, updated_state)
229229

230+
def test_lr_with_different_weight_dtypes(self):
231+
weights = (
232+
tf.constant([0.1], dtype=tf.float32),
233+
tf.constant(1.0, dtype=tf.float64),
234+
tf.constant([10.0, 10.0], dtype=tf.bfloat16),
235+
)
236+
adagrad_optimizer = adagrad.build_adagrad(
237+
learning_rate=tf.constant(0.1, dtype=tf.float32),
238+
initial_preconditioner_value=tf.constant(0.1, dtype=tf.float32),
239+
epsilon=tf.constant(0.1, dtype=tf.float64),
240+
)
241+
state = adagrad_optimizer.initialize(weights)
242+
adagrad_optimizer.next(
243+
state, weights, tf.nest.map_structure(tf.zeros_like, weights)
244+
)
245+
230246

231247
if __name__ == '__main__':
232248
tf.test.main()

tensorflow_federated/python/learning/optimizers/adam.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
]
3737

3838
State = TypeVar('State', bound=collections.OrderedDict[str, Any])
39-
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, float])
39+
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, Any])
4040

4141

4242
class _Adam(optimizer.Optimizer[State, optimizer.Weights, Hparams]):
@@ -50,19 +50,19 @@ def __init__(
5050
epsilon: optimizer.Float = 1e-7,
5151
):
5252
"""Initializes Adam optimizer."""
53-
if learning_rate < 0.0:
53+
if not tf.is_symbolic_tensor(learning_rate) and learning_rate < 0.0:
5454
raise ValueError(
5555
f'Adam `learning_rate` must be nonnegative, found {learning_rate}.'
5656
)
57-
if beta_1 < 0.0 or beta_1 > 1.0:
57+
if not tf.is_symbolic_tensor(beta_1) and (beta_1 < 0.0 or beta_1 > 1.0):
5858
raise ValueError(
5959
f'Adam `beta_1` must be in the range [0.0, 1.0], found {beta_1}.'
6060
)
61-
if beta_2 < 0.0 or beta_2 > 1.0:
61+
if not tf.is_symbolic_tensor(beta_2) and (beta_2 < 0.0 or beta_2 > 1.0):
6262
raise ValueError(
6363
f'Adam `beta_2` must be in the range [0.0, 1.0], found {beta_2}.'
6464
)
65-
if epsilon < 0.0:
65+
if not tf.is_symbolic_tensor(epsilon) and epsilon < 0.0:
6666
raise ValueError(f'Adam `epsilon` must be nonnegative, found {epsilon}.')
6767
self._lr = learning_rate
6868
self._beta_1 = beta_1
@@ -76,7 +76,7 @@ def initialize(self, specs: Any) -> State:
7676
initial_preconditioner = tf.nest.map_structure(
7777
lambda s: tf.zeros(s.shape, s.dtype), specs
7878
)
79-
state = collections.OrderedDict([
79+
return collections.OrderedDict([
8080
(optimizer.LEARNING_RATE_KEY, self._lr),
8181
(_BETA_1_KEY, self._beta_1),
8282
(_BETA_2_KEY, self._beta_2),
@@ -85,7 +85,6 @@ def initialize(self, specs: Any) -> State:
8585
(_ACCUMULATOR_KEY, initial_accumulator),
8686
(_PRECONDITIONER_KEY, initial_preconditioner),
8787
])
88-
return state
8988

9089
def next(
9190
self, state: State, weights: optimizer.Weights, gradients: Any
@@ -103,18 +102,24 @@ def next(
103102
optimizer.check_weights_state_match(
104103
weights, preconditioner, 'preconditioner'
105104
)
105+
if tf.is_tensor(beta_1):
106+
casted_step = tf.cast(step, beta_1.dtype)
107+
else:
108+
casted_step = step
106109
normalized_lr = (
107110
lr
108-
* tf.math.sqrt((1 - tf.math.pow(beta_2, tf.cast(step, tf.float32))))
109-
/ (1 - tf.math.pow(beta_1, tf.cast(step, tf.float32)))
111+
* tf.math.sqrt((1.0 - tf.math.pow(beta_2, casted_step)))
112+
/ (1.0 - tf.math.pow(beta_1, casted_step))
110113
)
111114

112115
def _adam_update(w, a, p, g):
113116
if g is None:
114117
return w, a, p
115-
a = a + (g - a) * (1 - beta_1)
116-
p = p + (tf.math.square(g) - p) * (1 - beta_2)
117-
w = w - normalized_lr * a / (tf.math.sqrt(p) + epsilon)
118+
a = a + (g - a) * (1 - tf.cast(beta_1, a.dtype))
119+
p = p + (tf.math.square(g) - p) * (1 - tf.cast(beta_2, p.dtype))
120+
w = w - tf.cast(normalized_lr, a.dtype) * a / (
121+
tf.math.sqrt(p) + tf.cast(epsilon, p.dtype)
122+
)
118123
return w, a, p
119124

120125
updated_weights, updated_accumulator, updated_preconditioner = (

tensorflow_federated/python/learning/optimizers/adam_test.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ def test_math(self):
5555
for _ in range(4):
5656
state, weights = optimizer.next(state, weights, gradients)
5757
history.append(weights)
58-
self.assertAllClose(
59-
[[1.0], [0.9000007], [0.8000017], [0.700002], [0.600003]], history
60-
)
58+
self.assertAllClose([[1.0], [0.9], [0.8], [0.7], [0.6]], history)
6159

6260
@parameterized.named_parameters(
6361
('scalar_spec', _SCALAR_SPEC),
@@ -142,8 +140,8 @@ def random_vector():
142140
genarator.normal(shape=s.shape, dtype=s.dtype) for s in weight_spec
143141
]
144142

145-
intial_weight = random_vector()
146-
model_variables_fn = lambda: [tf.Variable(v) for v in intial_weight]
143+
initial_weight = random_vector()
144+
model_variables_fn = lambda: [tf.Variable(v) for v in initial_weight]
147145
gradients = [random_vector() for _ in range(steps)]
148146
tff_optimizer_fn = lambda: adam.build_adam(0.01, 0.9, 0.999)
149147
keras_optimizer_fn = lambda: tf.keras.optimizers.Adam(0.01, 0.9, 0.999)
@@ -225,6 +223,23 @@ def test_set_get_hparams_is_no_op(self, spec):
225223
updated_state = optimizer.set_hparams(state, hparams)
226224
self.assertEqual(state, updated_state)
227225

226+
def test_lr_with_different_weight_dtypes(self):
227+
weights = (
228+
tf.constant([0.1], dtype=tf.float32),
229+
tf.constant(1.0, dtype=tf.float64),
230+
tf.constant([10.0, 10.0], dtype=tf.bfloat16),
231+
)
232+
adam_optimizer = adam.build_adam(
233+
learning_rate=tf.constant(0.1, dtype=tf.float32),
234+
beta_1=tf.constant(0.1, dtype=tf.float32),
235+
beta_2=tf.constant(0.1, dtype=tf.float32),
236+
epsilon=tf.constant(0.1, dtype=tf.float64),
237+
)
238+
state = adam_optimizer.initialize(weights)
239+
adam_optimizer.next(
240+
state, weights, tf.nest.map_structure(tf.zeros_like, weights)
241+
)
242+
228243

229244
if __name__ == '__main__':
230245
tf.test.main()

tensorflow_federated/python/learning/optimizers/sgdm.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
_ACCUMULATOR_KEY = 'accumulator'
2727

2828
State = TypeVar('State', bound=collections.OrderedDict[str, Any])
29-
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, float])
29+
Hparams = TypeVar('Hparams', bound=collections.OrderedDict[str, Any])
3030

3131

3232
class _SGD(optimizer.Optimizer[State, optimizer.Weights, Hparams]):
@@ -38,14 +38,16 @@ def __init__(
3838
momentum: Optional[optimizer.Float] = None,
3939
):
4040
"""Initializes SGD optimizer."""
41-
if learning_rate < 0.0:
41+
if not tf.is_symbolic_tensor(learning_rate) and learning_rate < 0.0:
4242
raise ValueError(
4343
f'SGD `learning_rate` must be nonnegative, found {learning_rate}.'
4444
)
4545
if momentum:
4646
# We should only track momentum as a hparam in the case that it is both
4747
# specified and nonzero.
48-
if momentum < 0.0 or momentum > 1.0:
48+
if not tf.is_symbolic_tensor(momentum) and (
49+
momentum < 0.0 or momentum > 1.0
50+
):
4951
raise ValueError(
5052
'SGD `momentum` must be `None` or in the range [0, 1], found '
5153
f'{momentum}.'
@@ -77,7 +79,7 @@ def next(
7779
def _sgd_update(w, g):
7880
if g is None:
7981
return w
80-
return w - lr * g
82+
return w - tf.cast(lr, dtype=g.dtype) * g
8183

8284
updated_weights = nest_utils.map_at_leaves(
8385
_sgd_update, weights, gradients
@@ -111,11 +113,6 @@ def get_hparams(self, state: State) -> Hparams:
111113
return collections.OrderedDict([(k, state[k]) for k in self._hparams_keys])
112114

113115
def set_hparams(self, state: State, hparams: Hparams) -> State:
114-
# TODO: b/245962555 - Find an alternative to `update_struct` if it
115-
# interferes with typing guarantees.
116-
# We use `tff.structure.update_struct` (rather than something like
117-
# `copy.deepcopy`) to ensure that this can be called within a
118-
# `tff.Computation`.
119116
return structure.update_struct(state, **hparams)
120117

121118

tensorflow_federated/python/learning/optimizers/sgdm_test.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_get_hparams_momentum(self, momentum_value):
4343
optimizer = sgdm.build_sgdm(0.01, momentum=momentum_value)
4444
state = optimizer.initialize(_SCALAR_SPEC)
4545
hparams = optimizer.get_hparams(state)
46-
# Whether we specify None momentum or momentum 0.0, we shouldnt track the
46+
# Whether we specify None momentum or momentum 0.0, we shouldn't track the
4747
# extra accumulator state. The implementation of next checks for the
4848
# presence or absence of momentum key--it should not be there in either
4949
# case.
@@ -177,8 +177,8 @@ def random_vector():
177177
genarator.normal(shape=s.shape, dtype=s.dtype) for s in weight_spec
178178
]
179179

180-
intial_weight = random_vector()
181-
model_variables_fn = lambda: [tf.Variable(v) for v in intial_weight]
180+
initial_weight = random_vector()
181+
model_variables_fn = lambda: [tf.Variable(v) for v in initial_weight]
182182
gradients = [random_vector() for _ in range(steps)]
183183
tff_optimizer_fn = lambda: sgdm.build_sgdm(learning_rate, momentum)
184184

@@ -306,6 +306,20 @@ def test_set_get_hparams_is_no_op_with_momentum(self, spec):
306306
updated_state = optimizer.set_hparams(state, hparams)
307307
self.assertEqual(state, updated_state)
308308

309+
def test_lr_with_different_weight_dtypes(self):
310+
weights = (
311+
tf.constant([0.1], dtype=tf.float32),
312+
tf.constant(1.0, dtype=tf.float64),
313+
tf.constant([10.0, 10.0], dtype=tf.bfloat16),
314+
)
315+
sgdm_optimizer = sgdm.build_sgdm(
316+
learning_rate=tf.constant(0.1, dtype=tf.float32)
317+
)
318+
state = sgdm_optimizer.initialize(weights)
319+
sgdm_optimizer.next(
320+
state, weights, tf.nest.map_structure(tf.zeros_like, weights)
321+
)
322+
309323

310324
if __name__ == '__main__':
311325
tf.test.main()

0 commit comments

Comments
 (0)