aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-09-20 14:49:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 14:52:53 -0700
commit424f0556ad8acde8f912a67e46421957a71dcef2 (patch)
tree053a1ec1c0993d4d6778ad3dda79dfd9425a0d7f /tensorflow/core/framework
parent800cc654de0bb99c5753fc4ab26a9293547ee0b3 (diff)
[tf.data] Some vectorization cleanup
PiperOrigin-RevId: 213886813
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/node_def_util.cc12
-rw-r--r--tensorflow/core/framework/node_def_util.h4
-rw-r--r--tensorflow/core/framework/node_def_util_test.cc42
3 files changed, 55 insertions, 3 deletions
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index 42ec315a32..43ac1d0ada 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -372,6 +372,14 @@ Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
node_def.name());
}
+Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs) {
+ for (const auto& arg : op_def.input_arg()) {
+ TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
+ }
+ return Status::OK();
+}
+
Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
int output_port, DataType* output_type) {
DataTypeVector output_types;
@@ -397,9 +405,7 @@ Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
DataTypeVector* inputs, DataTypeVector* outputs) {
- for (const auto& arg : op_def.input_arg()) {
- TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
- }
+ TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
return OutputTypesForNode(node_def, op_def, outputs);
}
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 7528d3d306..187bfa2c88 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -249,6 +249,10 @@ const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name);
// REQUIRES: ValidateOpDef(op_def).ok()
Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
int input_port, DataType* input_type);
+// Computes the input types for a specific node.
+// REQUIRES: ValidateOpDef(op_def).ok()
+Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs);
// Computes the output type for a specific node output.
// REQUIRES: ValidateOpDef(op_def).ok()
Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc
index 74cc594863..d9d437024a 100644
--- a/tensorflow/core/framework/node_def_util_test.cc
+++ b/tensorflow/core/framework/node_def_util_test.cc
@@ -370,6 +370,48 @@ TEST(NodeDefUtilTest, ValidSyntax) {
"Illegal op input name 'a:00");
}
+TEST(InputTypesForNode, Simple) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+ .Input("a: float")
+ .Input("b: int32")
+ .Output("c: string")
+ .Output("d: bool"));
+ const NodeDef node_def = ToNodeDef(
+ NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+ DataTypeVector types;
+ EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok());
+ EXPECT_EQ(types[0], DT_FLOAT);
+ EXPECT_EQ(types[1], DT_INT32);
+
+ DataType type;
+ EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok());
+ EXPECT_EQ(type, DT_FLOAT);
+ EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok());
+ EXPECT_EQ(type, DT_INT32);
+ EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok());
+}
+
+TEST(OutputTypesForNode, Simple) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+ .Input("a: float")
+ .Input("b: int32")
+ .Output("c: string")
+ .Output("d: bool"));
+ const NodeDef node_def = ToNodeDef(
+ NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+ DataTypeVector types;
+ EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok());
+ EXPECT_EQ(types[0], DT_STRING);
+ EXPECT_EQ(types[1], DT_BOOL);
+
+ DataType type;
+ EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok());
+ EXPECT_EQ(type, DT_STRING);
+ EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok());
+ EXPECT_EQ(type, DT_BOOL);
+ EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok());
+}
+
TEST(NameRangesForNodeTest, Simple) {
const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
.Input("a: float")