diff options
author | Rachel Lim <rachelim@google.com> | 2018-09-20 14:49:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-20 14:52:53 -0700 |
commit | 424f0556ad8acde8f912a67e46421957a71dcef2 (patch) | |
tree | 053a1ec1c0993d4d6778ad3dda79dfd9425a0d7f /tensorflow/core/framework/node_def_util_test.cc | |
parent | 800cc654de0bb99c5753fc4ab26a9293547ee0b3 (diff) |
[tf.data] Some vectorization cleanup
PiperOrigin-RevId: 213886813
Diffstat (limited to 'tensorflow/core/framework/node_def_util_test.cc')
-rw-r--r-- | tensorflow/core/framework/node_def_util_test.cc | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc index 74cc594863..d9d437024a 100644 --- a/tensorflow/core/framework/node_def_util_test.cc +++ b/tensorflow/core/framework/node_def_util_test.cc @@ -370,6 +370,48 @@ TEST(NodeDefUtilTest, ValidSyntax) { "Illegal op input name 'a:00"); } +TEST(InputTypesForNode, Simple) { + const OpDef op_def = ToOpDef(OpDefBuilder("Simple") + .Input("a: float") + .Input("b: int32") + .Output("c: string") + .Output("d: bool")); + const NodeDef node_def = ToNodeDef( + NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())); + DataTypeVector types; + EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok()); + EXPECT_EQ(types[0], DT_FLOAT); + EXPECT_EQ(types[1], DT_INT32); + + DataType type; + EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok()); + EXPECT_EQ(type, DT_FLOAT); + EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok()); + EXPECT_EQ(type, DT_INT32); + EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok()); +} + +TEST(OutputTypesForNode, Simple) { + const OpDef op_def = ToOpDef(OpDefBuilder("Simple") + .Input("a: float") + .Input("b: int32") + .Output("c: string") + .Output("d: bool")); + const NodeDef node_def = ToNodeDef( + NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())); + DataTypeVector types; + EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok()); + EXPECT_EQ(types[0], DT_STRING); + EXPECT_EQ(types[1], DT_BOOL); + + DataType type; + EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok()); + EXPECT_EQ(type, DT_STRING); + EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok()); + EXPECT_EQ(type, DT_BOOL); + EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok()); +} + TEST(NameRangesForNodeTest, Simple) { const OpDef op_def = ToOpDef(OpDefBuilder("Simple") .Input("a: float") |