aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/node_def_util_test.cc
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-09-20 14:49:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 14:52:53 -0700
commit424f0556ad8acde8f912a67e46421957a71dcef2 (patch)
tree053a1ec1c0993d4d6778ad3dda79dfd9425a0d7f /tensorflow/core/framework/node_def_util_test.cc
parent800cc654de0bb99c5753fc4ab26a9293547ee0b3 (diff)
[tf.data] Some vectorization cleanup
PiperOrigin-RevId: 213886813
Diffstat (limited to 'tensorflow/core/framework/node_def_util_test.cc')
-rw-r--r--tensorflow/core/framework/node_def_util_test.cc42
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc
index 74cc594863..d9d437024a 100644
--- a/tensorflow/core/framework/node_def_util_test.cc
+++ b/tensorflow/core/framework/node_def_util_test.cc
@@ -370,6 +370,48 @@ TEST(NodeDefUtilTest, ValidSyntax) {
"Illegal op input name 'a:00");
}
+TEST(InputTypesForNode, Simple) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+ .Input("a: float")
+ .Input("b: int32")
+ .Output("c: string")
+ .Output("d: bool"));
+ const NodeDef node_def = ToNodeDef(
+ NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+ DataTypeVector types;
+ EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok());
+ EXPECT_EQ(types[0], DT_FLOAT);
+ EXPECT_EQ(types[1], DT_INT32);
+
+ DataType type;
+ EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok());
+ EXPECT_EQ(type, DT_FLOAT);
+ EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok());
+ EXPECT_EQ(type, DT_INT32);
+ EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok());
+}
+
+TEST(OutputTypesForNode, Simple) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+ .Input("a: float")
+ .Input("b: int32")
+ .Output("c: string")
+ .Output("d: bool"));
+ const NodeDef node_def = ToNodeDef(
+ NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+ DataTypeVector types;
+ EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok());
+ EXPECT_EQ(types[0], DT_STRING);
+ EXPECT_EQ(types[1], DT_BOOL);
+
+ DataType type;
+ EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok());
+ EXPECT_EQ(type, DT_STRING);
+ EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok());
+ EXPECT_EQ(type, DT_BOOL);
+ EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok());
+}
+
TEST(NameRangesForNodeTest, Simple) {
const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
.Input("a: float")