aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2018-02-08 23:17:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-08 23:21:54 -0800
commit72b1a058613c26938a57670b3f32e29ba0e58d23 (patch)
treeea1ded10d70f507df2b6ec472a1d7d96af9a0bd1 /tensorflow/python/grappler
parent7f1f8b6f75c03c80dd153d508ffe255da1b151c1 (diff)
Only convert format if input is of layout-agnostic type.
PiperOrigin-RevId: 185103227
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 5bc9e4b803..25b1cdcbc5 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -1127,6 +1127,37 @@ class LayoutOptimizerTest(test.TestCase):
self._assert_vec_nchw_to_nhwc('ShapeN-0-0', nodes)
self.assertAllEqual(output_val_ref, output_val)
+ def testShapeNFollowedByNotConvertibleNodeReshape(self):
+ if test.is_gpu_available(cuda_only=True):
+ x = array_ops.placeholder(dtype='float32')
+ conv = _two_layer_model(x)
+ conv_reshape = array_ops.reshape(conv, [1, 1, 1, -1])
+ shapen = array_ops.shape_n([conv, conv_reshape])
+ shape = array_ops.identity(shapen[1])
+ ones = array_ops.ones(shape)
+ output = math_ops.add_n([conv_reshape, ones])
+
+ x_val = [1.7] * 784
+ with session.Session() as sess:
+ output_val_ref = sess.run(output, feed_dict={x: x_val})
+
+ with session.Session(config=_get_config()) as sess:
+ metadata = config_pb2.RunMetadata()
+ output_val = sess.run(
+ output, run_metadata=metadata, feed_dict={x: x_val})
+
+ nodes = []
+ num_transposes = 0
+ for node in metadata.cost_graph.node:
+ if _is_transpose(node.name):
+ num_transposes += 1
+ nodes.append(node.name)
+
+ expected_num_transposes = 2
+ self.assertEqual(expected_num_transposes, num_transposes)
+ self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
+ self.assertAllEqual(output_val_ref, output_val)
+
def testLoop(self):
if test.is_gpu_available(cuda_only=True):
output = _loop()