diff options
author | 2017-01-19 12:06:44 -0800 | |
---|---|---|
committer | 2017-01-19 12:23:43 -0800 | |
commit | a1118580819300cea8e87a1ef07a8c717b9bb459 (patch) | |
tree | 7ccdfd2c304649e08f8dd3fd86dd629372a3d768 /tensorflow/core/graph | |
parent | 69b8051d9410b83b110107efbdd87661c17bf8d2 (diff) |
Change RewriteGraphForExecution to propagate an _output_shapes annotation to feed nodes, if any.
Change: 144991987
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r-- | tensorflow/core/graph/subgraph.cc | 14 | ||||
-rw-r--r-- | tensorflow/core/graph/subgraph_test.cc | 38 |
2 files changed, 52 insertions, 0 deletions
diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc index 58199140d2..5622f0ca59 100644 --- a/tensorflow/core/graph/subgraph.cc +++ b/tensorflow/core/graph/subgraph.cc @@ -84,6 +84,20 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info, .Finalize(g, &recv_node)); recv_node->set_assigned_device_name(device_info.name()); + // Copy the _output_shapes from the original node to the feed node, + // if any. + std::vector<PartialTensorShape> output_shapes; + if (GetNodeAttr(n->def(), "_output_shapes", &output_shapes).ok()) { + if (n->num_outputs() != output_shapes.size()) { + return errors::InvalidArgument( + "FeedInputs: ", t, + ": size of _output_shapes attribute does not " + "match the number of node outputs"); + } + std::vector<PartialTensorShape> feed_shapes = {output_shapes[id.second]}; + recv_node->AddAttr("_output_shapes", feed_shapes); + } + // Update name_index (*name_index)[recv_node->name()] = recv_node; g->AddControlEdge(g->source_node(), recv_node); diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc index e3f6504ff9..ee4960121f 100644 --- a/tensorflow/core/graph/subgraph_test.cc +++ b/tensorflow/core/graph/subgraph_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" @@ -274,6 +275,43 @@ TEST_F(SubgraphTest, Errors) { EXPECT_TRUE(HasSubstr(Subgraph("", "", ""), "at least one target")); } +TEST_F(SubgraphTest, FedOutputsPreservesOutputShapes) { + ExpectOK( + R"proto( + node { name: 'W1' op: 'TestParams' } + node { name: 'W2' op: 'TestParams' } + node { + name: 'input' + op: 'TestInput' + attr { + key: '_output_shapes' + value { + list { + shape { unknown_rank: true } + shape { dim { size: 23 } } + } + } + } + } + node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] } + node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] } + node { name: 't3_a' op: 'TestRelu' input: 't2' } + node { name: 't3_b' op: 'TestRelu' input: 't2' } + )proto"); + EXPECT_EQ("OK", Subgraph("input:1", "", "t2")); + ExpectNodes("W1,W2,_recv_input_1,t1,t2"); + + for (Node* node : graph()->nodes()) { + if (node->name() == "_recv_input_1") { + std::vector<PartialTensorShape> shapes; + TF_ASSERT_OK(GetNodeAttr(node->def(), "_output_shapes", &shapes)); + ASSERT_EQ(1, shapes.size()); + EXPECT_TRUE(PartialTensorShape({23}).IsIdenticalTo(shapes[0])); + break; + } + } +} + REGISTER_OP("In").Output("o: float"); REGISTER_OP("Op").Input("i: float").Output("o: float"); |