aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/layers/pooling.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/layers/pooling.py')
-rw-r--r--tensorflow/python/keras/layers/pooling.py185
1 files changed, 137 insertions, 48 deletions
diff --git a/tensorflow/python/keras/layers/pooling.py b/tensorflow/python/keras/layers/pooling.py
index 912e8bd619..72a9c1d629 100644
--- a/tensorflow/python/keras/layers/pooling.py
+++ b/tensorflow/python/keras/layers/pooling.py
@@ -18,12 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.util.tf_export import tf_export
@@ -41,16 +44,18 @@ class Pooling1D(Layer):
strides of the pooling operation.
padding: A string. The padding method, either 'valid' or 'same'.
Case-insensitive.
- data_format: A string, one of `channels_last` (default) or `channels_first`.
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
- `(batch, length, channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, length)`.
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
name: A string, the name of the layer.
"""
def __init__(self, pool_function, pool_size, strides,
- padding='valid', data_format=None,
+ padding='valid', data_format='channels_last',
name=None, **kwargs):
super(Pooling1D, self).__init__(name=name, **kwargs)
if data_format is None:
@@ -65,45 +70,39 @@ class Pooling1D(Layer):
self.input_spec = InputSpec(ndim=3)
def call(self, inputs):
- # There is no TF op for 1D pooling, hence we make the inputs 4D.
- if self.data_format == 'channels_last':
- # input is NWC, make it NHWC
- inputs = array_ops.expand_dims(inputs, 1)
- # pool on the W dim
- pool_shape = (1, 1) + self.pool_size + (1,)
- strides = (1, 1) + self.strides + (1,)
- data_format = 'NHWC'
- else:
- # input is NCW, make it NCHW
- inputs = array_ops.expand_dims(inputs, 2)
- # pool on the W dim
- pool_shape = (1, 1, 1) + self.pool_size
- strides = (1, 1, 1) + self.strides
- data_format = 'NCHW'
-
+ pad_axis = 2 if self.data_format == 'channels_last' else 3
+ inputs = array_ops.expand_dims(inputs, pad_axis)
outputs = self.pool_function(
inputs,
- ksize=pool_shape,
- strides=strides,
- padding=self.padding.upper(),
- data_format=data_format)
-
- if self.data_format == 'channels_last':
- return array_ops.squeeze(outputs, 1)
- else:
- return array_ops.squeeze(outputs, 2)
+ self.pool_size + (1,),
+ strides=self.strides + (1,),
+ padding=self.padding,
+ data_format=self.data_format)
+ return array_ops.squeeze(outputs, pad_axis)
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
- length = conv_utils.conv_output_length(input_shape[1], self.pool_size[0],
- self.padding, self.strides[0])
- return tensor_shape.TensorShape([input_shape[0], length, input_shape[2]])
+ if self.data_format == 'channels_first':
+ steps = input_shape[2]
+ features = input_shape[1]
+ else:
+ steps = input_shape[1]
+ features = input_shape[2]
+ length = conv_utils.conv_output_length(steps,
+ self.pool_size[0],
+ self.padding,
+ self.strides[0])
+ if self.data_format == 'channels_first':
+ return tensor_shape.TensorShape([input_shape[0], features, length])
+ else:
+ return tensor_shape.TensorShape([input_shape[0], length, features])
def get_config(self):
config = {
'strides': self.strides,
'pool_size': self.pool_size,
- 'padding': self.padding
+ 'padding': self.padding,
+ 'data_format': self.data_format,
}
base_config = super(Pooling1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@@ -119,19 +118,36 @@ class MaxPooling1D(Pooling1D):
E.g. 2 will halve the input.
If None, it will default to `pool_size`.
padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
- 3D tensor with shape: `(batch_size, downsampled_steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, downsampled_steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, downsampled_steps)`
"""
def __init__(self, pool_size=2, strides=None,
- padding='valid', data_format=None, **kwargs):
+ padding='valid', data_format='channels_last', **kwargs):
super(MaxPooling1D, self).__init__(
- nn.max_pool,
+ functools.partial(backend.pool2d, pool_mode='max'),
pool_size=pool_size,
strides=strides,
padding=padding,
@@ -149,18 +165,35 @@ class AveragePooling1D(Pooling1D):
E.g. 2 will halve the input.
If None, it will default to `pool_size`.
padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
- 3D tensor with shape: `(batch_size, downsampled_steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, downsampled_steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, downsampled_steps)`
"""
def __init__(self, pool_size=2, strides=None,
- padding='valid', data_format=None, **kwargs):
+ padding='valid', data_format='channels_last', **kwargs):
super(AveragePooling1D, self).__init__(
- nn.avg_pool,
+ functools.partial(backend.pool2d, pool_mode='avg'),
pool_size=pool_size,
strides=strides,
padding=padding,
@@ -561,41 +594,96 @@ class GlobalPooling1D(Layer):
"""Abstract class for different global pooling 1D layers.
"""
- def __init__(self, **kwargs):
+ def __init__(self, data_format='channels_last', **kwargs):
super(GlobalPooling1D, self).__init__(**kwargs)
self.input_spec = InputSpec(ndim=3)
+ self.data_format = conv_utils.normalize_data_format(data_format)
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
- return tensor_shape.TensorShape([input_shape[0], input_shape[2]])
+ if self.data_format == 'channels_first':
+ return tensor_shape.TensorShape([input_shape[0], input_shape[1]])
+ else:
+ return tensor_shape.TensorShape([input_shape[0], input_shape[2]])
def call(self, inputs):
raise NotImplementedError
+ def get_config(self):
+ config = {'data_format': self.data_format}
+ base_config = super(GlobalPooling1D, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
@tf_export('keras.layers.GlobalAveragePooling1D',
'keras.layers.GlobalAvgPool1D')
class GlobalAveragePooling1D(GlobalPooling1D):
"""Global average pooling operation for temporal data.
+ Arguments:
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
+
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
2D tensor with shape:
`(batch_size, features)`
"""
- def call(self, inputs):
- return backend.mean(inputs, axis=1)
+ def __init__(self, data_format='channels_last', **kwargs):
+ super(GlobalAveragePooling1D, self).__init__(data_format=data_format,
+ **kwargs)
+ self.supports_masking = True
+
+ def call(self, inputs, mask=None):
+ steps_axis = 1 if self.data_format == 'channels_last' else 2
+ if mask is not None:
+ mask = math_ops.cast(mask, backend.floatx())
+ input_shape = inputs.shape.as_list()
+ broadcast_shape = [-1, input_shape[steps_axis], 1]
+ mask = array_ops.reshape(mask, broadcast_shape)
+ inputs *= mask
+ return backend.sum(inputs, axis=steps_axis) / math_ops.reduce_sum(
+ mask, axis=steps_axis)
+ else:
+ return backend.mean(inputs, axis=steps_axis)
+
+ def compute_mask(self, inputs, mask=None):
+ return None
@tf_export('keras.layers.GlobalMaxPool1D', 'keras.layers.GlobalMaxPooling1D')
class GlobalMaxPooling1D(GlobalPooling1D):
"""Global max pooling operation for temporal data.
+ Arguments:
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
+
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
2D tensor with shape:
@@ -603,7 +691,8 @@ class GlobalMaxPooling1D(GlobalPooling1D):
"""
def call(self, inputs):
- return backend.max(inputs, axis=1)
+ steps_axis = 1 if self.data_format == 'channels_last' else 2
+ return backend.max(inputs, axis=steps_axis)
class GlobalPooling2D(Layer):