diff options
author | 2018-01-24 23:36:20 -0800 | |
---|---|---|
committer | 2018-01-24 23:40:14 -0800 | |
commit | 86967885684433c86d4764d82e5d975e3ef4ab8e (patch) | |
tree | 31dd9bb1f812ae65c46c77e1f9217b2bc7968bbd /tensorflow/python/layers/pooling_test.py | |
parent | 72c420a32702f7a7638c0130a7d7dc1db4469840 (diff) |
Fix eager Pooling1D unit test for data_format='channels_first'
PiperOrigin-RevId: 183196050
Diffstat (limited to 'tensorflow/python/layers/pooling_test.py')
-rw-r--r-- | tensorflow/python/layers/pooling_test.py | 24 |
1 files changed, 16 insertions, 8 deletions
diff --git a/tensorflow/python/layers/pooling_test.py b/tensorflow/python/layers/pooling_test.py index e4d4ed4a2a..7533674e5a 100644 --- a/tensorflow/python/layers/pooling_test.py +++ b/tensorflow/python/layers/pooling_test.py @@ -96,33 +96,41 @@ class PoolingTest(test.TestCase): def testCreateMaxPooling1D(self): width = 7 - images = random_ops.random_uniform((5, width, 4)) + channels = 3 + images = random_ops.random_uniform((5, width, channels)) layer = pooling_layers.MaxPooling1D(2, strides=2) output = layer.apply(images) - self.assertListEqual(output.get_shape().as_list(), [5, 3, 4]) + self.assertListEqual(output.get_shape().as_list(), + [5, width // 2, channels]) def testCreateAveragePooling1D(self): width = 7 - images = random_ops.random_uniform((5, width, 4)) + channels = 3 + images = random_ops.random_uniform((5, width, channels)) layer = pooling_layers.AveragePooling1D(2, strides=2) output = layer.apply(images) - self.assertListEqual(output.get_shape().as_list(), [5, 3, 4]) + self.assertListEqual(output.get_shape().as_list(), + [5, width // 2, channels]) def testCreateMaxPooling1DChannelsFirst(self): width = 7 - images = random_ops.random_uniform((5, 4, width)) + channels = 3 + images = random_ops.random_uniform((5, channels, width)) layer = pooling_layers.MaxPooling1D( 2, strides=2, data_format='channels_first') output = layer.apply(images) - self.assertListEqual(output.get_shape().as_list(), [5, 4, 3]) + self.assertListEqual(output.get_shape().as_list(), + [5, channels, width // 2]) def testCreateAveragePooling1DChannelsFirst(self): width = 7 - images = random_ops.random_uniform((5, 4, width)) + channels = 3 + images = random_ops.random_uniform((5, channels, width)) layer = pooling_layers.AveragePooling1D( 2, strides=2, data_format='channels_first') output = layer.apply(images) - self.assertListEqual(output.get_shape().as_list(), [5, 4, 3]) + self.assertListEqual(output.get_shape().as_list(), + [5, channels, width // 2]) def testCreateMaxPooling3D(self): depth, height, width = 6, 7, 9 |