aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/pooling_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/layers/pooling_test.py')
-rw-r--r--tensorflow/python/layers/pooling_test.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/python/layers/pooling_test.py b/tensorflow/python/layers/pooling_test.py
index 589fee5f71..e4d4ed4a2a 100644
--- a/tensorflow/python/layers/pooling_test.py
+++ b/tensorflow/python/layers/pooling_test.py
@@ -110,19 +110,19 @@ class PoolingTest(test.TestCase):
def testCreateMaxPooling1DChannelsFirst(self):
width = 7
- images = random_ops.random_uniform((5, width, 4))
+ images = random_ops.random_uniform((5, 4, width))
layer = pooling_layers.MaxPooling1D(
2, strides=2, data_format='channels_first')
output = layer.apply(images)
- self.assertListEqual(output.get_shape().as_list(), [5, 3, 4])
+ self.assertListEqual(output.get_shape().as_list(), [5, 4, 3])
def testCreateAveragePooling1DChannelsFirst(self):
width = 7
- images = random_ops.random_uniform((5, width, 4))
+ images = random_ops.random_uniform((5, 4, width))
layer = pooling_layers.AveragePooling1D(
2, strides=2, data_format='channels_first')
output = layer.apply(images)
- self.assertListEqual(output.get_shape().as_list(), [5, 3, 4])
+ self.assertListEqual(output.get_shape().as_list(), [5, 4, 3])
def testCreateMaxPooling3D(self):
depth, height, width = 6, 7, 9