diff options
author | Rachel Lim <rachelim@google.com> | 2018-09-20 14:49:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-20 14:52:53 -0700 |
commit | 424f0556ad8acde8f912a67e46421957a71dcef2 (patch) | |
tree | 053a1ec1c0993d4d6778ad3dda79dfd9425a0d7f /tensorflow/core/framework | |
parent | 800cc654de0bb99c5753fc4ab26a9293547ee0b3 (diff) |
[tf.data] Some vectorization cleanup
PiperOrigin-RevId: 213886813
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/node_def_util.cc | 12 | ||||
-rw-r--r-- | tensorflow/core/framework/node_def_util.h | 4 | ||||
-rw-r--r-- | tensorflow/core/framework/node_def_util_test.cc | 42 |
3 files changed, 55 insertions, 3 deletions
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index 42ec315a32..43ac1d0ada 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -372,6 +372,14 @@ Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def, node_def.name()); } +Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs) { + for (const auto& arg : op_def.input_arg()) { + TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs)); + } + return Status::OK(); +} + Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, int output_port, DataType* output_type) { DataTypeVector output_types; @@ -397,9 +405,7 @@ Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def, Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, DataTypeVector* inputs, DataTypeVector* outputs) { - for (const auto& arg : op_def.input_arg()) { - TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs)); - } + TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs)); return OutputTypesForNode(node_def, op_def, outputs); } diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index 7528d3d306..187bfa2c88 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -249,6 +249,10 @@ const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name); // REQUIRES: ValidateOpDef(op_def).ok() Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def, int input_port, DataType* input_type); +// Computes the input types for a specific node. +// REQUIRES: ValidateOpDef(op_def).ok() +Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs); // Computes the output type for a specific node output. // REQUIRES: ValidateOpDef(op_def).ok() Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc index 74cc594863..d9d437024a 100644 --- a/tensorflow/core/framework/node_def_util_test.cc +++ b/tensorflow/core/framework/node_def_util_test.cc @@ -370,6 +370,48 @@ TEST(NodeDefUtilTest, ValidSyntax) { "Illegal op input name 'a:00"); } +TEST(InputTypesForNode, Simple) { + const OpDef op_def = ToOpDef(OpDefBuilder("Simple") + .Input("a: float") + .Input("b: int32") + .Output("c: string") + .Output("d: bool")); + const NodeDef node_def = ToNodeDef( + NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())); + DataTypeVector types; + EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok()); + EXPECT_EQ(types[0], DT_FLOAT); + EXPECT_EQ(types[1], DT_INT32); + + DataType type; + EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok()); + EXPECT_EQ(type, DT_FLOAT); + EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok()); + EXPECT_EQ(type, DT_INT32); + EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok()); +} + +TEST(OutputTypesForNode, Simple) { + const OpDef op_def = ToOpDef(OpDefBuilder("Simple") + .Input("a: float") + .Input("b: int32") + .Output("c: string") + .Output("d: bool")); + const NodeDef node_def = ToNodeDef( + NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())); + DataTypeVector types; + EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok()); + EXPECT_EQ(types[0], DT_STRING); + EXPECT_EQ(types[1], DT_BOOL); + + DataType type; + EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok()); + EXPECT_EQ(type, DT_STRING); + EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok()); + EXPECT_EQ(type, DT_BOOL); + EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok()); +} + TEST(NameRangesForNodeTest, Simple) { const OpDef op_def = ToOpDef(OpDefBuilder("Simple") .Input("a: float") |