aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/layers/pooling_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/layers/pooling_test.py')
-rw-r--r--tensorflow/python/keras/layers/pooling_test.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/python/keras/layers/pooling_test.py b/tensorflow/python/keras/layers/pooling_test.py
index 2cd9939e66..936e73ecf9 100644
--- a/tensorflow/python/keras/layers/pooling_test.py
+++ b/tensorflow/python/keras/layers/pooling_test.py
@@ -18,11 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
+from tensorflow.python.training import rmsprop
class GlobalPoolingTest(test.TestCase):
@@ -31,8 +34,26 @@ class GlobalPoolingTest(test.TestCase):
def test_globalpooling_1d(self):
testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D,
input_shape=(3, 4, 5))
+ testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 4, 5))
testing_utils.layer_test(
keras.layers.pooling.GlobalAveragePooling1D, input_shape=(3, 4, 5))
+ testing_utils.layer_test(keras.layers.pooling.GlobalAveragePooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 4, 5))
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_globalpooling_1d_masking_support(self):
+ model = keras.Sequential()
+ model.add(keras.layers.Masking(mask_value=0., input_shape=(3, 4)))
+ model.add(keras.layers.GlobalAveragePooling1D())
+ model.compile(loss='mae', optimizer=rmsprop.RMSPropOptimizer(0.001))
+
+ model_input = np.random.random((2, 3, 4))
+ model_input[0, 1:, :] = 0
+ output = model.predict(model_input)
+ self.assertAllClose(output[0], model_input[0, 0, :])
@tf_test_util.run_in_graph_and_eager_modes
def test_globalpooling_2d(self):
@@ -172,6 +193,10 @@ class Pooling1DTest(test.TestCase):
kwargs={'strides': stride,
'padding': padding},
input_shape=(3, 5, 4))
+ testing_utils.layer_test(
+ keras.layers.MaxPooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 2, 6))
@tf_test_util.run_in_graph_and_eager_modes
def test_averagepooling_1d(self):
@@ -183,6 +208,11 @@ class Pooling1DTest(test.TestCase):
'padding': padding},
input_shape=(3, 5, 4))
+ testing_utils.layer_test(
+ keras.layers.AveragePooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 2, 6))
+
if __name__ == '__main__':
test.main()