aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-03 16:48:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-03 16:52:15 -0700
commit57626dd38a7867b76c44f3933e7810190174a2ee (patch)
treecc9d6833e5107fe6fb915ca28c597b53c4b0bb3e /tensorflow/c/c_api_test.cc
parent11d900686fd6aa65c26d75ac92e11b437ef4c48c (diff)
Allow specifying colocation constraints through TF_SetAttr*
Before this change, colocation constraint semantics were not well-defined when they were set using TF_ColocateWith and TF_SetAttrStringList and/or TF_SetAttrValueProto. One could get an exception if multiple methods were used. After this change all changes to colocation attribute (i.e. _class) are executed on TF_OperationDescription.colocation_constraints leading to consistent semantics. PiperOrigin-RevId: 164202666
Diffstat (limited to 'tensorflow/c/c_api_test.cc')
-rw-r--r--tensorflow/c/c_api_test.cc200
1 files changed, 161 insertions, 39 deletions
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 1d191fc36d..0aa60fb45d 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -837,39 +837,172 @@ TEST(CAPI, ShapeInferenceError) {
TF_DeleteStatus(status);
}
-TEST(CAPI, ColocateWith) {
- TF_Status* s = TF_NewStatus();
- TF_Graph* graph = TF_NewGraph();
+void StringVectorToArrays(const std::vector<string>& v,
+ std::unique_ptr<const void* []>* ptrs,
+ std::unique_ptr<size_t[]>* lens) {
+ ptrs->reset(new const void*[v.size()]);
+ lens->reset(new size_t[v.size()]);
+ for (size_t i = 0; i < v.size(); ++i) {
+ (*ptrs)[i] = v[i].data();
+ (*lens)[i] = v[i].size();
+ }
+}
- TF_Operation* feed = Placeholder(graph, s);
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+class CApiColocationTest : public ::testing::Test {
+ protected:
+ CApiColocationTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {}
- TF_Operation* constant = ScalarConst(10, graph, s);
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ void SetUp() override {
+ feed1_ = Placeholder(graph_, s_, "feed1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", "add");
- TF_Output inputs[] = {{feed, 0}, {constant, 0}};
- TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs));
- TF_ColocateWith(desc, feed);
- TF_Operation* add = TF_FinishOperation(desc, s);
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ feed2_ = Placeholder(graph_, s_, "feed2");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- TF_AttrMetadata m =
- TF_OperationGetAttrMetadata(add, tensorflow::kColocationAttrName, s);
- EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(1, m.is_list);
- EXPECT_EQ(1, m.list_size);
- EXPECT_EQ(TF_ATTR_STRING, m.type);
- void* values[1];
- size_t lens[1];
- std::unique_ptr<char[]> storage(new char[m.total_size]);
- TF_OperationGetAttrStringList(add, tensorflow::kColocationAttrName, values,
- lens, 1, storage.get(), m.total_size, s);
- EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ("loc:@feed", string(static_cast<const char*>(values[0]), lens[0]));
+ constant_ = ScalarConst(10, graph_, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- TF_DeleteGraph(graph);
- TF_DeleteStatus(s);
+ desc_ = TF_NewOperation(graph_, "AddN", "add");
+ TF_Output inputs[] = {{feed1_, 0}, {constant_, 0}};
+ TF_AddInputList(desc_, inputs, TF_ARRAYSIZE(inputs));
+ }
+
+ ~CApiColocationTest() override {
+ TF_DeleteGraph(graph_);
+ TF_DeleteStatus(s_);
+ }
+
+ void SetViaStringList(TF_OperationDescription* desc,
+ const std::vector<string>& list) {
+ std::unique_ptr<const void* []> list_ptrs;
+ std::unique_ptr<size_t[]> list_lens;
+ StringVectorToArrays(list, &list_ptrs, &list_lens);
+ TF_SetAttrStringList(desc, tensorflow::kColocationAttrName, list_ptrs.get(),
+ list_lens.get(), list.size());
+ }
+
+ void SetViaProto(TF_OperationDescription* desc,
+ const std::vector<string>& list) {
+ tensorflow::AttrValue attr;
+ for (const string& v : list) {
+ attr.mutable_list()->add_s(v);
+ }
+ string bytes;
+ attr.SerializeToString(&bytes);
+ TF_SetAttrValueProto(desc, tensorflow::kColocationAttrName, bytes.data(),
+ bytes.size(), s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ }
+
+ void VerifyCollocation(TF_Operation* op,
+ const std::vector<string>& expected) {
+ TF_AttrMetadata m =
+ TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_);
+ if (expected.empty()) {
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
+ EXPECT_EQ(std::string("Operation has no attr named '_class'."),
+ std::string(TF_Message(s_)));
+ return;
+ }
+ EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ EXPECT_EQ(1, m.is_list);
+ EXPECT_EQ(expected.size(), m.list_size);
+ EXPECT_EQ(TF_ATTR_STRING, m.type);
+ std::vector<void*> values(expected.size());
+ std::vector<size_t> lens(expected.size());
+ std::unique_ptr<char[]> storage(new char[m.total_size]);
+ TF_OperationGetAttrStringList(op, tensorflow::kColocationAttrName,
+ values.data(), lens.data(), expected.size(),
+ storage.get(), m.total_size, s_);
+ EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ for (int i = 0; i < expected.size(); ++i) {
+ EXPECT_EQ(expected[i],
+ string(static_cast<const char*>(values[i]), lens[i]));
+ }
+ }
+
+ void FinishAndVerify(TF_OperationDescription* desc,
+ const std::vector<string>& expected) {
+ TF_Operation* op = TF_FinishOperation(desc_, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ VerifyCollocation(op, expected);
+ }
+
+ TF_Status* s_;
+ TF_Graph* graph_;
+ TF_Operation* feed1_;
+ TF_Operation* feed2_;
+ TF_Operation* constant_;
+ TF_OperationDescription* desc_;
+};
+
+TEST_F(CApiColocationTest, ColocateWith) {
+ TF_ColocateWith(desc_, feed1_);
+ FinishAndVerify(desc_, {"loc:@feed1"});
+}
+
+TEST_F(CApiColocationTest, StringList) {
+ SetViaStringList(desc_, {"loc:@feed1"});
+ FinishAndVerify(desc_, {"loc:@feed1"});
+}
+
+TEST_F(CApiColocationTest, Proto) {
+ SetViaProto(desc_, {"loc:@feed1"});
+ FinishAndVerify(desc_, {"loc:@feed1"});
+}
+
+TEST_F(CApiColocationTest, ColocateWith_StringList) {
+ TF_ColocateWith(desc_, feed1_);
+ SetViaStringList(desc_, {"loc:@feed2"});
+ FinishAndVerify(desc_, {"loc:@feed2"});
+}
+
+TEST_F(CApiColocationTest, ColocateWith_Proto) {
+ TF_ColocateWith(desc_, feed1_);
+ SetViaProto(desc_, {"loc:@feed2"});
+ FinishAndVerify(desc_, {"loc:@feed2"});
+}
+
+TEST_F(CApiColocationTest, StringList_ColocateWith) {
+ SetViaStringList(desc_, {"loc:@feed2"});
+ TF_ColocateWith(desc_, feed1_);
+ FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"});
+}
+
+TEST_F(CApiColocationTest, Proto_ColocateWith) {
+ SetViaProto(desc_, {"loc:@feed2"});
+ TF_ColocateWith(desc_, feed1_);
+ FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"});
+}
+
+TEST_F(CApiColocationTest, ColocateWith_ColocateWith) {
+ TF_ColocateWith(desc_, feed1_);
+ TF_ColocateWith(desc_, feed2_);
+ FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"});
+}
+
+TEST_F(CApiColocationTest, Proto_StringList) {
+ SetViaProto(desc_, {"loc:@feed1"});
+ SetViaStringList(desc_, {"loc:@feed2"});
+ FinishAndVerify(desc_, {"loc:@feed2"});
+}
+
+TEST_F(CApiColocationTest, StringList_Proto) {
+ SetViaStringList(desc_, {"loc:@feed1"});
+ SetViaProto(desc_, {"loc:@feed2"});
+ FinishAndVerify(desc_, {"loc:@feed2"});
+}
+
+TEST_F(CApiColocationTest, ClearViaStringList) {
+ TF_ColocateWith(desc_, feed1_);
+ SetViaStringList(desc_, {});
+ FinishAndVerify(desc_, {});
+}
+
+TEST_F(CApiColocationTest, ClearViaProto) {
+ TF_ColocateWith(desc_, feed1_);
+ SetViaProto(desc_, {});
+ FinishAndVerify(desc_, {});
}
TEST(CAPI, SavedModel) {
@@ -1245,17 +1378,6 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) {
TestGradientsError(false);
}
-void StringVectorToArrays(const std::vector<string>& v,
- std::unique_ptr<const void* []>* ptrs,
- std::unique_ptr<size_t[]>* lens) {
- ptrs->reset(new const void*[v.size()]);
- lens->reset(new size_t[v.size()]);
- for (size_t i = 0; i < v.size(); ++i) {
- (*ptrs)[i] = v[i].data();
- (*lens)[i] = v[i].size();
- }
-}
-
// REGISTER_OP for CApiTestAttributesTest test cases.
// Registers two ops, each with a single attribute called 'v'.
// The attribute in one op will have a type 'type', the other