diff options
author | 2018-04-06 11:24:20 -0700 | |
---|---|---|
committer | 2018-04-06 11:31:24 -0700 | |
commit | 98b8b786036172d33c85b6b5f81347440d0594df (patch) | |
tree | 0bcfc56b315709294e7cc21a497a27dbadc481f9 | |
parent | beda9ebd36bbf6964459c7ee2209975d62cb01e6 (diff) |
Update tf.keras to keras 2.1.5 version.
PiperOrigin-RevId: 191914904
11 files changed, 562 insertions, 251 deletions
diff --git a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py index ad96b53a45..12775fccec 100644 --- a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py +++ b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py @@ -84,11 +84,13 @@ from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import BatchNormalization from tensorflow.python.keras._impl.keras.layers import Conv2D +from tensorflow.python.keras._impl.keras.layers import DepthwiseConv2D from tensorflow.python.keras._impl.keras.layers import Dropout from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D from tensorflow.python.keras._impl.keras.layers import Input from tensorflow.python.keras._impl.keras.layers import Reshape +from tensorflow.python.keras._impl.keras.layers import ZeroPadding2D from tensorflow.python.keras._impl.keras.models import Model from tensorflow.python.keras._impl.keras.utils import conv_utils from tensorflow.python.keras._impl.keras.utils.data_utils import get_file @@ -116,195 +118,6 @@ def preprocess_input(x): return imagenet_utils.preprocess_input(x, mode='tf') -class DepthwiseConv2D(Conv2D): - """Depthwise separable 2D convolution. - - Depthwise Separable convolutions consists in performing - just the first step in a depthwise spatial convolution - (which acts on each input channel separately). - The `depth_multiplier` argument controls how many - output channels are generated per input channel in the depthwise step. - - Arguments: - kernel_size: An integer or tuple/list of 2 integers, specifying the - width and height of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the width and height. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: one of `'valid'` or `'same'` (case-insensitive). - depth_multiplier: The number of depthwise convolution output channels - for each input channel. - The total number of depthwise convolution output - channels will be equal to `filters_in * depth_multiplier`. - data_format: A string, - one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, height, width)`. - It defaults to the `image_data_format` value found in your - Keras config file at `~/.keras/keras.json`. - If you never set it, then it will be 'channels_last'. - activation: Activation function to use. - If you don't specify anything, no activation is applied - (ie. 'linear' activation: `a(x) = x`). - use_bias: Boolean, whether the layer uses a bias vector. - depthwise_initializer: Initializer for the depthwise kernel matrix. - bias_initializer: Initializer for the bias vector. - depthwise_regularizer: Regularizer function applied to - the depthwise kernel matrix. - bias_regularizer: Regularizer function applied to the bias vector. - activity_regularizer: Regularizer function applied to - the output of the layer (its 'activation').. - depthwise_constraint: Constraint function applied to - the depthwise kernel matrix. - bias_constraint: Constraint function applied to the bias vector. - - Input shape: - 4D tensor with shape: - `[batch, channels, rows, cols]` if data_format='channels_first' - or 4D tensor with shape: - `[batch, rows, cols, channels]` if data_format='channels_last'. - - Output shape: - 4D tensor with shape: - `[batch, filters, new_rows, new_cols]` if data_format='channels_first' - or 4D tensor with shape: - `[batch, new_rows, new_cols, filters]` if data_format='channels_last'. - `rows` and `cols` values might have changed due to padding. - """ - - def __init__(self, - kernel_size, - strides=(1, 1), - padding='valid', - depth_multiplier=1, - data_format=None, - activation=None, - use_bias=True, - depthwise_initializer='glorot_uniform', - bias_initializer='zeros', - depthwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - bias_constraint=None, - **kwargs): - super(DepthwiseConv2D, self).__init__( - filters=None, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - activation=activation, - use_bias=use_bias, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - bias_constraint=bias_constraint, - **kwargs) - self.depth_multiplier = depth_multiplier - self.depthwise_initializer = initializers.get(depthwise_initializer) - self.depthwise_regularizer = regularizers.get(depthwise_regularizer) - self.depthwise_constraint = constraints.get(depthwise_constraint) - self.bias_initializer = initializers.get(bias_initializer) - - @shape_type_conversion - def build(self, input_shape): - if len(input_shape) < 4: - raise ValueError('Inputs to `DepthwiseConv2D` should have rank 4. ' - 'Received input shape:', str(input_shape)) - if self.data_format == 'channels_first': - channel_axis = 1 - else: - channel_axis = 3 - if input_shape[channel_axis] is None: - raise ValueError('The channel dimension of the inputs to ' - '`DepthwiseConv2D` ' - 'should be defined. Found `None`.') - input_dim = int(input_shape[channel_axis]) - depthwise_kernel_shape = (self.kernel_size[0], self.kernel_size[1], - input_dim, self.depth_multiplier) - - self.depthwise_kernel = self.add_weight( - shape=depthwise_kernel_shape, - initializer=self.depthwise_initializer, - name='depthwise_kernel', - regularizer=self.depthwise_regularizer, - constraint=self.depthwise_constraint) - - if self.use_bias: - self.bias = self.add_weight( - shape=(input_dim * self.depth_multiplier,), - initializer=self.bias_initializer, - name='bias', - regularizer=self.bias_regularizer, - constraint=self.bias_constraint) - else: - self.bias = None - # Set input spec. - self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim}) - self.built = True - - def call(self, inputs, training=None): - outputs = K.depthwise_conv2d( - inputs, - self.depthwise_kernel, - strides=self.strides, - padding=self.padding, - dilation_rate=self.dilation_rate, - data_format=self.data_format) - - if self.bias: - outputs = K.bias_add(outputs, self.bias, data_format=self.data_format) - - if self.activation is not None: - return self.activation(outputs) - - return outputs - - @shape_type_conversion - def compute_output_shape(self, input_shape): - if self.data_format == 'channels_first': - rows = input_shape[2] - cols = input_shape[3] - out_filters = input_shape[1] * self.depth_multiplier - elif self.data_format == 'channels_last': - rows = input_shape[1] - cols = input_shape[2] - out_filters = input_shape[3] * self.depth_multiplier - - rows = conv_utils.conv_output_length(rows, self.kernel_size[0], - self.padding, self.strides[0]) - cols = conv_utils.conv_output_length(cols, self.kernel_size[1], - self.padding, self.strides[1]) - - if self.data_format == 'channels_first': - return (input_shape[0], out_filters, rows, cols) - elif self.data_format == 'channels_last': - return (input_shape[0], rows, cols, out_filters) - - def get_config(self): - config = super(DepthwiseConv2D, self).get_config() - config.pop('filters') - config.pop('kernel_initializer') - config.pop('kernel_regularizer') - config.pop('kernel_constraint') - config['depth_multiplier'] = self.depth_multiplier - config['depthwise_initializer'] = initializers.serialize( - self.depthwise_initializer) - config['depthwise_regularizer'] = regularizers.serialize( - self.depthwise_regularizer) - config['depthwise_constraint'] = constraints.serialize( - self.depthwise_constraint) - return config - - @tf_export('keras.applications.MobileNet', 'keras.applications.mobilenet.MobileNet') def MobileNet(input_shape=None, @@ -318,18 +131,11 @@ def MobileNet(input_shape=None, classes=1000): """Instantiates the MobileNet architecture. - Note that only TensorFlow is supported for now, - therefore it only works with the data format - `image_data_format='channels_last'` in your Keras config - at `~/.keras/keras.json`. - To load a MobileNet model via `load_model`, import the custom - objects `relu6` and `DepthwiseConv2D` and pass them to the - `custom_objects` parameter. + objects `relu6` and pass them to the `custom_objects` parameter. E.g. model = load_model('mobilenet.h5', custom_objects={ - 'relu6': mobilenet.relu6, - 'DepthwiseConv2D': mobilenet.DepthwiseConv2D}) + 'relu6': mobilenet.relu6}) Arguments: input_shape: optional shape tuple, only to be specified @@ -383,11 +189,6 @@ def MobileNet(input_shape=None, backend that does not support separable convolutions. """ - if K.backend() != 'tensorflow': - raise RuntimeError('Only TensorFlow backend is currently supported, ' - 'as other backends do not support ' - 'depthwise convolution.') - if not (weights in {'imagenet', None} or os.path.exists(weights)): raise ValueError('The `weights` argument should be either ' '`None` (random initialization), `imagenet` ' @@ -522,7 +323,7 @@ def MobileNet(input_shape=None, # load weights if weights == 'imagenet': if K.image_data_format() == 'channels_first': - raise ValueError('Weights for "channels_last" format ' + raise ValueError('Weights for "channels_first" format ' 'are not available.') if alpha == 1.0: alpha_text = '1_0' @@ -598,14 +399,14 @@ def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)): """ channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 filters = int(filters * alpha) + x = ZeroPadding2D(padding=(1, 1), name='conv1_pad')(inputs) x = Conv2D( filters, kernel, - padding='same', + padding='valid', use_bias=False, strides=strides, - name='conv1')( - inputs) + name='conv1')(x) x = BatchNormalization(axis=channel_axis, name='conv1_bn')(x) return Activation(relu6, name='conv1_relu')(x) @@ -665,15 +466,14 @@ def _depthwise_conv_block(inputs, """ channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 pointwise_conv_filters = int(pointwise_conv_filters * alpha) - + x = ZeroPadding2D(padding=(1, 1), name='conv_pad_%d' % block_id)(inputs) x = DepthwiseConv2D( # pylint: disable=not-callable (3, 3), - padding='same', + padding='valid', depth_multiplier=depth_multiplier, strides=strides, use_bias=False, - name='conv_dw_%d' % block_id)( - inputs) + name='conv_dw_%d' % block_id)(x) x = BatchNormalization(axis=channel_axis, name='conv_dw_%d_bn' % block_id)(x) x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x) diff --git a/tensorflow/python/keras/_impl/keras/applications/resnet50.py b/tensorflow/python/keras/_impl/keras/applications/resnet50.py index 46c0e63557..f8c6aff4f2 100644 --- a/tensorflow/python/keras/_impl/keras/applications/resnet50.py +++ b/tensorflow/python/keras/_impl/keras/applications/resnet50.py @@ -45,6 +45,7 @@ from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D from tensorflow.python.keras._impl.keras.layers import Input from tensorflow.python.keras._impl.keras.layers import MaxPooling2D +from tensorflow.python.keras._impl.keras.layers import ZeroPadding2D from tensorflow.python.keras._impl.keras.models import Model from tensorflow.python.keras._impl.keras.utils import layer_utils from tensorflow.python.keras._impl.keras.utils.data_utils import get_file @@ -236,9 +237,9 @@ def ResNet50(include_top=True, else: bn_axis = 1 + x = ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input) x = Conv2D( - 64, (7, 7), strides=(2, 2), padding='same', name='conv1')( - img_input) + 64, (7, 7), strides=(2, 2), padding='valid', name='conv1')(x) x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) x = Activation('relu')(x) x = MaxPooling2D((3, 3), strides=(2, 2))(x) diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional.py b/tensorflow/python/keras/_impl/keras/layers/convolutional.py index 162ae6c28f..7cdebc6aa4 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional.py @@ -27,6 +27,7 @@ from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion # imports for backwards namespace compatibility # pylint: disable=unused-import from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling1D @@ -1024,6 +1025,200 @@ class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer): return dict(list(base_config.items()) + list(config.items())) +@tf_export('keras.layers.DepthwiseConv2D') +class DepthwiseConv2D(Conv2D): + """Depthwise separable 2D convolution. + + Depthwise Separable convolutions consists in performing + just the first step in a depthwise spatial convolution + (which acts on each input channel separately). + The `depth_multiplier` argument controls how many + output channels are generated per input channel in the depthwise step. + + Arguments: + kernel_size: An integer or tuple/list of 2 integers, specifying the + width and height of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the width and height. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: one of `'valid'` or `'same'` (case-insensitive). + depth_multiplier: The number of depthwise convolution output channels + for each input channel. + The total number of depthwise convolution output + channels will be equal to `filters_in * depth_multiplier`. + data_format: A string, + one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, height, width)`. + It defaults to the `image_data_format` value found in your + Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be 'channels_last'. + activation: Activation function to use. + If you don't specify anything, no activation is applied + (ie. 'linear' activation: `a(x) = x`). + use_bias: Boolean, whether the layer uses a bias vector. + depthwise_initializer: Initializer for the depthwise kernel matrix. + bias_initializer: Initializer for the bias vector. + depthwise_regularizer: Regularizer function applied to + the depthwise kernel matrix. + bias_regularizer: Regularizer function applied to the bias vector. + activity_regularizer: Regularizer function applied to + the output of the layer (its 'activation'). + depthwise_constraint: Constraint function applied to + the depthwise kernel matrix. + bias_constraint: Constraint function applied to the bias vector. + + Input shape: + 4D tensor with shape: + `[batch, channels, rows, cols]` if data_format='channels_first' + or 4D tensor with shape: + `[batch, rows, cols, channels]` if data_format='channels_last'. + + Output shape: + 4D tensor with shape: + `[batch, filters, new_rows, new_cols]` if data_format='channels_first' + or 4D tensor with shape: + `[batch, new_rows, new_cols, filters]` if data_format='channels_last'. + `rows` and `cols` values might have changed due to padding. + """ + + def __init__(self, + kernel_size, + strides=(1, 1), + padding='valid', + depth_multiplier=1, + data_format=None, + activation=None, + use_bias=True, + depthwise_initializer='glorot_uniform', + bias_initializer='zeros', + depthwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + bias_constraint=None, + **kwargs): + super(DepthwiseConv2D, self).__init__( + filters=None, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + activation=activation, + use_bias=use_bias, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + bias_constraint=bias_constraint, + **kwargs) + self.depth_multiplier = depth_multiplier + self.depthwise_initializer = initializers.get(depthwise_initializer) + self.depthwise_regularizer = regularizers.get(depthwise_regularizer) + self.depthwise_constraint = constraints.get(depthwise_constraint) + self.bias_initializer = initializers.get(bias_initializer) + + def build(self, input_shape): + if len(input_shape) < 4: + raise ValueError('Inputs to `DepthwiseConv2D` should have rank 4. ' + 'Received input shape:', str(input_shape)) + if self.data_format == 'channels_first': + channel_axis = 1 + else: + channel_axis = 3 + if input_shape[channel_axis] is None: + raise ValueError('The channel dimension of the inputs to ' + '`DepthwiseConv2D` ' + 'should be defined. Found `None`.') + input_dim = int(input_shape[channel_axis]) + depthwise_kernel_shape = (self.kernel_size[0], + self.kernel_size[1], + input_dim, + self.depth_multiplier) + + self.depthwise_kernel = self.add_weight( + shape=depthwise_kernel_shape, + initializer=self.depthwise_initializer, + name='depthwise_kernel', + regularizer=self.depthwise_regularizer, + constraint=self.depthwise_constraint) + + if self.use_bias: + self.bias = self.add_weight(shape=(input_dim * self.depth_multiplier,), + initializer=self.bias_initializer, + name='bias', + regularizer=self.bias_regularizer, + constraint=self.bias_constraint) + else: + self.bias = None + # Set input spec. + self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim}) + self.built = True + + def call(self, inputs, training=None): + outputs = K.depthwise_conv2d( + inputs, + self.depthwise_kernel, + strides=self.strides, + padding=self.padding, + dilation_rate=self.dilation_rate, + data_format=self.data_format) + + if self.bias: + outputs = K.bias_add( + outputs, + self.bias, + data_format=self.data_format) + + if self.activation is not None: + return self.activation(outputs) + + return outputs + + @shape_type_conversion + def compute_output_shape(self, input_shape): + if self.data_format == 'channels_first': + rows = input_shape[2] + cols = input_shape[3] + out_filters = input_shape[1] * self.depth_multiplier + elif self.data_format == 'channels_last': + rows = input_shape[1] + cols = input_shape[2] + out_filters = input_shape[3] * self.depth_multiplier + + rows = conv_utils.conv_output_length(rows, self.kernel_size[0], + self.padding, + self.strides[0]) + cols = conv_utils.conv_output_length(cols, self.kernel_size[1], + self.padding, + self.strides[1]) + if self.data_format == 'channels_first': + return (input_shape[0], out_filters, rows, cols) + elif self.data_format == 'channels_last': + return (input_shape[0], rows, cols, out_filters) + + def get_config(self): + config = super(DepthwiseConv2D, self).get_config() + config.pop('filters') + config.pop('kernel_initializer') + config.pop('kernel_regularizer') + config.pop('kernel_constraint') + config['depth_multiplier'] = self.depth_multiplier + config['depthwise_initializer'] = initializers.serialize( + self.depthwise_initializer) + config['depthwise_regularizer'] = regularizers.serialize( + self.depthwise_regularizer) + config['depthwise_constraint'] = constraints.serialize( + self.depthwise_constraint) + return config + + @tf_export('keras.layers.UpSampling1D') class UpSampling1D(Layer): """Upsampling layer for 1D inputs. diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py index f4a134b96c..12b4267675 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py @@ -961,5 +961,43 @@ class CroppingTest(test.TestCase): keras.layers.Cropping3D(cropping=None) +class DepthwiseConv2DTest(test.TestCase): + + def _run_test(self, kwargs, arg, values): + num_samples = 2 + stack_size = 3 + num_row = 7 + num_col = 6 + + test_kwargs = copy.copy(kwargs) + for value in values: + test_kwargs[arg] = value + with self.test_session(use_gpu=True): + testing_utils.layer_test( + keras.layers.DepthwiseConv2D, + kwargs=test_kwargs, + input_shape=(num_samples, num_row, num_col, stack_size)) + + def test_depthwise_conv2d(self): + kwargs = {'kernel_size': (3, 3)} + + self._run_test(kwargs, 'padding', ['valid', 'same']) + self._run_test(kwargs, 'strides', [(2, 2)]) + if test.is_gpu_available(cuda_only=True): + self._run_test(kwargs, 'data_format', ['channels_first']) + self._run_test(kwargs, 'depth_multiplier', [1, 2]) + + kwargs = {'kernel_size': 3, + 'padding': 'valid', + 'data_format': 'channels_first', + 'activation': None, + 'depthwise_regularizer': 'l2', + 'bias_regularizer': 'l2', + 'activity_regularizer': 'l2', + 'depthwise_constraint': 'unit_norm', + 'strides': (2, 2), + } + self._run_test(kwargs, 'depth_multiplier', [1]) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py index 7f9f77c296..f53db987ff 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py @@ -251,7 +251,7 @@ class RNN(Layer): It is also possible for `cell` to be a list of RNN cell instances, in which cases the cells get stacked on after the other in the RNN, implementing an efficient stacked RNN. - return_sequences: Boolean. Whether to return the last output. + return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence. return_state: Boolean. Whether to return the last state in addition to the output. @@ -797,10 +797,10 @@ class RNN(Layer): @property def losses(self): - losses = [] + layer_losses = super(RNN, self).losses if isinstance(self.cell, Layer): - losses += self.cell.losses - return losses + self._losses + return self.cell.losses + layer_losses + return layer_losses @property def updates(self): @@ -1017,7 +1017,7 @@ class SimpleRNN(RNN): recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. - return_sequences: Boolean. Whether to return the last output. + return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence. return_state: Boolean. Whether to return the last state in addition to the output. @@ -1237,6 +1237,9 @@ class GRUCell(Layer): batch them into fewer, larger operations. These modes will have different performance profiles on different hardware and for different applications. + reset_after: GRU convention (whether to apply reset gate after or + before matrix multiplication). False = "before" (default), + True = "after" (CuDNN compatible). """ def __init__(self, @@ -1256,6 +1259,7 @@ class GRUCell(Layer): dropout=0., recurrent_dropout=0., implementation=1, + reset_after=False, **kwargs): super(GRUCell, self).__init__(**kwargs) self.units = units @@ -1278,6 +1282,7 @@ class GRUCell(Layer): self.dropout = min(1., max(0., dropout)) self.recurrent_dropout = min(1., max(0., recurrent_dropout)) self.implementation = implementation + self.reset_after = reset_after self.state_size = self.units self._dropout_mask = None self._recurrent_dropout_mask = None @@ -1299,12 +1304,25 @@ class GRUCell(Layer): constraint=self.recurrent_constraint) if self.use_bias: - self.bias = self.add_weight( - shape=(self.units * 3,), - name='bias', - initializer=self.bias_initializer, - regularizer=self.bias_regularizer, - constraint=self.bias_constraint) + if not self.reset_after: + bias_shape = (3 * self.units,) + else: + # separate biases for input and recurrent kernels + # Note: the shape is intentionally different from CuDNNGRU biases + # `(2 * 3 * self.units,)`, so that we can distinguish the classes + # when loading and converting saved weights. + bias_shape = (2, 3 * self.units) + self.bias = self.add_weight(shape=bias_shape, + name='bias', + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint) + if not self.reset_after: + self.input_bias, self.recurrent_bias = self.bias, None + else: + self.input_bias = K.flatten(self.bias[0]) + self.recurrent_bias = K.flatten(self.bias[1]) + else: self.bias = None self.built = True @@ -1340,13 +1358,15 @@ class GRUCell(Layer): inputs_z = inputs inputs_r = inputs inputs_h = inputs + x_z = K.dot(inputs_z, self.kernel[:, :self.units]) x_r = K.dot(inputs_r, self.kernel[:, self.units:self.units * 2]) x_h = K.dot(inputs_h, self.kernel[:, self.units * 2:]) + if self.use_bias: - x_z = K.bias_add(x_z, self.bias[:self.units]) - x_r = K.bias_add(x_r, self.bias[self.units:self.units * 2]) - x_h = K.bias_add(x_h, self.bias[self.units * 2:]) + x_z = K.bias_add(x_z, self.input_bias[:self.units]) + x_r = K.bias_add(x_r, self.input_bias[self.units: self.units * 2]) + x_h = K.bias_add(x_h, self.input_bias[self.units * 2:]) if 0. < self.recurrent_dropout < 1.: h_tm1_z = h_tm1 * rec_dp_mask[0] @@ -1356,42 +1376,70 @@ class GRUCell(Layer): h_tm1_z = h_tm1 h_tm1_r = h_tm1 h_tm1_h = h_tm1 - z = self.recurrent_activation( - x_z + K.dot(h_tm1_z, self.recurrent_kernel[:, :self.units])) - r = self.recurrent_activation( - x_r + K.dot(h_tm1_r, self.recurrent_kernel[:, self.units: - self.units * 2])) - - hh = self.activation(x_h + K.dot(r * h_tm1_h, - self.recurrent_kernel[:, - self.units * 2:])) + + recurrent_z = K.dot(h_tm1_z, self.recurrent_kernel[:, :self.units]) + recurrent_r = K.dot(h_tm1_r, + self.recurrent_kernel[:, self.units:self.units * 2]) + if self.reset_after and self.use_bias: + recurrent_z = K.bias_add(recurrent_z, self.recurrent_bias[:self.units]) + recurrent_r = K.bias_add(recurrent_r, + self.recurrent_bias[self.units: + self.units * 2]) + + z = self.recurrent_activation(x_z + recurrent_z) + r = self.recurrent_activation(x_r + recurrent_r) + + # reset gate applied after/before matrix multiplication + if self.reset_after: + recurrent_h = K.dot(h_tm1_h, self.recurrent_kernel[:, self.units * 2:]) + if self.use_bias: + recurrent_h = K.bias_add(recurrent_h, + self.recurrent_bias[self.units * 2:]) + recurrent_h = r * recurrent_h + else: + recurrent_h = K.dot(r * h_tm1_h, + self.recurrent_kernel[:, self.units * 2:]) + + hh = self.activation(x_h + recurrent_h) else: if 0. < self.dropout < 1.: inputs *= dp_mask[0] + + # inputs projected by all gate matrices at once matrix_x = K.dot(inputs, self.kernel) if self.use_bias: - matrix_x = K.bias_add(matrix_x, self.bias) + # biases: bias_z_i, bias_r_i, bias_h_i + matrix_x = K.bias_add(matrix_x, self.input_bias) + + x_z = matrix_x[:, :self.units] + x_r = matrix_x[:, self.units: 2 * self.units] + x_h = matrix_x[:, 2 * self.units:] + if 0. < self.recurrent_dropout < 1.: h_tm1 *= rec_dp_mask[0] matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units]) - x_z = matrix_x[:, :self.units] - x_r = matrix_x[:, self.units:2 * self.units] recurrent_z = matrix_inner[:, :self.units] recurrent_r = matrix_inner[:, self.units:2 * self.units] z = self.recurrent_activation(x_z + recurrent_z) r = self.recurrent_activation(x_r + recurrent_r) - x_h = matrix_x[:, 2 * self.units:] - recurrent_h = K.dot(r * h_tm1, self.recurrent_kernel[:, 2 * self.units:]) + if self.reset_after: + recurrent_h = r * matrix_inner[:, 2 * self.units:] + else: + recurrent_h = K.dot(r * h_tm1, + self.recurrent_kernel[:, 2 * self.units:]) + hh = self.activation(x_h + recurrent_h) + # previous and candidate state mixed by update gate h = z * h_tm1 + (1 - z) * hh if 0 < self.dropout + self.recurrent_dropout: if training is None and not context.executing_eagerly(): # This would be harmless to set in eager mode, but eager tensors # disallow setting arbitrary attributes. h._uses_learning_phase = True + return h, [h] def get_config(self): @@ -1415,7 +1463,8 @@ class GRUCell(Layer): 'bias_constraint': constraints.serialize(self.bias_constraint), 'dropout': self.dropout, 'recurrent_dropout': self.recurrent_dropout, - 'implementation': self.implementation + 'implementation': self.implementation, + 'reset_after': self.reset_after } base_config = super(GRUCell, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -1423,9 +1472,16 @@ class GRUCell(Layer): @tf_export('keras.layers.GRU') class GRU(RNN): - """Gated Recurrent Unit - Cho et al. + """Gated Recurrent Unit - Cho et al. 2014. - 2014. + There are two variants. The default one is based on 1406.1078v3 and + has reset gate applied to hidden state before matrix multiplication. The + other one is based on original 1406.1078v1 and has the order reversed. + + The second variant is compatible with CuDNNGRU (GPU-only) and allows + inference on CPU. Thus it has separate biases for `kernel` and + `recurrent_kernel`. Use `'reset_after'=True` and + `recurrent_activation='sigmoid'`. Arguments: units: Positive integer, dimensionality of the output space. @@ -1469,7 +1525,7 @@ class GRU(RNN): batch them into fewer, larger operations. These modes will have different performance profiles on different hardware and for different applications. - return_sequences: Boolean. Whether to return the last output. + return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence. return_state: Boolean. Whether to return the last state in addition to the output. @@ -1485,6 +1541,9 @@ class GRU(RNN): Unrolling can speed-up a RNN, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. + reset_after: GRU convention (whether to apply reset gate after or + before matrix multiplication). False = "before" (default), + True = "after" (CuDNN compatible). """ @@ -1511,6 +1570,7 @@ class GRU(RNN): go_backwards=False, stateful=False, unroll=False, + reset_after=False, **kwargs): if implementation == 0: logging.warning('`implementation=0` has been deprecated, ' @@ -1532,7 +1592,8 @@ class GRU(RNN): bias_constraint=bias_constraint, dropout=dropout, recurrent_dropout=recurrent_dropout, - implementation=implementation) + implementation=implementation, + reset_after=reset_after) super(GRU, self).__init__( cell, return_sequences=return_sequences, @@ -1613,6 +1674,10 @@ class GRU(RNN): def implementation(self): return self.cell.implementation + @property + def reset_after(self): + return self.cell.reset_after + def get_config(self): config = { 'units': @@ -1648,7 +1713,9 @@ class GRU(RNN): 'recurrent_dropout': self.recurrent_dropout, 'implementation': - self.implementation + self.implementation, + 'reset_after': + self.reset_after } base_config = super(GRU, self).get_config() del base_config['cell'] @@ -1929,7 +1996,7 @@ class LSTMCell(Layer): @tf_export('keras.layers.LSTM') class LSTM(RNN): - """Long-Short Term Memory layer - Hochreiter 1997. + """Long Short-Term Memory layer - Hochreiter 1997. Arguments: units: Positive integer, dimensionality of the output space. diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py index fb743b617f..641b563a25 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py @@ -232,6 +232,7 @@ class RNNTest(test.TestCase): cell = RNNCellWithConstants(32) layer = keras.layers.RNN(cell) y = layer(x, constants=c) + model = keras.models.Model([x, c], y) model.compile(optimizer='rmsprop', loss='mse') model.train_on_batch( @@ -280,6 +281,20 @@ class RNNTest(test.TestCase): ) with self.test_session(): + # Test GRUCell reset_after property. + x = keras.Input((None, 5)) + c = keras.Input((3,)) + cells = [keras.layers.recurrent.GRUCell(32, reset_after=True)] + layer = keras.layers.recurrent.RNN(cells) + y = layer(x, constants=c) + model = keras.models.Model([x, c], y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + [np.zeros((6, 5, 5)), np.zeros((6, 3))], + np.zeros((6, 32)) + ) + + with self.test_session(): # Test stacked RNN serialization x_np = np.random.random((6, 5, 5)) c_np = np.random.random((6, 3)) @@ -541,6 +556,5 @@ class RNNTest(test.TestCase): [tuple(o.as_list()) for o in output_shape], expected_output_shape) - if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index 84ee5040dc..b45cafed31 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -49,6 +49,7 @@ from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3DTranspose from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution1D from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution2D +from tensorflow.python.keras._impl.keras.layers.convolutional import DepthwiseConv2D # Image processing layers. from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling1D diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt new file mode 100644 index 0000000000..b38716aa2c --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt @@ -0,0 +1,187 @@ +path: "tensorflow.keras.layers.DepthwiseConv2D" +tf_class { + is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.DepthwiseConv2D\'>" + is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.Conv2D\'>" + is_instance: "<class \'tensorflow.python.layers.convolutional.Conv2D\'>" + is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>" + is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>" + is_instance: "<class \'tensorflow.python.layers.base.Layer\'>" + is_instance: "<class \'tensorflow.python.training.checkpointable.CheckpointableBase\'>" + is_instance: "<type \'object\'>" + member { + name: "activity_regularizer" + mtype: "<type \'property\'>" + } + member { + name: "dtype" + mtype: "<type \'property\'>" + } + member { + name: "graph" + mtype: "<type \'property\'>" + } + member { + name: "inbound_nodes" + mtype: "<type \'property\'>" + } + member { + name: "input" + mtype: "<type \'property\'>" + } + member { + name: "input_mask" + mtype: "<type \'property\'>" + } + member { + name: "input_shape" + mtype: "<type \'property\'>" + } + member { + name: "losses" + mtype: "<type \'property\'>" + } + member { + name: "name" + mtype: "<type \'property\'>" + } + member { + name: "non_trainable_variables" + mtype: "<type \'property\'>" + } + member { + name: "non_trainable_weights" + mtype: "<type \'property\'>" + } + member { + name: "outbound_nodes" + mtype: "<type \'property\'>" + } + member { + name: "output" + mtype: "<type \'property\'>" + } + member { + name: "output_mask" + mtype: "<type \'property\'>" + } + member { + name: "output_shape" + mtype: "<type \'property\'>" + } + member { + name: "scope_name" + mtype: "<type \'property\'>" + } + member { + name: "trainable_variables" + mtype: "<type \'property\'>" + } + member { + name: "trainable_weights" + mtype: "<type \'property\'>" + } + member { + name: "updates" + mtype: "<type \'property\'>" + } + member { + name: "variables" + mtype: "<type \'property\'>" + } + member { + name: "weights" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'kernel_size\', \'strides\', \'padding\', \'depth_multiplier\', \'data_format\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'1\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt index 1fd3febad2..4274b8d425 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt @@ -91,7 +91,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\'], " + argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'reset_after\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt index f5f41d879d..8d9f06083c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt @@ -123,6 +123,10 @@ tf_class { mtype: "<type \'property\'>" } member { + name: "reset_after" + mtype: "<type \'property\'>" + } + member { name: "scope_name" mtype: "<type \'property\'>" } @@ -160,7 +164,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\', \'False\', \'False\', \'False\', \'False\'], " + argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\', \'reset_after\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\', \'False\', \'False\', \'False\', \'False\', \'False\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt index 088c8e88e2..affc9bd09b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt @@ -117,6 +117,10 @@ tf_module { mtype: "<type \'type\'>" } member { + name: "DepthwiseConv2D" + mtype: "<type \'type\'>" + } + member { name: "Dot" mtype: "<type \'type\'>" } |