Skip to content

Commit e1b0be3

Browse files
TensorFlow Federated Teamcopybara-github
authored andcommitted
Prevent MutableVectorData from being used for std::string.
PiperOrigin-RevId: 705972216
1 parent 8fa237b commit e1b0be3

File tree

4 files changed

+148
-1
lines changed

4 files changed

+148
-1
lines changed

tensorflow_federated/cc/core/impl/aggregation/core/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ cc_library(
7070
"agg_vector_iterator.h",
7171
"datatype.h",
7272
"input_tensor_list.h",
73+
"mutable_unowned_string_data.h",
7374
"mutable_vector_data.h",
7475
"tensor.h",
7576
"tensor_data.h",
@@ -566,6 +567,17 @@ cc_test(
566567
],
567568
)
568569

570+
cc_test(
571+
name = "mutable_unowned_string_data_test",
572+
srcs = ["mutable_unowned_string_data_test.cc"],
573+
deps = [
574+
":tensor",
575+
"//tensorflow_federated/cc/testing:oss_test_main",
576+
"//tensorflow_federated/cc/testing:status_matchers",
577+
"@com_google_absl//absl/strings:string_view",
578+
],
579+
)
580+
569581
cc_test(
570582
name = "vector_string_data_test",
571583
srcs = [
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#ifndef TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_MUTABLE_UNOWNED_STRING_DATA_H_
18+
#define TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_MUTABLE_UNOWNED_STRING_DATA_H_
19+
20+
#include <cstddef>
21+
#include <memory>
22+
#include <string>
23+
#include <vector>
24+
25+
#include "absl/strings/string_view.h"
26+
#include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_data.h"
27+
28+
namespace tensorflow_federated {
29+
namespace aggregation {
30+
31+
// MutableUnownedStringData implements TensorData by wrapping std::vector and
32+
// using it as backing storage for string_view objects. MutableUnownedStringData
33+
// can be mutated using std::vector methods. The MutableUnownedStringData object
34+
// does not own the string values. Use MutableStringData instead if you want
35+
// a TensorData object that owns the strings.
36+
class MutableUnownedStringData : public std::vector<absl::string_view>,
37+
public TensorData {
38+
public:
39+
// Derive constructors from the base vector class.
40+
using std::vector<absl::string_view>::vector;
41+
42+
~MutableUnownedStringData() override = default;
43+
44+
// Implementation of the base class methods.
45+
size_t byte_size() const override {
46+
return this->size() * sizeof(absl::string_view);
47+
}
48+
const void* data() const override {
49+
return this->std::vector<absl::string_view>::data();
50+
}
51+
52+
// Copy the MutableUnownedStringData into a string.
53+
std::string EncodeContent() {
54+
return std::string(reinterpret_cast<const char*>(this->data()),
55+
this->byte_size());
56+
}
57+
58+
// Create and return a new MutableUnownedStringData populated with the data
59+
// from content.
60+
static std::unique_ptr<MutableUnownedStringData> CreateFromEncodedContent(
61+
const std::string& content) {
62+
const absl::string_view* data =
63+
reinterpret_cast<const absl::string_view*>(content.data());
64+
return std::make_unique<MutableUnownedStringData>(
65+
data, data + content.size() / sizeof(absl::string_view));
66+
}
67+
};
68+
69+
} // namespace aggregation
70+
} // namespace tensorflow_federated
71+
72+
#endif // TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_MUTABLE_UNOWNED_STRING_DATA_H_
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "tensorflow_federated/cc/core/impl/aggregation/core/mutable_unowned_string_data.h"
18+
19+
#include <string>
20+
#include <vector>
21+
22+
#include "googlemock/include/gmock/gmock.h"
23+
#include "googletest/include/gtest/gtest.h"
24+
#include "absl/strings/string_view.h"
25+
#include "tensorflow_federated/cc/testing/status_matchers.h"
26+
27+
namespace tensorflow_federated {
28+
namespace aggregation {
29+
namespace {
30+
31+
TEST(MutableUnownedStringDataTest, MutableUnownedStringDataValid) {
32+
std::string string_1 = "foo";
33+
std::string string_2 = "bar";
34+
std::string string_3 = "baz";
35+
MutableUnownedStringData vector_data;
36+
vector_data.push_back(absl::string_view(string_1));
37+
vector_data.push_back(absl::string_view(string_2));
38+
vector_data.push_back(absl::string_view(string_3));
39+
EXPECT_THAT(vector_data.CheckValid<absl::string_view>(), IsOk());
40+
}
41+
42+
TEST(MutableUnownedStringDataTest, EncodeDecodeSucceeds) {
43+
std::string string_1 = "foo";
44+
std::string string_2 = "bar";
45+
std::string string_3 = "baz";
46+
MutableUnownedStringData vector_data;
47+
vector_data.push_back(absl::string_view(string_1));
48+
vector_data.push_back(absl::string_view(string_2));
49+
vector_data.push_back(absl::string_view(string_3));
50+
std::string encoded_vector_data = vector_data.EncodeContent();
51+
EXPECT_THAT(vector_data.CheckValid<absl::string_view>(), IsOk());
52+
auto decoded_vector_data =
53+
MutableUnownedStringData::CreateFromEncodedContent(encoded_vector_data);
54+
EXPECT_THAT(decoded_vector_data->CheckValid<absl::string_view>(), IsOk());
55+
EXPECT_EQ((*decoded_vector_data)[0], string_1);
56+
EXPECT_EQ((*decoded_vector_data)[1], string_2);
57+
EXPECT_EQ((*decoded_vector_data)[2], string_3);
58+
}
59+
60+
} // namespace
61+
} // namespace aggregation
62+
} // namespace tensorflow_federated

tensorflow_federated/cc/core/impl/aggregation/core/mutable_vector_data.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <cstddef>
2121
#include <memory>
2222
#include <string>
23+
#include <type_traits>
2324
#include <vector>
2425

2526
#include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_data.h"
@@ -30,7 +31,7 @@ namespace aggregation {
3031
// MutableVectorData implements TensorData by wrapping std::vector and using it
3132
// as a backing storage. MutableVectorData can be mutated using std::vector
3233
// methods.
33-
template <typename T>
34+
template <typename T, std::enable_if_t<std::is_arithmetic_v<T>, int> = 0>
3435
class MutableVectorData : public std::vector<T>, public TensorData {
3536
public:
3637
// Derive constructors from the base vector class.

0 commit comments

Comments
 (0)