diff options
Diffstat (limited to 'tensorflow/python/keras/layers/pooling.py')
-rw-r--r-- | tensorflow/python/keras/layers/pooling.py | 185 |
1 files changed, 137 insertions, 48 deletions
diff --git a/tensorflow/python/keras/layers/pooling.py b/tensorflow/python/keras/layers/pooling.py index 912e8bd619..72a9c1d629 100644 --- a/tensorflow/python/keras/layers/pooling.py +++ b/tensorflow/python/keras/layers/pooling.py @@ -18,12 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import backend from tensorflow.python.keras.engine.base_layer import InputSpec from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import conv_utils from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.util.tf_export import tf_export @@ -41,16 +44,18 @@ class Pooling1D(Layer): strides of the pooling operation. padding: A string. The padding method, either 'valid' or 'same'. Case-insensitive. - data_format: A string, one of `channels_last` (default) or `channels_first`. + data_format: A string, + one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. + `(batch, steps, features)` while `channels_first` + corresponds to inputs with shape + `(batch, features, steps)`. name: A string, the name of the layer. """ def __init__(self, pool_function, pool_size, strides, - padding='valid', data_format=None, + padding='valid', data_format='channels_last', name=None, **kwargs): super(Pooling1D, self).__init__(name=name, **kwargs) if data_format is None: @@ -65,45 +70,39 @@ class Pooling1D(Layer): self.input_spec = InputSpec(ndim=3) def call(self, inputs): - # There is no TF op for 1D pooling, hence we make the inputs 4D. - if self.data_format == 'channels_last': - # input is NWC, make it NHWC - inputs = array_ops.expand_dims(inputs, 1) - # pool on the W dim - pool_shape = (1, 1) + self.pool_size + (1,) - strides = (1, 1) + self.strides + (1,) - data_format = 'NHWC' - else: - # input is NCW, make it NCHW - inputs = array_ops.expand_dims(inputs, 2) - # pool on the W dim - pool_shape = (1, 1, 1) + self.pool_size - strides = (1, 1, 1) + self.strides - data_format = 'NCHW' - + pad_axis = 2 if self.data_format == 'channels_last' else 3 + inputs = array_ops.expand_dims(inputs, pad_axis) outputs = self.pool_function( inputs, - ksize=pool_shape, - strides=strides, - padding=self.padding.upper(), - data_format=data_format) - - if self.data_format == 'channels_last': - return array_ops.squeeze(outputs, 1) - else: - return array_ops.squeeze(outputs, 2) + self.pool_size + (1,), + strides=self.strides + (1,), + padding=self.padding, + data_format=self.data_format) + return array_ops.squeeze(outputs, pad_axis) def compute_output_shape(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape).as_list() - length = conv_utils.conv_output_length(input_shape[1], self.pool_size[0], - self.padding, self.strides[0]) - return tensor_shape.TensorShape([input_shape[0], length, input_shape[2]]) + if self.data_format == 'channels_first': + steps = input_shape[2] + features = input_shape[1] + else: + steps = input_shape[1] + features = input_shape[2] + length = conv_utils.conv_output_length(steps, + self.pool_size[0], + self.padding, + self.strides[0]) + if self.data_format == 'channels_first': + return tensor_shape.TensorShape([input_shape[0], features, length]) + else: + return tensor_shape.TensorShape([input_shape[0], length, features]) def get_config(self): config = { 'strides': self.strides, 'pool_size': self.pool_size, - 'padding': self.padding + 'padding': self.padding, + 'data_format': self.data_format, } base_config = super(Pooling1D, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -119,19 +118,36 @@ class MaxPooling1D(Pooling1D): E.g. 2 will halve the input. If None, it will default to `pool_size`. padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, + one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, steps, features)` while `channels_first` + corresponds to inputs with shape + `(batch, features, steps)`. Input shape: - 3D tensor with shape: `(batch_size, steps, features)`. + - If `data_format='channels_last'`: + 3D tensor with shape: + `(batch_size, steps, features)` + - If `data_format='channels_first'`: + 3D tensor with shape: + `(batch_size, features, steps)` Output shape: - 3D tensor with shape: `(batch_size, downsampled_steps, features)`. + - If `data_format='channels_last'`: + 3D tensor with shape: + `(batch_size, downsampled_steps, features)` + - If `data_format='channels_first'`: + 3D tensor with shape: + `(batch_size, features, downsampled_steps)` """ def __init__(self, pool_size=2, strides=None, - padding='valid', data_format=None, **kwargs): + padding='valid', data_format='channels_last', **kwargs): super(MaxPooling1D, self).__init__( - nn.max_pool, + functools.partial(backend.pool2d, pool_mode='max'), pool_size=pool_size, strides=strides, padding=padding, @@ -149,18 +165,35 @@ class AveragePooling1D(Pooling1D): E.g. 2 will halve the input. If None, it will default to `pool_size`. padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, + one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, steps, features)` while `channels_first` + corresponds to inputs with shape + `(batch, features, steps)`. Input shape: - 3D tensor with shape: `(batch_size, steps, features)`. + - If `data_format='channels_last'`: + 3D tensor with shape: + `(batch_size, steps, features)` + - If `data_format='channels_first'`: + 3D tensor with shape: + `(batch_size, features, steps)` Output shape: - 3D tensor with shape: `(batch_size, downsampled_steps, features)`. + - If `data_format='channels_last'`: + 3D tensor with shape: + `(batch_size, downsampled_steps, features)` + - If `data_format='channels_first'`: + 3D tensor with shape: + `(batch_size, features, downsampled_steps)` """ def __init__(self, pool_size=2, strides=None, - padding='valid', data_format=None, **kwargs): + padding='valid', data_format='channels_last', **kwargs): super(AveragePooling1D, self).__init__( - nn.avg_pool, + functools.partial(backend.pool2d, pool_mode='avg'), pool_size=pool_size, strides=strides, padding=padding, @@ -561,41 +594,96 @@ class GlobalPooling1D(Layer): """Abstract class for different global pooling 1D layers. """ - def __init__(self, **kwargs): + def __init__(self, data_format='channels_last', **kwargs): super(GlobalPooling1D, self).__init__(**kwargs) self.input_spec = InputSpec(ndim=3) + self.data_format = conv_utils.normalize_data_format(data_format) def compute_output_shape(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape).as_list() - return tensor_shape.TensorShape([input_shape[0], input_shape[2]]) + if self.data_format == 'channels_first': + return tensor_shape.TensorShape([input_shape[0], input_shape[1]]) + else: + return tensor_shape.TensorShape([input_shape[0], input_shape[2]]) def call(self, inputs): raise NotImplementedError + def get_config(self): + config = {'data_format': self.data_format} + base_config = super(GlobalPooling1D, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + @tf_export('keras.layers.GlobalAveragePooling1D', 'keras.layers.GlobalAvgPool1D') class GlobalAveragePooling1D(GlobalPooling1D): """Global average pooling operation for temporal data. + Arguments: + data_format: A string, + one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, steps, features)` while `channels_first` + corresponds to inputs with shape + `(batch, features, steps)`. + Input shape: - 3D tensor with shape: `(batch_size, steps, features)`. + - If `data_format='channels_last'`: + 3D tensor with shape: + `(batch_size, steps, features)` + - If `data_format='channels_first'`: + 3D tensor with shape: + `(batch_size, features, steps)` Output shape: 2D tensor with shape: `(batch_size, features)` """ - def call(self, inputs): - return backend.mean(inputs, axis=1) + def __init__(self, data_format='channels_last', **kwargs): + super(GlobalAveragePooling1D, self).__init__(data_format=data_format, + **kwargs) + self.supports_masking = True + + def call(self, inputs, mask=None): + steps_axis = 1 if self.data_format == 'channels_last' else 2 + if mask is not None: + mask = math_ops.cast(mask, backend.floatx()) + input_shape = inputs.shape.as_list() + broadcast_shape = [-1, input_shape[steps_axis], 1] + mask = array_ops.reshape(mask, broadcast_shape) + inputs *= mask + return backend.sum(inputs, axis=steps_axis) / math_ops.reduce_sum( + mask, axis=steps_axis) + else: + return backend.mean(inputs, axis=steps_axis) + + def compute_mask(self, inputs, mask=None): + return None @tf_export('keras.layers.GlobalMaxPool1D', 'keras.layers.GlobalMaxPooling1D') class GlobalMaxPooling1D(GlobalPooling1D): """Global max pooling operation for temporal data. + Arguments: + data_format: A string, + one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, steps, features)` while `channels_first` + corresponds to inputs with shape + `(batch, features, steps)`. + Input shape: - 3D tensor with shape: `(batch_size, steps, features)`. + - If `data_format='channels_last'`: + 3D tensor with shape: + `(batch_size, steps, features)` + - If `data_format='channels_first'`: + 3D tensor with shape: + `(batch_size, features, steps)` Output shape: 2D tensor with shape: @@ -603,7 +691,8 @@ class GlobalMaxPooling1D(GlobalPooling1D): """ def call(self, inputs): - return backend.max(inputs, axis=1) + steps_axis = 1 if self.data_format == 'channels_last' else 2 + return backend.max(inputs, axis=steps_axis) class GlobalPooling2D(Layer): |