diff options
author | Peter Hawkins <phawkins@google.com> | 2017-01-19 12:06:44 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-19 12:23:43 -0800 |
commit | a1118580819300cea8e87a1ef07a8c717b9bb459 (patch) | |
tree | 7ccdfd2c304649e08f8dd3fd86dd629372a3d768 /tensorflow/core/graph/subgraph_test.cc | |
parent | 69b8051d9410b83b110107efbdd87661c17bf8d2 (diff) |
Change RewriteGraphForExecution to propagate an _output_shapes annotation to feed nodes, if any.
Change: 144991987
Diffstat (limited to 'tensorflow/core/graph/subgraph_test.cc')
-rw-r--r-- | tensorflow/core/graph/subgraph_test.cc | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc index e3f6504ff9..ee4960121f 100644 --- a/tensorflow/core/graph/subgraph_test.cc +++ b/tensorflow/core/graph/subgraph_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" @@ -274,6 +275,43 @@ TEST_F(SubgraphTest, Errors) { EXPECT_TRUE(HasSubstr(Subgraph("", "", ""), "at least one target")); } +TEST_F(SubgraphTest, FedOutputsPreservesOutputShapes) { + ExpectOK( + R"proto( + node { name: 'W1' op: 'TestParams' } + node { name: 'W2' op: 'TestParams' } + node { + name: 'input' + op: 'TestInput' + attr { + key: '_output_shapes' + value { + list { + shape { unknown_rank: true } + shape { dim { size: 23 } } + } + } + } + } + node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] } + node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] } + node { name: 't3_a' op: 'TestRelu' input: 't2' } + node { name: 't3_b' op: 'TestRelu' input: 't2' } + )proto"); + EXPECT_EQ("OK", Subgraph("input:1", "", "t2")); + ExpectNodes("W1,W2,_recv_input_1,t1,t2"); + + for (Node* node : graph()->nodes()) { + if (node->name() == "_recv_input_1") { + std::vector<PartialTensorShape> shapes; + TF_ASSERT_OK(GetNodeAttr(node->def(), "_output_shapes", &shapes)); + ASSERT_EQ(1, shapes.size()); + EXPECT_TRUE(PartialTensorShape({23}).IsIdenticalTo(shapes[0])); + break; + } + } +} + REGISTER_OP("In").Output("o: float"); REGISTER_OP("Op").Input("i: float").Output("o: float"); |