diff options
author | 2017-01-11 08:53:21 -0800 | |
---|---|---|
committer | 2017-01-11 09:06:12 -0800 | |
commit | d4a9d91bc68ffa9f0148ae6fe344b7d3e3de7221 (patch) | |
tree | ad5ade6fb6e18725559457872451657608d2689c /tensorflow/core/framework/attr_value_util_test.cc | |
parent | 024507b55a488a4432a8f74f399c63a4f3debe24 (diff) |
Add support for list(func) AttrValues.
Change: 144211572
Diffstat (limited to 'tensorflow/core/framework/attr_value_util_test.cc')
-rw-r--r-- | tensorflow/core/framework/attr_value_util_test.cc | 29 |
1 files changed, 22 insertions, 7 deletions
diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc index 8679044f76..c14ea9b322 100644 --- a/tensorflow/core/framework/attr_value_util_test.cc +++ b/tensorflow/core/framework/attr_value_util_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value_util.h" #include <vector> +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -35,19 +36,32 @@ AttrValue P(const string& p) { } AttrValue F(const string& name, - std::vector<std::pair<string, AttrValue> > pairs) { + std::vector<std::pair<string, AttrValue>> pairs) { AttrValue ret; ret.mutable_func()->set_name(name); ret.mutable_func()->mutable_attr()->insert(pairs.begin(), pairs.end()); return ret; } +AttrValue Fs( + std::vector<std::pair<string, std::vector<std::pair<string, AttrValue>>>> + funcs) { + AttrValue ret; + for (const auto& func : funcs) { + NameAttrList* entry = ret.mutable_list()->add_func(); + entry->set_name(func.first); + entry->mutable_attr()->insert(func.second.begin(), func.second.end()); + } + return ret; +} + TEST(AttrValueUtil, HasType) { // OK EXPECT_TRUE(AttrValueHasType(V(123), "int").ok()); EXPECT_TRUE(AttrValueHasType(V(1.2), "float").ok()); EXPECT_TRUE(AttrValueHasType(V(DT_FLOAT), "type").ok()); EXPECT_TRUE(AttrValueHasType(F("f", {}), "func").ok()); + EXPECT_TRUE(AttrValueHasType(Fs({{"f", {}}, {"g", {}}}), "list(func)").ok()); // not OK. EXPECT_FALSE(AttrValueHasType(V(123), "func").ok()); @@ -76,7 +90,7 @@ TEST(AttrValueUtil, Basic) { {"transpose_a", V(false)}, {"transpose_b", V(true)}, {"use_cublas", V(true)}}); - TF_CHECK_OK(AttrValueHasType(v, "func")); + TF_EXPECT_OK(AttrValueHasType(v, "func")); EXPECT_TRUE(HasPlaceHolder(v)); EXPECT_EQ( @@ -94,7 +108,7 @@ TEST(AttrValueUtil, Shaped) { auto v = F("OpRequiresShape", {{"shape_full", V(TensorShape({1, 0}))}, {"shape_part", V(PartialTensorShape({-1, 1, 0}))}}); - TF_CHECK_OK(AttrValueHasType(v, "func")); + TF_EXPECT_OK(AttrValueHasType(v, "func")); EXPECT_FALSE(HasPlaceHolder(v)); EXPECT_EQ(SummarizeAttrValue(v), @@ -102,20 +116,21 @@ TEST(AttrValueUtil, Shaped) { } TEST(AttrValueUtil, DeepAttr) { - auto v = F("f", {{"T", P("T")}}); - TF_CHECK_OK(AttrValueHasType(v, "func")); + auto v = Fs({{"f", {{"T", P("T")}}}, {"g", {{"T", P("T")}}}}); + TF_EXPECT_OK(AttrValueHasType(v, "list(func)")); EXPECT_TRUE(HasPlaceHolder(v)); for (int i = 0; i < 3; ++i) { v = F("f", {{"T", P("T")}, {"F", v}}); EXPECT_TRUE(HasPlaceHolder(v)); } - EXPECT_EQ(SummarizeAttrValue(v), "f[F=f[F=f[F=f[T=$T], T=$T], T=$T], T=$T]"); + EXPECT_EQ(SummarizeAttrValue(v), + "f[F=f[F=f[F=[f[T=$T], g[T=$T]], T=$T], T=$T], T=$T]"); SubstitutePlaceholders(ReplaceTWith(F("x", {})), &v); EXPECT_TRUE(!HasPlaceHolder(v)); EXPECT_EQ(SummarizeAttrValue(v), - "f[F=f[F=f[F=f[T=x[]], T=x[]], T=x[]], T=x[]]"); + "f[F=f[F=f[F=[f[T=x[]], g[T=x[]]], T=x[]], T=x[]], T=x[]]"); } } // namespace tensorflow |