diff options
author | Yao Zhang <yaozhang@google.com> | 2018-02-08 23:17:54 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-08 23:21:54 -0800 |
commit | 72b1a058613c26938a57670b3f32e29ba0e58d23 (patch) | |
tree | ea1ded10d70f507df2b6ec472a1d7d96af9a0bd1 /tensorflow/python/grappler | |
parent | 7f1f8b6f75c03c80dd153d508ffe255da1b151c1 (diff) |
Only convert format if input is of layout-agnostic type.
PiperOrigin-RevId: 185103227
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r-- | tensorflow/python/grappler/layout_optimizer_test.py | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 5bc9e4b803..25b1cdcbc5 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -1127,6 +1127,37 @@ class LayoutOptimizerTest(test.TestCase): self._assert_vec_nchw_to_nhwc('ShapeN-0-0', nodes) self.assertAllEqual(output_val_ref, output_val) + def testShapeNFollowedByNotConvertibleNodeReshape(self): + if test.is_gpu_available(cuda_only=True): + x = array_ops.placeholder(dtype='float32') + conv = _two_layer_model(x) + conv_reshape = array_ops.reshape(conv, [1, 1, 1, -1]) + shapen = array_ops.shape_n([conv, conv_reshape]) + shape = array_ops.identity(shapen[1]) + ones = array_ops.ones(shape) + output = math_ops.add_n([conv_reshape, ones]) + + x_val = [1.7] * 784 + with session.Session() as sess: + output_val_ref = sess.run(output, feed_dict={x: x_val}) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run( + output, run_metadata=metadata, feed_dict={x: x_val}) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) + self.assertAllEqual(output_val_ref, output_val) + def testLoop(self): if test.is_gpu_available(cuda_only=True): output = _loop() |