diff options
Diffstat (limited to 'tensorflow/python/layers/pooling_test.py')
-rw-r--r-- | tensorflow/python/layers/pooling_test.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/python/layers/pooling_test.py b/tensorflow/python/layers/pooling_test.py index 589fee5f71..e4d4ed4a2a 100644 --- a/tensorflow/python/layers/pooling_test.py +++ b/tensorflow/python/layers/pooling_test.py @@ -110,19 +110,19 @@ class PoolingTest(test.TestCase): def testCreateMaxPooling1DChannelsFirst(self): width = 7 - images = random_ops.random_uniform((5, width, 4)) + images = random_ops.random_uniform((5, 4, width)) layer = pooling_layers.MaxPooling1D( 2, strides=2, data_format='channels_first') output = layer.apply(images) - self.assertListEqual(output.get_shape().as_list(), [5, 3, 4]) + self.assertListEqual(output.get_shape().as_list(), [5, 4, 3]) def testCreateAveragePooling1DChannelsFirst(self): width = 7 - images = random_ops.random_uniform((5, width, 4)) + images = random_ops.random_uniform((5, 4, width)) layer = pooling_layers.AveragePooling1D( 2, strides=2, data_format='channels_first') output = layer.apply(images) - self.assertListEqual(output.get_shape().as_list(), [5, 3, 4]) + self.assertListEqual(output.get_shape().as_list(), [5, 4, 3]) def testCreateMaxPooling3D(self): depth, height, width = 6, 7, 9 |