aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/subgraph_test.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-01-19 12:06:44 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-19 12:23:43 -0800
commita1118580819300cea8e87a1ef07a8c717b9bb459 (patch)
tree7ccdfd2c304649e08f8dd3fd86dd629372a3d768 /tensorflow/core/graph/subgraph_test.cc
parent69b8051d9410b83b110107efbdd87661c17bf8d2 (diff)
Change RewriteGraphForExecution to propagate an _output_shapes annotation to feed nodes, if any.
Change: 144991987
Diffstat (limited to 'tensorflow/core/graph/subgraph_test.cc')
-rw-r--r--tensorflow/core/graph/subgraph_test.cc38
1 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc
index e3f6504ff9..ee4960121f 100644
--- a/tensorflow/core/graph/subgraph_test.cc
+++ b/tensorflow/core/graph/subgraph_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
@@ -274,6 +275,43 @@ TEST_F(SubgraphTest, Errors) {
EXPECT_TRUE(HasSubstr(Subgraph("", "", ""), "at least one target"));
}
+TEST_F(SubgraphTest, FedOutputsPreservesOutputShapes) {
+ ExpectOK(
+ R"proto(
+ node { name: 'W1' op: 'TestParams' }
+ node { name: 'W2' op: 'TestParams' }
+ node {
+ name: 'input'
+ op: 'TestInput'
+ attr {
+ key: '_output_shapes'
+ value {
+ list {
+ shape { unknown_rank: true }
+ shape { dim { size: 23 } }
+ }
+ }
+ }
+ }
+ node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }
+ node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }
+ node { name: 't3_a' op: 'TestRelu' input: 't2' }
+ node { name: 't3_b' op: 'TestRelu' input: 't2' }
+ )proto");
+ EXPECT_EQ("OK", Subgraph("input:1", "", "t2"));
+ ExpectNodes("W1,W2,_recv_input_1,t1,t2");
+
+ for (Node* node : graph()->nodes()) {
+ if (node->name() == "_recv_input_1") {
+ std::vector<PartialTensorShape> shapes;
+ TF_ASSERT_OK(GetNodeAttr(node->def(), "_output_shapes", &shapes));
+ ASSERT_EQ(1, shapes.size());
+ EXPECT_TRUE(PartialTensorShape({23}).IsIdenticalTo(shapes[0]));
+ break;
+ }
+ }
+}
+
REGISTER_OP("In").Output("o: float");
REGISTER_OP("Op").Input("i: float").Output("o: float");