aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-20 15:41:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 15:44:47 -0700
commitf4f37efdc95adc4b2c6235479b89ddfbaf4b3eed (patch)
treede67f8bd9d72716567eaa989fde113b18fb21f10 /tensorflow/core/grappler
parent9151e6139881033253fb671ea615b8b8f2529380 (diff)
Update Grappler to use existing functions for retrieving a node's
name and position. PiperOrigin-RevId: 205465354
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc18
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc38
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt239
3 files changed, 285 insertions, 10 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 83a8326e79..231c7c63be 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -496,18 +496,11 @@ class SymbolicShapeRefiner {
"supported.");
}
+ // It is guaranteed that output_tensors does not contain any control
+ // inputs, so port_id >= 0.
string out_tensor = out_arg.output_tensors[0];
- auto out_tensor_pieces = str_util::Split(out_tensor, ",");
- string node_name = out_tensor_pieces[0];
int port_id;
-
- // Check if port_id was included in out_tensor
- if (out_tensor_pieces.size() <= 1) {
- port_id = 0;
- } else if (!strings::safe_strto32(out_tensor_pieces[1], &port_id)) {
- return errors::FailedPrecondition(
- "Failed string to integer conversion for ", out_tensor_pieces[1]);
- }
+ string node_name = ParseNodeName(out_tensor, &port_id);
const NodeDef* retnode = gv.GetNode(node_name);
if (retnode == nullptr) {
@@ -516,6 +509,11 @@ class SymbolicShapeRefiner {
}
auto output_properties = gp.GetOutputProperties(retnode->name());
+ if (port_id >= output_properties.size()) {
+ return errors::InvalidArgument(
+ out_tensor, " has invalid position ", port_id,
+ " (output_properties.size() = ", output_properties.size(), ").");
+ }
auto const& outprop = output_properties[port_id];
const TensorShapeProto& shape = outprop.shape();
ShapeHandle out;
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 1be19d291a..5acfb56b05 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -887,6 +887,44 @@ TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
EXPECT_EQ(8, in_prop3.shape().dim(3).size());
}
+TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) {
+ // Test graph produced in python using:
+ /*
+ @function.Defun(noinline=True)
+ def MyFunc():
+ @function.Defun(*[tf.float32] * 2)
+ def Cond(n, unused_x):
+ return n > 0
+
+ @function.Defun(*[tf.float32] * 2)
+ def Body(n, x):
+ return n - 1, x + n
+
+ i = tf.constant(10)
+ return functional_ops.While([i, 0.], Cond, Body)
+
+ with tf.Graph().as_default():
+ z = MyFunc()
+ */
+ GrapplerItem item;
+ string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
+ "function_functional_while.pbtxt");
+ TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ const auto out_props = properties.GetOutputProperties("MyFunc_AenMyWWx1Us");
+ EXPECT_EQ(2, out_props.size());
+
+ const OpInfo::TensorProperties& out_prop0 = out_props[0];
+ EXPECT_EQ(DT_INT32, out_prop0.dtype());
+ EXPECT_FALSE(out_prop0.shape().unknown_rank());
+
+ const OpInfo::TensorProperties& out_prop1 = out_props[1];
+ EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
+ EXPECT_FALSE(out_prop1.shape().unknown_rank());
+}
+
TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) {
GrapplerItem item;
string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt
new file mode 100644
index 0000000000..c94ee2f227
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt
@@ -0,0 +1,239 @@
+node {
+ name: "MyFunc_AenMyWWx1Us"
+ op: "MyFunc_AenMyWWx1Us"
+}
+library {
+ function {
+ signature {
+ name: "MyFunc_AenMyWWx1Us"
+ output_arg {
+ name: "while"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "while_0"
+ type: DT_FLOAT
+ }
+ is_stateful: true
+ }
+ node_def {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 10
+ }
+ }
+ }
+ }
+ node_def {
+ name: "While/input_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.0
+ }
+ }
+ }
+ }
+ node_def {
+ name: "While"
+ op: "While"
+ input: "Const:output:0"
+ input: "While/input_1:output:0"
+ attr {
+ key: "T"
+ value {
+ list {
+ type: DT_INT32
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ key: "body"
+ value {
+ func {
+ name: "Body_8GOMGeZeK5c"
+ }
+ }
+ }
+ attr {
+ key: "cond"
+ value {
+ func {
+ name: "Cond_Xf5ttAHgUCg"
+ }
+ }
+ }
+ }
+ ret {
+ key: "while"
+ value: "While:output:0"
+ }
+ ret {
+ key: "while_0"
+ value: "While:output:1"
+ }
+ attr {
+ key: "_noinline"
+ value {
+ b: true
+ }
+ }
+ }
+ function {
+ signature {
+ name: "Body_8GOMGeZeK5c"
+ input_arg {
+ name: "n"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "x"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "sub"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "add"
+ type: DT_FLOAT
+ }
+ }
+ node_def {
+ name: "sub/y"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+ }
+ node_def {
+ name: "sub_0"
+ op: "Sub"
+ input: "n"
+ input: "sub/y:output:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node_def {
+ name: "add_0"
+ op: "Add"
+ input: "x"
+ input: "n"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ ret {
+ key: "add"
+ value: "add_0:z:0"
+ }
+ ret {
+ key: "sub"
+ value: "sub_0:z:0"
+ }
+ }
+ function {
+ signature {
+ name: "Cond_Xf5ttAHgUCg"
+ input_arg {
+ name: "n"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "unused_x"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "greater"
+ type: DT_BOOL
+ }
+ }
+ node_def {
+ name: "Greater/y"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.0
+ }
+ }
+ }
+ }
+ node_def {
+ name: "Greater"
+ op: "Greater"
+ input: "n"
+ input: "Greater/y:output:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ ret {
+ key: "greater"
+ value: "Greater:z:0"
+ }
+ }
+}
+versions {
+ producer: 26
+ min_consumer: 12
+}