diff options
author | 2017-12-20 16:13:59 -0800 | |
---|---|---|
committer | 2017-12-20 16:17:19 -0800 | |
commit | 92f21cd359d0253f5870688442d77a41868d8ffe (patch) | |
tree | 97c46eb9c167e23cbf21ebe8d3cf4450982e2cfc /tensorflow/python/grappler | |
parent | 4edaeba3b14e706740ef4160afa257de565ffd6e (diff) |
Support max pool v2.
PiperOrigin-RevId: 179747281
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r-- | tensorflow/python/grappler/layout_optimizer_test.py | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 1d9b87519a..2c15e340c9 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops @@ -416,6 +417,81 @@ class LayoutOptimizerTest(test.TestCase): self.assertIn('LayoutOptimizerVecPermuteNHWCToNCHW_Pad_1', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) + def testMaxPoolV2(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + ksize = constant_op.constant([1, 2, 3, 1], shape=[4]) + strides = array_ops.placeholder(dtype='int32', shape=[4]) + max_pool = gen_nn_ops._max_pool_v2(conv, ksize, strides, 'VALID') + output = array_ops.identity(max_pool) + + strides_val = [1, 3, 2, 1] + with session.Session() as sess: + output_val_ref = sess.run(output, feed_dict={strides: strides_val}) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run( + output, run_metadata=metadata, feed_dict={ + strides: strides_val + }) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if node.name.startswith('LayoutOptimizerTranspose'): + num_transposes += 1 + nodes.append(node.name) + + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes) + self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-MaxPoolV2-0-0', nodes) + self.assertIn('LayoutOptimizerVecPermuteNHWCToNCHW_MaxPoolV2_2', nodes) + self.assertIn('LayoutOptimizer-MaxPoolV2-Const_2', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + + def testMaxPoolGradV2(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + ksize = constant_op.constant([1, 2, 3, 1], shape=[4]) + strides = array_ops.placeholder(dtype='int32', shape=[4]) + max_pool_grad = gen_nn_ops.max_pool_grad_v2(conv, conv, conv, ksize, + strides, 'VALID') + output = array_ops.identity(max_pool_grad) + + strides_val = [1, 3, 2, 1] + with session.Session() as sess: + output_val_ref = sess.run(output, feed_dict={strides: strides_val}) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run( + output, run_metadata=metadata, feed_dict={ + strides: strides_val + }) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if node.name.startswith('LayoutOptimizerTranspose'): + num_transposes += 1 + nodes.append(node.name) + + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes) + self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-MaxPoolGradV2-0-0', + nodes) + self.assertIn('LayoutOptimizerVecPermuteNHWCToNCHW_MaxPoolGradV2_4', + nodes) + self.assertIn('LayoutOptimizer-MaxPoolGradV2-Const_2', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + def testSliceWithNonConstAxis(self): if test.is_gpu_available(cuda_only=True): random_seed.set_random_seed(0) |