diff options
author | 2018-07-20 15:41:04 -0700 | |
---|---|---|
committer | 2018-07-20 15:44:47 -0700 | |
commit | f4f37efdc95adc4b2c6235479b89ddfbaf4b3eed (patch) | |
tree | de67f8bd9d72716567eaa989fde113b18fb21f10 /tensorflow/core/grappler | |
parent | 9151e6139881033253fb671ea615b8b8f2529380 (diff) |
Update Grappler to use existing functions for retrieving a node's
name and position.
PiperOrigin-RevId: 205465354
Diffstat (limited to 'tensorflow/core/grappler')
3 files changed, 285 insertions, 10 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 83a8326e79..231c7c63be 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -496,18 +496,11 @@ class SymbolicShapeRefiner { "supported."); } + // It is guaranteed that output_tensors does not contain any control + // inputs, so port_id >= 0. string out_tensor = out_arg.output_tensors[0]; - auto out_tensor_pieces = str_util::Split(out_tensor, ","); - string node_name = out_tensor_pieces[0]; int port_id; - - // Check if port_id was included in out_tensor - if (out_tensor_pieces.size() <= 1) { - port_id = 0; - } else if (!strings::safe_strto32(out_tensor_pieces[1], &port_id)) { - return errors::FailedPrecondition( - "Failed string to integer conversion for ", out_tensor_pieces[1]); - } + string node_name = ParseNodeName(out_tensor, &port_id); const NodeDef* retnode = gv.GetNode(node_name); if (retnode == nullptr) { @@ -516,6 +509,11 @@ class SymbolicShapeRefiner { } auto output_properties = gp.GetOutputProperties(retnode->name()); + if (port_id >= output_properties.size()) { + return errors::InvalidArgument( + out_tensor, " has invalid position ", port_id, + " (output_properties.size() = ", output_properties.size(), ")."); + } auto const& outprop = output_properties[port_id]; const TensorShapeProto& shape = outprop.shape(); ShapeHandle out; diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 1be19d291a..5acfb56b05 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -887,6 +887,44 @@ TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) { EXPECT_EQ(8, in_prop3.shape().dim(3).size()); } +TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) { + // Test graph produced in python using: + /* + @function.Defun(noinline=True) + def MyFunc(): + @function.Defun(*[tf.float32] * 2) + def Cond(n, unused_x): + return n > 0 + + @function.Defun(*[tf.float32] * 2) + def Body(n, x): + return n - 1, x + n + + i = tf.constant(10) + return functional_ops.While([i, 0.], Cond, Body) + + with tf.Graph().as_default(): + z = MyFunc() + */ + GrapplerItem item; + string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath, + "function_functional_while.pbtxt"); + TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + + const auto out_props = properties.GetOutputProperties("MyFunc_AenMyWWx1Us"); + EXPECT_EQ(2, out_props.size()); + + const OpInfo::TensorProperties& out_prop0 = out_props[0]; + EXPECT_EQ(DT_INT32, out_prop0.dtype()); + EXPECT_FALSE(out_prop0.shape().unknown_rank()); + + const OpInfo::TensorProperties& out_prop1 = out_props[1]; + EXPECT_EQ(DT_FLOAT, out_prop1.dtype()); + EXPECT_FALSE(out_prop1.shape().unknown_rank()); +} + TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) { GrapplerItem item; string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath, diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt new file mode 100644 index 0000000000..c94ee2f227 --- /dev/null +++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt @@ -0,0 +1,239 @@ +node { + name: "MyFunc_AenMyWWx1Us" + op: "MyFunc_AenMyWWx1Us" +} +library { + function { + signature { + name: "MyFunc_AenMyWWx1Us" + output_arg { + name: "while" + type: DT_INT32 + } + output_arg { + name: "while_0" + type: DT_FLOAT + } + is_stateful: true + } + node_def { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 10 + } + } + } + } + node_def { + name: "While/input_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } + } + node_def { + name: "While" + op: "While" + input: "Const:output:0" + input: "While/input_1:output:0" + attr { + key: "T" + value { + list { + type: DT_INT32 + type: DT_FLOAT + } + } + } + attr { + key: "body" + value { + func { + name: "Body_8GOMGeZeK5c" + } + } + } + attr { + key: "cond" + value { + func { + name: "Cond_Xf5ttAHgUCg" + } + } + } + } + ret { + key: "while" + value: "While:output:0" + } + ret { + key: "while_0" + value: "While:output:1" + } + attr { + key: "_noinline" + value { + b: true + } + } + } + function { + signature { + name: "Body_8GOMGeZeK5c" + input_arg { + name: "n" + type: DT_FLOAT + } + input_arg { + name: "x" + type: DT_FLOAT + } + output_arg { + name: "sub" + type: DT_FLOAT + } + output_arg { + name: "add" + type: DT_FLOAT + } + } + node_def { + name: "sub/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } + } + node_def { + name: "sub_0" + op: "Sub" + input: "n" + input: "sub/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "add_0" + op: "Add" + input: "x" + input: "n" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + ret { + key: "add" + value: "add_0:z:0" + } + ret { + key: "sub" + value: "sub_0:z:0" + } + } + function { + signature { + name: "Cond_Xf5ttAHgUCg" + input_arg { + name: "n" + type: DT_FLOAT + } + input_arg { + name: "unused_x" + type: DT_FLOAT + } + output_arg { + name: "greater" + type: DT_BOOL + } + } + node_def { + name: "Greater/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } + } + node_def { + name: "Greater" + op: "Greater" + input: "n" + input: "Greater/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + ret { + key: "greater" + value: "Greater:z:0" + } + } +} +versions { + producer: 26 + min_consumer: 12 +} |