aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/pooling_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/layers/pooling_test.py')
-rw-r--r--tensorflow/python/layers/pooling_test.py24
1 files changed, 16 insertions, 8 deletions
diff --git a/tensorflow/python/layers/pooling_test.py b/tensorflow/python/layers/pooling_test.py
index e4d4ed4a2a..7533674e5a 100644
--- a/tensorflow/python/layers/pooling_test.py
+++ b/tensorflow/python/layers/pooling_test.py
@@ -96,33 +96,41 @@ class PoolingTest(test.TestCase):
def testCreateMaxPooling1D(self):
width = 7
- images = random_ops.random_uniform((5, width, 4))
+ channels = 3
+ images = random_ops.random_uniform((5, width, channels))
layer = pooling_layers.MaxPooling1D(2, strides=2)
output = layer.apply(images)
- self.assertListEqual(output.get_shape().as_list(), [5, 3, 4])
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, width // 2, channels])
def testCreateAveragePooling1D(self):
width = 7
- images = random_ops.random_uniform((5, width, 4))
+ channels = 3
+ images = random_ops.random_uniform((5, width, channels))
layer = pooling_layers.AveragePooling1D(2, strides=2)
output = layer.apply(images)
- self.assertListEqual(output.get_shape().as_list(), [5, 3, 4])
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, width // 2, channels])
def testCreateMaxPooling1DChannelsFirst(self):
width = 7
- images = random_ops.random_uniform((5, 4, width))
+ channels = 3
+ images = random_ops.random_uniform((5, channels, width))
layer = pooling_layers.MaxPooling1D(
2, strides=2, data_format='channels_first')
output = layer.apply(images)
- self.assertListEqual(output.get_shape().as_list(), [5, 4, 3])
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, channels, width // 2])
def testCreateAveragePooling1DChannelsFirst(self):
width = 7
- images = random_ops.random_uniform((5, 4, width))
+ channels = 3
+ images = random_ops.random_uniform((5, channels, width))
layer = pooling_layers.AveragePooling1D(
2, strides=2, data_format='channels_first')
output = layer.apply(images)
- self.assertListEqual(output.get_shape().as_list(), [5, 4, 3])
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, channels, width // 2])
def testCreateMaxPooling3D(self):
depth, height, width = 6, 7, 9