@@ -33,13 +33,13 @@ limitations under the License
33
33
#include " absl/strings/string_view.h"
34
34
#include " absl/synchronization/mutex.h"
35
35
#include " absl/types/span.h"
36
+ #include " federated_language/proto/array.pb.h"
36
37
#include " federated_language/proto/computation.pb.h"
37
- #include " tensorflow/core/framework/tensor .h"
38
+ #include " federated_language/proto/data_type.pb .h"
38
39
#include " tensorflow_federated/cc/core/impl/executors/cardinalities.h"
39
40
#include " tensorflow_federated/cc/core/impl/executors/executor.h"
40
41
#include " tensorflow_federated/cc/core/impl/executors/federated_intrinsics.h"
41
42
#include " tensorflow_federated/cc/core/impl/executors/status_macros.h"
42
- #include " tensorflow_federated/cc/core/impl/executors/tensor_serialization.h"
43
43
#include " tensorflow_federated/cc/core/impl/executors/threading.h"
44
44
#include " tensorflow_federated/cc/core/impl/executors/value_validation.h"
45
45
#include " tensorflow_federated/proto/v0/executor.pb.h"
@@ -703,24 +703,23 @@ class FederatingExecutor : public ExecutorBase<ExecutorValue> {
703
703
// these materialize calls don't block.
704
704
v0::Value keys_for_client_pb =
705
705
TFF_TRY (client_child_->Materialize (keys_child_id->ref ()));
706
- tensorflow::Tensor keys_for_client_tensor =
707
- TFF_TRY (DeserializeTensorValue (keys_for_client_pb));
708
- if (keys_for_client_tensor.dtype () != tensorflow::DT_INT32) {
706
+ federated_language::Array array_pb = keys_for_client_pb.array ();
707
+ if (array_pb.dtype () != federated_language::DataType::DT_INT32) {
709
708
return absl::InvalidArgumentError (
710
709
absl::StrCat (" Expected int32_t key, found key of tensor dtype " ,
711
- keys_for_client_tensor .dtype ()));
710
+ array_pb .dtype ()));
712
711
}
713
- if (keys_for_client_tensor. dims () != 1 ) {
712
+ if (array_pb. shape (). dim (). size () != 1 ) {
714
713
return absl::InvalidArgumentError (absl::StrCat (
715
714
" Expected key tensor to be rank one, but found tensor of rank " ,
716
- keys_for_client_tensor. dims ()));
715
+ array_pb. shape (). dim (). size ()));
717
716
}
718
- int64_t num_keys = keys_for_client_tensor. NumElements () ;
717
+ int64_t num_keys = array_pb. shape (). dim ()[ 0 ] ;
719
718
std::vector<int32_t > keys_for_client;
720
719
keys_for_client.reserve (num_keys);
721
- auto keys_for_client_eigen = keys_for_client_tensor. flat < int32_t > ();
720
+ auto keys_for_client_eigen = array_pb. int32_list (). value ();
722
721
for (int64_t i = 0 ; i < num_keys; i++) {
723
- int32_t key = keys_for_client_eigen (i) ;
722
+ int32_t key = keys_for_client_eigen[i] ;
724
723
keys_for_client.push_back (key);
725
724
keys.all .insert (key);
726
725
}
@@ -732,8 +731,12 @@ class FederatingExecutor : public ExecutorBase<ExecutorValue> {
732
731
absl::StatusOr<OwnedValueId> SelectSliceForKey (int32_t key,
733
732
ValueId server_val_child_id,
734
733
ValueId select_fn_child_id) {
734
+ federated_language::Array array_pb;
735
+ array_pb.set_dtype (federated_language::DataType::DT_INT32);
736
+ array_pb.mutable_shape ();
737
+ array_pb.mutable_int32_list ()->add_value (key);
735
738
v0::Value key_pb;
736
- TFF_TRY ( SerializeTensorValue ( tensorflow::Tensor (key), &key_pb)) ;
739
+ *key_pb. mutable_array () = array_pb ;
737
740
OwnedValueId key_id = TFF_TRY (server_child_->CreateValue (key_pb));
738
741
OwnedValueId arg_id =
739
742
TFF_TRY (server_child_->CreateStruct ({server_val_child_id, key_id}));
0 commit comments