Skip to content

Commit 0c17626

Browse files
michaelreneercopybara-github
authored andcommitted
Remove TensorFlow dependency from the federating_executor module.
PiperOrigin-RevId: 746084493
1 parent 8963fda commit 0c17626

File tree

5 files changed

+606
-259
lines changed

5 files changed

+606
-259
lines changed

tensorflow_federated/cc/core/impl/executors/BUILD

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,6 @@ cc_library(
453453
":executor",
454454
":federated_intrinsics",
455455
":status_macros",
456-
":tensor_serialization",
457456
":threading",
458457
":value_validation",
459458
"//tensorflow_federated/proto/v0:executor_cc_proto",
@@ -466,29 +465,30 @@ cc_library(
466465
"@com_google_absl//absl/strings:string_view",
467466
"@com_google_absl//absl/synchronization",
468467
"@com_google_absl//absl/types:span",
468+
"@federated_language//federated_language/proto:array_cc_proto",
469469
"@federated_language//federated_language/proto:computation_cc_proto",
470-
"@org_tensorflow//tensorflow/core:framework",
470+
"@federated_language//federated_language/proto:data_type_cc_proto",
471471
],
472472
)
473473

474474
cc_test(
475475
name = "federating_executor_test",
476476
srcs = ["federating_executor_test.cc"],
477477
deps = [
478+
":array_shape_test_utils",
479+
":array_test_utils",
478480
":executor",
479481
":executor_test_base",
480482
":federating_executor",
481483
":mock_executor",
482484
":status_macros",
483-
":tensorflow_test_utils",
484485
":value_test_utils",
485486
"//tensorflow_federated/cc/testing:oss_test_main",
486487
"//tensorflow_federated/cc/testing:status_matchers",
487488
"//tensorflow_federated/proto/v0:executor_cc_proto",
488489
"@com_google_absl//absl/status",
489490
"@com_google_absl//absl/status:statusor",
490491
"@com_google_absl//absl/types:span",
491-
"@org_tensorflow//tensorflow/core:framework",
492492
"@org_tensorflow//tensorflow/core:tensorflow",
493493
],
494494
)
@@ -1000,7 +1000,7 @@ cc_library(
10001000
deps = [
10011001
":dataset_utils",
10021002
":status_macros",
1003-
":tensor_serialization",
1003+
":tensorflow_utils",
10041004
"//tensorflow_federated/cc/testing:protobuf_matchers",
10051005
"//tensorflow_federated/proto/v0:executor_cc_proto",
10061006
"@com_google_absl//absl/status",

tensorflow_federated/cc/core/impl/executors/federating_executor.cc

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ limitations under the License
3333
#include "absl/strings/string_view.h"
3434
#include "absl/synchronization/mutex.h"
3535
#include "absl/types/span.h"
36+
#include "federated_language/proto/array.pb.h"
3637
#include "federated_language/proto/computation.pb.h"
37-
#include "tensorflow/core/framework/tensor.h"
38+
#include "federated_language/proto/data_type.pb.h"
3839
#include "tensorflow_federated/cc/core/impl/executors/cardinalities.h"
3940
#include "tensorflow_federated/cc/core/impl/executors/executor.h"
4041
#include "tensorflow_federated/cc/core/impl/executors/federated_intrinsics.h"
4142
#include "tensorflow_federated/cc/core/impl/executors/status_macros.h"
42-
#include "tensorflow_federated/cc/core/impl/executors/tensor_serialization.h"
4343
#include "tensorflow_federated/cc/core/impl/executors/threading.h"
4444
#include "tensorflow_federated/cc/core/impl/executors/value_validation.h"
4545
#include "tensorflow_federated/proto/v0/executor.pb.h"
@@ -703,24 +703,23 @@ class FederatingExecutor : public ExecutorBase<ExecutorValue> {
703703
// these materialize calls don't block.
704704
v0::Value keys_for_client_pb =
705705
TFF_TRY(client_child_->Materialize(keys_child_id->ref()));
706-
tensorflow::Tensor keys_for_client_tensor =
707-
TFF_TRY(DeserializeTensorValue(keys_for_client_pb));
708-
if (keys_for_client_tensor.dtype() != tensorflow::DT_INT32) {
706+
federated_language::Array array_pb = keys_for_client_pb.array();
707+
if (array_pb.dtype() != federated_language::DataType::DT_INT32) {
709708
return absl::InvalidArgumentError(
710709
absl::StrCat("Expected int32_t key, found key of tensor dtype ",
711-
keys_for_client_tensor.dtype()));
710+
array_pb.dtype()));
712711
}
713-
if (keys_for_client_tensor.dims() != 1) {
712+
if (array_pb.shape().dim().size() != 1) {
714713
return absl::InvalidArgumentError(absl::StrCat(
715714
"Expected key tensor to be rank one, but found tensor of rank ",
716-
keys_for_client_tensor.dims()));
715+
array_pb.shape().dim().size()));
717716
}
718-
int64_t num_keys = keys_for_client_tensor.NumElements();
717+
int64_t num_keys = array_pb.shape().dim()[0];
719718
std::vector<int32_t> keys_for_client;
720719
keys_for_client.reserve(num_keys);
721-
auto keys_for_client_eigen = keys_for_client_tensor.flat<int32_t>();
720+
auto keys_for_client_eigen = array_pb.int32_list().value();
722721
for (int64_t i = 0; i < num_keys; i++) {
723-
int32_t key = keys_for_client_eigen(i);
722+
int32_t key = keys_for_client_eigen[i];
724723
keys_for_client.push_back(key);
725724
keys.all.insert(key);
726725
}
@@ -732,8 +731,12 @@ class FederatingExecutor : public ExecutorBase<ExecutorValue> {
732731
absl::StatusOr<OwnedValueId> SelectSliceForKey(int32_t key,
733732
ValueId server_val_child_id,
734733
ValueId select_fn_child_id) {
734+
federated_language::Array array_pb;
735+
array_pb.set_dtype(federated_language::DataType::DT_INT32);
736+
array_pb.mutable_shape();
737+
array_pb.mutable_int32_list()->add_value(key);
735738
v0::Value key_pb;
736-
TFF_TRY(SerializeTensorValue(tensorflow::Tensor(key), &key_pb));
739+
*key_pb.mutable_array() = array_pb;
737740
OwnedValueId key_id = TFF_TRY(server_child_->CreateValue(key_pb));
738741
OwnedValueId arg_id =
739742
TFF_TRY(server_child_->CreateStruct({server_val_child_id, key_id}));

0 commit comments

Comments
 (0)