Skip to content

Commit b049689

Browse files
michaelreneercopybara-github
authored andcommitted
Migrate calls to tff.structure.to_elements to use the API on federated_language.StructType or federated_language.framework.Struct.
PiperOrigin-RevId: 746173864
1 parent 24ac2b3 commit b049689

File tree

10 files changed

+24
-30
lines changed

10 files changed

+24
-30
lines changed

tensorflow_federated/python/core/backends/mapreduce/form_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,7 +1329,7 @@ def _has_placement(type_spec):
13291329
# Add a reference to the injected broadcast call in the result so that it
13301330
# does not get pruned by various tree transformations. We will remove this
13311331
# additional element in the result after the first split operation.
1332-
revised_block_result = structure.to_elements(comp_tree.result.result) + [
1332+
revised_block_result = list(comp_tree.result.result.items()) + [
13331333
federated_language.framework.Reference(
13341334
'injected_broadcast_ref',
13351335
injected_broadcast.type_signature,
@@ -1438,7 +1438,7 @@ def _unnest_lambda_parameter(comp):
14381438
federated_language.framework.Block(
14391439
after_broadcast.result.locals,
14401440
federated_language.framework.Struct(
1441-
structure.to_elements(after_broadcast.result.result)[:-1]
1441+
list(after_broadcast.result.result.items())[:-1]
14421442
),
14431443
),
14441444
)

tensorflow_federated/python/core/environments/jax_frontend/jax_serialization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _make(
134134
return obj, next_unused_tensor_index
135135
elif isinstance(type_spec, federated_language.StructType):
136136
elements = []
137-
for k, v in structure.to_elements(type_spec):
137+
for k, v in type_spec.items():
138138
obj, next_unused_tensor_index = _make(v, next_unused_tensor_index)
139139
elements.append((k, obj))
140140
obj = _XlaSerializerStructArg(type_spec, elements)

tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def _stamp_value_into_graph(
6060
if isinstance(value, (list, dict)):
6161
value = structure.from_container(value)
6262
stamped_elements = []
63-
named_type_signatures = structure.to_elements(type_signature)
63+
named_type_signatures = type_signature.items()
6464
for (name, type_signature), element in zip(named_type_signatures, value):
6565
stamped_element = _stamp_value_into_graph(element, type_signature, graph)
6666
stamped_elements.append((name, stamped_element))

tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def stamp_parameter_in_graph(parameter_name, parameter_type, graph):
9494
elif isinstance(parameter_type, federated_language.StructType):
9595
# The parameter_type could be a StructTypeWithPyContainer, however, we
9696
# ignore that for now. Instead, the proper containers will be inserted at
97-
# call time by federated_language.framework.wrap_as_zero_or_one_arg_callable.
97+
# call time by federated_language.framework.wrap_as_zero_or_one_arg
98+
# _callable.
9899
if not parameter_type:
99100
# Stamps whimsy element to "populate" graph, as TensorFlow does not
100101
# support empty graphs.
@@ -533,7 +534,7 @@ def _assemble_result_from_graph(type_spec, binding, output_map):
533534
'Expected a struct binding, found {}.'.format(binding_oneof)
534535
)
535536
else:
536-
type_elements = structure.to_elements(type_spec)
537+
type_elements = type_spec.items()
537538
if len(binding.struct.element) != len(type_elements):
538539
raise ValueError(
539540
'Mismatching tuple sizes in type ({}) and binding ({}).'.format(
@@ -601,7 +602,7 @@ def _make_empty_list_structure_for_element_type_spec(type_spec):
601602
if isinstance(type_spec, federated_language.TensorType):
602603
return []
603604
elif isinstance(type_spec, federated_language.StructType):
604-
elements = structure.to_elements(type_spec)
605+
elements = type_spec.items()
605606
if all(k is not None for k, _ in elements):
606607
return collections.OrderedDict([
607608
(k, _make_empty_list_structure_for_element_type_spec(v))
@@ -684,7 +685,7 @@ def _handle_none_dimension(x):
684685
return np.empty(whimsy_shape, dtype=np.str_)
685686
return np.zeros(whimsy_shape, type_spec.dtype)
686687
elif isinstance(type_spec, federated_language.StructType):
687-
elements = structure.to_elements(type_spec)
688+
elements = type_spec.items()
688689
elem_list = []
689690
for _, elem_type in elements:
690691
elem_list.append(_make_whimsy_element_for_type_spec(elem_type))
@@ -735,7 +736,7 @@ def _append_to_list_structure_for_element_type_spec(nested, value, type_spec):
735736
# tf.data.Dataset.from_tensor_slices.
736737
nested.append(tf.convert_to_tensor(value, type_spec.dtype))
737738
elif isinstance(type_spec, federated_language.StructType):
738-
elements = structure.to_elements(type_spec)
739+
elements = type_spec.items()
739740
if isinstance(nested, collections.OrderedDict):
740741
if isinstance(value, py_typecheck.SupportsNamedTuple):
741742
# In Python 3.8 and later `_asdict` no longer return OrdereDict, rather
@@ -814,7 +815,7 @@ def _replace_empty_leaf_lists_with_numpy_arrays(lists, type_spec):
814815
else:
815816
return np.array([], dtype=type_spec.dtype)
816817
elif isinstance(type_spec, federated_language.StructType):
817-
elements = structure.to_elements(type_spec)
818+
elements = type_spec.items()
818819
if isinstance(lists, collections.OrderedDict):
819820
to_return = []
820821
for elem_name, elem_type in elements:
@@ -1074,7 +1075,7 @@ def _to_representative_value(type_spec, elements):
10741075
return py_type(*values)
10751076
return py_type(values) # pylint: disable=too-many-function-args
10761077
elif isinstance(type_spec, federated_language.StructType):
1077-
field_types = structure.to_elements(type_spec)
1078+
field_types = type_spec.items()
10781079
is_all_named = all([name is not None for name, _ in field_types])
10791080
if is_all_named:
10801081
if isinstance(elements, py_typecheck.SupportsNamedTuple):

tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _assert_binding_matches_type_and_value(
7070
if not isinstance(val, (list, tuple, structure.Struct)):
7171
self.assertIsInstance(val, dict)
7272
val = list(val.values())
73-
for idx, e in enumerate(structure.to_elements(type_spec)):
73+
for idx, e in enumerate(type_spec.items()):
7474
self._assert_binding_matches_type_and_value(
7575
binding.struct.element[idx], e[1], val[idx], graph, is_output
7676
)

tensorflow_federated/python/core/environments/tensorflow_backend/type_conversions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def _type_to_tf_dtypes_and_shapes(type_spec: federated_language.Type):
131131
shape = tf.TensorShape(type_spec.shape)
132132
return (type_spec.dtype, shape)
133133
elif isinstance(type_spec, federated_language.StructType):
134-
elements = structure.to_elements(type_spec)
134+
elements = type_spec.items()
135135
if not elements:
136136
output_dtypes = []
137137
output_shapes = []
@@ -238,7 +238,7 @@ def type_to_tf_structure(type_spec: federated_language.Type):
238238
if isinstance(type_spec, federated_language.TensorType):
239239
return tf.TensorSpec(type_spec.shape, type_spec.dtype)
240240
elif isinstance(type_spec, federated_language.StructType):
241-
elements = structure.to_elements(type_spec)
241+
elements = type_spec.items()
242242
if not elements:
243243
return ()
244244
element_outputs = [(k, type_to_tf_structure(v)) for k, v in elements]

tensorflow_federated/python/core/impl/compiler/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ py_test(
5151
":building_block_test_utils",
5252
":transformations",
5353
":tree_transformations",
54-
"//tensorflow_federated/python/common_libs:structure",
5554
"@federated_language//federated_language",
5655
"@federated_language//federated_language/proto:computation_py_pb2",
5756
],

tensorflow_federated/python/core/impl/compiler/transformations.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,9 +1444,7 @@ def divisive_force_align_and_split_by_intrinsics(
14441444
]
14451445

14461446
# Update the before comp result to produce the extended intermediate state.
1447-
before_result_elements = structure.to_elements(
1448-
preliminary_before_comp.result.result
1449-
)
1447+
before_result_elements = list(preliminary_before_comp.result.result.items())
14501448
intermediate_state_index_in_before_result = 1
14511449
intermediate_state_name, intermediate_state_vals = before_result_elements[
14521450
intermediate_state_index_in_before_result
@@ -1461,8 +1459,7 @@ def divisive_force_align_and_split_by_intrinsics(
14611459
for local_name, local_value in duplicated_locals
14621460
]
14631461
extended_intermediate_state_vals = federated_language.framework.Struct(
1464-
structure.to_elements(intermediate_state_vals)
1465-
+ duplicate_intermediate_state_vals
1462+
list(intermediate_state_vals.items()) + duplicate_intermediate_state_vals
14661463
)
14671464
before_result_elements[intermediate_state_index_in_before_result] = (
14681465
intermediate_state_name,

tensorflow_federated/python/core/impl/compiler/transformations_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from federated_language.proto import computation_pb2
2020
import numpy as np
2121

22-
from tensorflow_federated.python.common_libs import structure
2322
from tensorflow_federated.python.core.impl.compiler import building_block_test_utils
2423
from tensorflow_federated.python.core.impl.compiler import transformations
2524
from tensorflow_federated.python.core.impl.compiler import tree_transformations
@@ -751,7 +750,7 @@ def _check_transformed_comp_validity(
751750
transformed_comp.parameter_type, len(original_comp.parameter_type) + 1
752751
)
753752
self.assertEqual(
754-
structure.to_elements(transformed_comp.parameter_type)[
753+
transformed_comp.parameter_type.items()[
755754
len(transformed_comp.parameter_type) - 1
756755
][0],
757756
lambda_parameter_extension_name,
@@ -1011,7 +1010,7 @@ def check_split_signatures(self, original_comp, before, intrinsic, after):
10111010
before.type_signature.result, federated_language.StructType
10121011
)
10131012
self.assertEqual(
1014-
[x for x, _ in structure.to_elements(before.type_signature.result)],
1013+
[x for x, _ in before.type_signature.result.items()],
10151014
['intrinsic_args_from_before_comp', 'intermediate_state'],
10161015
)
10171016
self.assertIsInstance(
@@ -1025,7 +1024,7 @@ def check_split_signatures(self, original_comp, before, intrinsic, after):
10251024
len(intrinsic.type_signature.parameter) - 1
10261025
)
10271026
intrinsic_arg_names = [
1028-
x for x, _ in structure.to_elements(intrinsic.type_signature.parameter)
1027+
x for x, _ in intrinsic.type_signature.parameter.items()
10291028
]
10301029
self.assertEqual(
10311030
intrinsic_arg_names[intrinsic_args_from_before_comp_index],
@@ -1047,7 +1046,7 @@ def check_split_signatures(self, original_comp, before, intrinsic, after):
10471046
intrinsic_results_index = len(after.type_signature.parameter) - 2
10481047
intermediate_state_index = len(after.type_signature.parameter) - 1
10491048
after_comp_parameter_names = [
1050-
x for x, _ in structure.to_elements(after.type_signature.parameter)
1049+
x for x, _ in after.type_signature.parameter.items()
10511050
]
10521051
self.assertEqual(
10531052
after_comp_parameter_names[intrinsic_results_index], 'intrinsic_results'

tensorflow_federated/python/learning/models/model_weights_test.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,9 @@ def test_converts_struct(self):
271271
lambda item: self.assertIsInstance(item, np.ndarray),
272272
converted.non_trainable,
273273
)
274+
self.assertAllEqual(converted.trainable, [(None, np.array([1.0]))])
274275
self.assertAllEqual(
275-
structure.to_elements(converted.trainable), [(None, np.array([1.0]))]
276-
)
277-
self.assertAllEqual(
278-
structure.to_elements(converted.non_trainable),
276+
converted.non_trainable,
279277
[(None, np.array([2.0])), (None, np.array([3.0]))],
280278
)
281279

@@ -296,7 +294,7 @@ def test_converts_heterogeneous_struct(self):
296294
converted.trainable,
297295
)
298296
self.assertAllEqual(
299-
structure.to_elements(converted.trainable),
297+
converted.trainable,
300298
[
301299
('a', 1),
302300
('b', 2.0),

0 commit comments

Comments
 (0)