diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-08-03 16:48:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-03 16:52:15 -0700 |
commit | 57626dd38a7867b76c44f3933e7810190174a2ee (patch) | |
tree | cc9d6833e5107fe6fb915ca28c597b53c4b0bb3e /tensorflow/c/c_api_test.cc | |
parent | 11d900686fd6aa65c26d75ac92e11b437ef4c48c (diff) |
Allow specifying colocation constraints through TF_SetAttr*
Before this change, colocation constraint semantics were not
well-defined when they were set using TF_ColocateWith and
TF_SetAttrStringList and/or TF_SetAttrValueProto. One could get
an exception if multiple methods were used.
After this change all changes to colocation attribute (i.e. _class)
are executed on TF_OperationDescription.colocation_constraints leading
to consistent semantics.
PiperOrigin-RevId: 164202666
Diffstat (limited to 'tensorflow/c/c_api_test.cc')
-rw-r--r-- | tensorflow/c/c_api_test.cc | 200 |
1 files changed, 161 insertions, 39 deletions
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 1d191fc36d..0aa60fb45d 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -837,39 +837,172 @@ TEST(CAPI, ShapeInferenceError) { TF_DeleteStatus(status); } -TEST(CAPI, ColocateWith) { - TF_Status* s = TF_NewStatus(); - TF_Graph* graph = TF_NewGraph(); +void StringVectorToArrays(const std::vector<string>& v, + std::unique_ptr<const void* []>* ptrs, + std::unique_ptr<size_t[]>* lens) { + ptrs->reset(new const void*[v.size()]); + lens->reset(new size_t[v.size()]); + for (size_t i = 0; i < v.size(); ++i) { + (*ptrs)[i] = v[i].data(); + (*lens)[i] = v[i].size(); + } +} - TF_Operation* feed = Placeholder(graph, s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); +class CApiColocationTest : public ::testing::Test { + protected: + CApiColocationTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {} - TF_Operation* constant = ScalarConst(10, graph, s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + void SetUp() override { + feed1_ = Placeholder(graph_, s_, "feed1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", "add"); - TF_Output inputs[] = {{feed, 0}, {constant, 0}}; - TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs)); - TF_ColocateWith(desc, feed); - TF_Operation* add = TF_FinishOperation(desc, s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + feed2_ = Placeholder(graph_, s_, "feed2"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - TF_AttrMetadata m = - TF_OperationGetAttrMetadata(add, tensorflow::kColocationAttrName, s); - EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(1, m.is_list); - EXPECT_EQ(1, m.list_size); - EXPECT_EQ(TF_ATTR_STRING, m.type); - void* values[1]; - size_t lens[1]; - std::unique_ptr<char[]> storage(new char[m.total_size]); - TF_OperationGetAttrStringList(add, tensorflow::kColocationAttrName, values, - lens, 1, storage.get(), m.total_size, s); - EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ("loc:@feed", string(static_cast<const char*>(values[0]), lens[0])); + constant_ = ScalarConst(10, graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - TF_DeleteGraph(graph); - TF_DeleteStatus(s); + desc_ = TF_NewOperation(graph_, "AddN", "add"); + TF_Output inputs[] = {{feed1_, 0}, {constant_, 0}}; + TF_AddInputList(desc_, inputs, TF_ARRAYSIZE(inputs)); + } + + ~CApiColocationTest() override { + TF_DeleteGraph(graph_); + TF_DeleteStatus(s_); + } + + void SetViaStringList(TF_OperationDescription* desc, + const std::vector<string>& list) { + std::unique_ptr<const void* []> list_ptrs; + std::unique_ptr<size_t[]> list_lens; + StringVectorToArrays(list, &list_ptrs, &list_lens); + TF_SetAttrStringList(desc, tensorflow::kColocationAttrName, list_ptrs.get(), + list_lens.get(), list.size()); + } + + void SetViaProto(TF_OperationDescription* desc, + const std::vector<string>& list) { + tensorflow::AttrValue attr; + for (const string& v : list) { + attr.mutable_list()->add_s(v); + } + string bytes; + attr.SerializeToString(&bytes); + TF_SetAttrValueProto(desc, tensorflow::kColocationAttrName, bytes.data(), + bytes.size(), s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + } + + void VerifyCollocation(TF_Operation* op, + const std::vector<string>& expected) { + TF_AttrMetadata m = + TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_); + if (expected.empty()) { + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); + EXPECT_EQ(std::string("Operation has no attr named '_class'."), + std::string(TF_Message(s_))); + return; + } + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + EXPECT_EQ(1, m.is_list); + EXPECT_EQ(expected.size(), m.list_size); + EXPECT_EQ(TF_ATTR_STRING, m.type); + std::vector<void*> values(expected.size()); + std::vector<size_t> lens(expected.size()); + std::unique_ptr<char[]> storage(new char[m.total_size]); + TF_OperationGetAttrStringList(op, tensorflow::kColocationAttrName, + values.data(), lens.data(), expected.size(), + storage.get(), m.total_size, s_); + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + for (int i = 0; i < expected.size(); ++i) { + EXPECT_EQ(expected[i], + string(static_cast<const char*>(values[i]), lens[i])); + } + } + + void FinishAndVerify(TF_OperationDescription* desc, + const std::vector<string>& expected) { + TF_Operation* op = TF_FinishOperation(desc_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + VerifyCollocation(op, expected); + } + + TF_Status* s_; + TF_Graph* graph_; + TF_Operation* feed1_; + TF_Operation* feed2_; + TF_Operation* constant_; + TF_OperationDescription* desc_; +}; + +TEST_F(CApiColocationTest, ColocateWith) { + TF_ColocateWith(desc_, feed1_); + FinishAndVerify(desc_, {"loc:@feed1"}); +} + +TEST_F(CApiColocationTest, StringList) { + SetViaStringList(desc_, {"loc:@feed1"}); + FinishAndVerify(desc_, {"loc:@feed1"}); +} + +TEST_F(CApiColocationTest, Proto) { + SetViaProto(desc_, {"loc:@feed1"}); + FinishAndVerify(desc_, {"loc:@feed1"}); +} + +TEST_F(CApiColocationTest, ColocateWith_StringList) { + TF_ColocateWith(desc_, feed1_); + SetViaStringList(desc_, {"loc:@feed2"}); + FinishAndVerify(desc_, {"loc:@feed2"}); +} + +TEST_F(CApiColocationTest, ColocateWith_Proto) { + TF_ColocateWith(desc_, feed1_); + SetViaProto(desc_, {"loc:@feed2"}); + FinishAndVerify(desc_, {"loc:@feed2"}); +} + +TEST_F(CApiColocationTest, StringList_ColocateWith) { + SetViaStringList(desc_, {"loc:@feed2"}); + TF_ColocateWith(desc_, feed1_); + FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"}); +} + +TEST_F(CApiColocationTest, Proto_ColocateWith) { + SetViaProto(desc_, {"loc:@feed2"}); + TF_ColocateWith(desc_, feed1_); + FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"}); +} + +TEST_F(CApiColocationTest, ColocateWith_ColocateWith) { + TF_ColocateWith(desc_, feed1_); + TF_ColocateWith(desc_, feed2_); + FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"}); +} + +TEST_F(CApiColocationTest, Proto_StringList) { + SetViaProto(desc_, {"loc:@feed1"}); + SetViaStringList(desc_, {"loc:@feed2"}); + FinishAndVerify(desc_, {"loc:@feed2"}); +} + +TEST_F(CApiColocationTest, StringList_Proto) { + SetViaStringList(desc_, {"loc:@feed1"}); + SetViaProto(desc_, {"loc:@feed2"}); + FinishAndVerify(desc_, {"loc:@feed2"}); +} + +TEST_F(CApiColocationTest, ClearViaStringList) { + TF_ColocateWith(desc_, feed1_); + SetViaStringList(desc_, {}); + FinishAndVerify(desc_, {}); +} + +TEST_F(CApiColocationTest, ClearViaProto) { + TF_ColocateWith(desc_, feed1_); + SetViaProto(desc_, {}); + FinishAndVerify(desc_, {}); } TEST(CAPI, SavedModel) { @@ -1245,17 +1378,6 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) { TestGradientsError(false); } -void StringVectorToArrays(const std::vector<string>& v, - std::unique_ptr<const void* []>* ptrs, - std::unique_ptr<size_t[]>* lens) { - ptrs->reset(new const void*[v.size()]); - lens->reset(new size_t[v.size()]); - for (size_t i = 0; i < v.size(); ++i) { - (*ptrs)[i] = v[i].data(); - (*lens)[i] = v[i].size(); - } -} - // REGISTER_OP for CApiTestAttributesTest test cases. // Registers two ops, each with a single attribute called 'v'. // The attribute in one op will have a type 'type', the other |