aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/pooling_test.py
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2018-01-24 23:36:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-24 23:40:14 -0800
commit86967885684433c86d4764d82e5d975e3ef4ab8e (patch)
tree31dd9bb1f812ae65c46c77e1f9217b2bc7968bbd /tensorflow/python/layers/pooling_test.py
parent72c420a32702f7a7638c0130a7d7dc1db4469840 (diff)
Fix eager Pooling1D unit test for data_format='channels_first'
PiperOrigin-RevId: 183196050
Diffstat (limited to 'tensorflow/python/layers/pooling_test.py')
-rw-r--r--tensorflow/python/layers/pooling_test.py24
1 files changed, 16 insertions, 8 deletions
diff --git a/tensorflow/python/layers/pooling_test.py b/tensorflow/python/layers/pooling_test.py
index e4d4ed4a2a..7533674e5a 100644
--- a/tensorflow/python/layers/pooling_test.py
+++ b/tensorflow/python/layers/pooling_test.py
@@ -96,33 +96,41 @@ class PoolingTest(test.TestCase):
def testCreateMaxPooling1D(self):
width = 7
- images = random_ops.random_uniform((5, width, 4))
+ channels = 3
+ images = random_ops.random_uniform((5, width, channels))
layer = pooling_layers.MaxPooling1D(2, strides=2)
output = layer.apply(images)
- self.assertListEqual(output.get_shape().as_list(), [5, 3, 4])
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, width // 2, channels])
def testCreateAveragePooling1D(self):
width = 7
- images = random_ops.random_uniform((5, width, 4))
+ channels = 3
+ images = random_ops.random_uniform((5, width, channels))
layer = pooling_layers.AveragePooling1D(2, strides=2)
output = layer.apply(images)
- self.assertListEqual(output.get_shape().as_list(), [5, 3, 4])
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, width // 2, channels])
def testCreateMaxPooling1DChannelsFirst(self):
width = 7
- images = random_ops.random_uniform((5, 4, width))
+ channels = 3
+ images = random_ops.random_uniform((5, channels, width))
layer = pooling_layers.MaxPooling1D(
2, strides=2, data_format='channels_first')
output = layer.apply(images)
- self.assertListEqual(output.get_shape().as_list(), [5, 4, 3])
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, channels, width // 2])
def testCreateAveragePooling1DChannelsFirst(self):
width = 7
- images = random_ops.random_uniform((5, 4, width))
+ channels = 3
+ images = random_ops.random_uniform((5, channels, width))
layer = pooling_layers.AveragePooling1D(
2, strides=2, data_format='channels_first')
output = layer.apply(images)
- self.assertListEqual(output.get_shape().as_list(), [5, 4, 3])
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, channels, width // 2])
def testCreateMaxPooling3D(self):
depth, height, width = 6, 7, 9