aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/attr_value_util_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/attr_value_util_test.cc')
-rw-r--r--tensorflow/core/framework/attr_value_util_test.cc91
1 files changed, 91 insertions, 0 deletions
diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc
new file mode 100644
index 0000000000..bdfbf1707a
--- /dev/null
+++ b/tensorflow/core/framework/attr_value_util_test.cc
@@ -0,0 +1,91 @@
+#include "tensorflow/core/framework/attr_value_util.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+// A few helpers to construct AttrValue protos.
+template <typename T>
+AttrValue V(T value) {
+ AttrValue ret;
+ SetAttrValue(value, &ret);
+ return ret;
+}
+
+AttrValue P(const string& p) {
+ AttrValue ret;
+ ret.set_placeholder(p);
+ return ret;
+}
+
+AttrValue F(const string& name,
+ std::vector<std::pair<string, AttrValue> > pairs) {
+ AttrValue ret;
+ ret.mutable_func()->set_name(name);
+ ret.mutable_func()->mutable_attr()->insert(pairs.begin(), pairs.end());
+ return ret;
+}
+
+TEST(AttrValueUtil, HasType) {
+ // OK
+ EXPECT_TRUE(AttrValueHasType(V(123), "int").ok());
+ EXPECT_TRUE(AttrValueHasType(V(1.2), "float").ok());
+ EXPECT_TRUE(AttrValueHasType(V(DT_FLOAT), "type").ok());
+ EXPECT_TRUE(AttrValueHasType(F("f", {}), "func").ok());
+
+ // not OK.
+ EXPECT_FALSE(AttrValueHasType(V(123), "func").ok());
+ EXPECT_FALSE(AttrValueHasType(V(1.2), "int").ok());
+ EXPECT_FALSE(AttrValueHasType(V(DT_FLOAT), "shape").ok());
+ EXPECT_FALSE(AttrValueHasType(F("f", {}), "string").ok());
+ EXPECT_FALSE(AttrValueHasType(P("T"), "float").ok());
+}
+
+SubstituteFunc ReplaceTWith(const AttrValue& val) {
+ return [val](const string& placeholder, AttrValue* target) {
+ if (placeholder == "T") {
+ *target = val;
+ return true;
+ } else {
+ return false;
+ }
+ };
+}
+
+TEST(AttrValueUtil, Basic) {
+ auto v = F("MatMul", {{"dtype", P("T")},
+ {"transpose_a", V(false)},
+ {"transpose_b", V(true)},
+ {"use_cublas", V(true)}});
+ TF_CHECK_OK(AttrValueHasType(v, "func"));
+ EXPECT_TRUE(HasPlaceHolder(v));
+
+ EXPECT_EQ(
+ SummarizeAttrValue(v),
+ "MatMul[dtype=$T, transpose_a=false, transpose_b=true, use_cublas=true]");
+
+ SubstitutePlaceholders(ReplaceTWith(V(DT_FLOAT)), &v);
+ EXPECT_TRUE(!HasPlaceHolder(v));
+ EXPECT_EQ(SummarizeAttrValue(v),
+ "MatMul[dtype=DT_FLOAT, transpose_a=false, transpose_b=true, "
+ "use_cublas=true]");
+}
+
+TEST(AttrValueUtil, DeepAttr) {
+ auto v = F("f", {{"T", P("T")}});
+ TF_CHECK_OK(AttrValueHasType(v, "func"));
+ EXPECT_TRUE(HasPlaceHolder(v));
+
+ for (int i = 0; i < 3; ++i) {
+ v = F("f", {{"T", P("T")}, {"F", v}});
+ EXPECT_TRUE(HasPlaceHolder(v));
+ }
+ EXPECT_EQ(SummarizeAttrValue(v), "f[F=f[F=f[F=f[T=$T], T=$T], T=$T], T=$T]");
+
+ SubstitutePlaceholders(ReplaceTWith(F("x", {})), &v);
+ EXPECT_TRUE(!HasPlaceHolder(v));
+ EXPECT_EQ(SummarizeAttrValue(v),
+ "f[F=f[F=f[F=f[T=x[]], T=x[]], T=x[]], T=x[]]");
+}
+
+} // namespace tensorflow