aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/attr_value_util_test.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-01-11 08:53:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-11 09:06:12 -0800
commitd4a9d91bc68ffa9f0148ae6fe344b7d3e3de7221 (patch)
treead5ade6fb6e18725559457872451657608d2689c /tensorflow/core/framework/attr_value_util_test.cc
parent024507b55a488a4432a8f74f399c63a4f3debe24 (diff)
Add support for list(func) AttrValues.
Change: 144211572
Diffstat (limited to 'tensorflow/core/framework/attr_value_util_test.cc')
-rw-r--r--tensorflow/core/framework/attr_value_util_test.cc29
1 files changed, 22 insertions, 7 deletions
diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc
index 8679044f76..c14ea9b322 100644
--- a/tensorflow/core/framework/attr_value_util_test.cc
+++ b/tensorflow/core/framework/attr_value_util_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value_util.h"
#include <vector>
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -35,19 +36,32 @@ AttrValue P(const string& p) {
}
AttrValue F(const string& name,
- std::vector<std::pair<string, AttrValue> > pairs) {
+ std::vector<std::pair<string, AttrValue>> pairs) {
AttrValue ret;
ret.mutable_func()->set_name(name);
ret.mutable_func()->mutable_attr()->insert(pairs.begin(), pairs.end());
return ret;
}
+AttrValue Fs(
+ std::vector<std::pair<string, std::vector<std::pair<string, AttrValue>>>>
+ funcs) {
+ AttrValue ret;
+ for (const auto& func : funcs) {
+ NameAttrList* entry = ret.mutable_list()->add_func();
+ entry->set_name(func.first);
+ entry->mutable_attr()->insert(func.second.begin(), func.second.end());
+ }
+ return ret;
+}
+
TEST(AttrValueUtil, HasType) {
// OK
EXPECT_TRUE(AttrValueHasType(V(123), "int").ok());
EXPECT_TRUE(AttrValueHasType(V(1.2), "float").ok());
EXPECT_TRUE(AttrValueHasType(V(DT_FLOAT), "type").ok());
EXPECT_TRUE(AttrValueHasType(F("f", {}), "func").ok());
+ EXPECT_TRUE(AttrValueHasType(Fs({{"f", {}}, {"g", {}}}), "list(func)").ok());
// not OK.
EXPECT_FALSE(AttrValueHasType(V(123), "func").ok());
@@ -76,7 +90,7 @@ TEST(AttrValueUtil, Basic) {
{"transpose_a", V(false)},
{"transpose_b", V(true)},
{"use_cublas", V(true)}});
- TF_CHECK_OK(AttrValueHasType(v, "func"));
+ TF_EXPECT_OK(AttrValueHasType(v, "func"));
EXPECT_TRUE(HasPlaceHolder(v));
EXPECT_EQ(
@@ -94,7 +108,7 @@ TEST(AttrValueUtil, Shaped) {
auto v =
F("OpRequiresShape", {{"shape_full", V(TensorShape({1, 0}))},
{"shape_part", V(PartialTensorShape({-1, 1, 0}))}});
- TF_CHECK_OK(AttrValueHasType(v, "func"));
+ TF_EXPECT_OK(AttrValueHasType(v, "func"));
EXPECT_FALSE(HasPlaceHolder(v));
EXPECT_EQ(SummarizeAttrValue(v),
@@ -102,20 +116,21 @@ TEST(AttrValueUtil, Shaped) {
}
TEST(AttrValueUtil, DeepAttr) {
- auto v = F("f", {{"T", P("T")}});
- TF_CHECK_OK(AttrValueHasType(v, "func"));
+ auto v = Fs({{"f", {{"T", P("T")}}}, {"g", {{"T", P("T")}}}});
+ TF_EXPECT_OK(AttrValueHasType(v, "list(func)"));
EXPECT_TRUE(HasPlaceHolder(v));
for (int i = 0; i < 3; ++i) {
v = F("f", {{"T", P("T")}, {"F", v}});
EXPECT_TRUE(HasPlaceHolder(v));
}
- EXPECT_EQ(SummarizeAttrValue(v), "f[F=f[F=f[F=f[T=$T], T=$T], T=$T], T=$T]");
+ EXPECT_EQ(SummarizeAttrValue(v),
+ "f[F=f[F=f[F=[f[T=$T], g[T=$T]], T=$T], T=$T], T=$T]");
SubstitutePlaceholders(ReplaceTWith(F("x", {})), &v);
EXPECT_TRUE(!HasPlaceHolder(v));
EXPECT_EQ(SummarizeAttrValue(v),
- "f[F=f[F=f[F=f[T=x[]], T=x[]], T=x[]], T=x[]]");
+ "f[F=f[F=f[F=[f[T=x[]], g[T=x[]]], T=x[]], T=x[]], T=x[]]");
}
} // namespace tensorflow