aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-04-06 11:24:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-06 11:31:24 -0700
commit98b8b786036172d33c85b6b5f81347440d0594df (patch)
tree0bcfc56b315709294e7cc21a497a27dbadc481f9
parentbeda9ebd36bbf6964459c7ee2209975d62cb01e6 (diff)
Update tf.keras to keras 2.1.5 version.
PiperOrigin-RevId: 191914904
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/mobilenet.py222
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/resnet50.py5
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional.py195
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional_test.py38
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent.py137
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent_test.py16
-rw-r--r--tensorflow/python/keras/layers/__init__.py1
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt187
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt4
11 files changed, 562 insertions, 251 deletions
diff --git a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
index ad96b53a45..12775fccec 100644
--- a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
+++ b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
@@ -84,11 +84,13 @@ from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import BatchNormalization
from tensorflow.python.keras._impl.keras.layers import Conv2D
+from tensorflow.python.keras._impl.keras.layers import DepthwiseConv2D
from tensorflow.python.keras._impl.keras.layers import Dropout
from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D
from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D
from tensorflow.python.keras._impl.keras.layers import Input
from tensorflow.python.keras._impl.keras.layers import Reshape
+from tensorflow.python.keras._impl.keras.layers import ZeroPadding2D
from tensorflow.python.keras._impl.keras.models import Model
from tensorflow.python.keras._impl.keras.utils import conv_utils
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
@@ -116,195 +118,6 @@ def preprocess_input(x):
return imagenet_utils.preprocess_input(x, mode='tf')
-class DepthwiseConv2D(Conv2D):
- """Depthwise separable 2D convolution.
-
- Depthwise Separable convolutions consists in performing
- just the first step in a depthwise spatial convolution
- (which acts on each input channel separately).
- The `depth_multiplier` argument controls how many
- output channels are generated per input channel in the depthwise step.
-
- Arguments:
- kernel_size: An integer or tuple/list of 2 integers, specifying the
- width and height of the 2D convolution window.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- strides: An integer or tuple/list of 2 integers,
- specifying the strides of the convolution along the width and height.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- Specifying any stride value != 1 is incompatible with specifying
- any `dilation_rate` value != 1.
- padding: one of `'valid'` or `'same'` (case-insensitive).
- depth_multiplier: The number of depthwise convolution output channels
- for each input channel.
- The total number of depthwise convolution output
- channels will be equal to `filters_in * depth_multiplier`.
- data_format: A string,
- one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, height, width, channels)` while `channels_first`
- corresponds to inputs with shape
- `(batch, channels, height, width)`.
- It defaults to the `image_data_format` value found in your
- Keras config file at `~/.keras/keras.json`.
- If you never set it, then it will be 'channels_last'.
- activation: Activation function to use.
- If you don't specify anything, no activation is applied
- (ie. 'linear' activation: `a(x) = x`).
- use_bias: Boolean, whether the layer uses a bias vector.
- depthwise_initializer: Initializer for the depthwise kernel matrix.
- bias_initializer: Initializer for the bias vector.
- depthwise_regularizer: Regularizer function applied to
- the depthwise kernel matrix.
- bias_regularizer: Regularizer function applied to the bias vector.
- activity_regularizer: Regularizer function applied to
- the output of the layer (its 'activation')..
- depthwise_constraint: Constraint function applied to
- the depthwise kernel matrix.
- bias_constraint: Constraint function applied to the bias vector.
-
- Input shape:
- 4D tensor with shape:
- `[batch, channels, rows, cols]` if data_format='channels_first'
- or 4D tensor with shape:
- `[batch, rows, cols, channels]` if data_format='channels_last'.
-
- Output shape:
- 4D tensor with shape:
- `[batch, filters, new_rows, new_cols]` if data_format='channels_first'
- or 4D tensor with shape:
- `[batch, new_rows, new_cols, filters]` if data_format='channels_last'.
- `rows` and `cols` values might have changed due to padding.
- """
-
- def __init__(self,
- kernel_size,
- strides=(1, 1),
- padding='valid',
- depth_multiplier=1,
- data_format=None,
- activation=None,
- use_bias=True,
- depthwise_initializer='glorot_uniform',
- bias_initializer='zeros',
- depthwise_regularizer=None,
- bias_regularizer=None,
- activity_regularizer=None,
- depthwise_constraint=None,
- bias_constraint=None,
- **kwargs):
- super(DepthwiseConv2D, self).__init__(
- filters=None,
- kernel_size=kernel_size,
- strides=strides,
- padding=padding,
- data_format=data_format,
- activation=activation,
- use_bias=use_bias,
- bias_regularizer=bias_regularizer,
- activity_regularizer=activity_regularizer,
- bias_constraint=bias_constraint,
- **kwargs)
- self.depth_multiplier = depth_multiplier
- self.depthwise_initializer = initializers.get(depthwise_initializer)
- self.depthwise_regularizer = regularizers.get(depthwise_regularizer)
- self.depthwise_constraint = constraints.get(depthwise_constraint)
- self.bias_initializer = initializers.get(bias_initializer)
-
- @shape_type_conversion
- def build(self, input_shape):
- if len(input_shape) < 4:
- raise ValueError('Inputs to `DepthwiseConv2D` should have rank 4. '
- 'Received input shape:', str(input_shape))
- if self.data_format == 'channels_first':
- channel_axis = 1
- else:
- channel_axis = 3
- if input_shape[channel_axis] is None:
- raise ValueError('The channel dimension of the inputs to '
- '`DepthwiseConv2D` '
- 'should be defined. Found `None`.')
- input_dim = int(input_shape[channel_axis])
- depthwise_kernel_shape = (self.kernel_size[0], self.kernel_size[1],
- input_dim, self.depth_multiplier)
-
- self.depthwise_kernel = self.add_weight(
- shape=depthwise_kernel_shape,
- initializer=self.depthwise_initializer,
- name='depthwise_kernel',
- regularizer=self.depthwise_regularizer,
- constraint=self.depthwise_constraint)
-
- if self.use_bias:
- self.bias = self.add_weight(
- shape=(input_dim * self.depth_multiplier,),
- initializer=self.bias_initializer,
- name='bias',
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint)
- else:
- self.bias = None
- # Set input spec.
- self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim})
- self.built = True
-
- def call(self, inputs, training=None):
- outputs = K.depthwise_conv2d(
- inputs,
- self.depthwise_kernel,
- strides=self.strides,
- padding=self.padding,
- dilation_rate=self.dilation_rate,
- data_format=self.data_format)
-
- if self.bias:
- outputs = K.bias_add(outputs, self.bias, data_format=self.data_format)
-
- if self.activation is not None:
- return self.activation(outputs)
-
- return outputs
-
- @shape_type_conversion
- def compute_output_shape(self, input_shape):
- if self.data_format == 'channels_first':
- rows = input_shape[2]
- cols = input_shape[3]
- out_filters = input_shape[1] * self.depth_multiplier
- elif self.data_format == 'channels_last':
- rows = input_shape[1]
- cols = input_shape[2]
- out_filters = input_shape[3] * self.depth_multiplier
-
- rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
- self.padding, self.strides[0])
- cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
- self.padding, self.strides[1])
-
- if self.data_format == 'channels_first':
- return (input_shape[0], out_filters, rows, cols)
- elif self.data_format == 'channels_last':
- return (input_shape[0], rows, cols, out_filters)
-
- def get_config(self):
- config = super(DepthwiseConv2D, self).get_config()
- config.pop('filters')
- config.pop('kernel_initializer')
- config.pop('kernel_regularizer')
- config.pop('kernel_constraint')
- config['depth_multiplier'] = self.depth_multiplier
- config['depthwise_initializer'] = initializers.serialize(
- self.depthwise_initializer)
- config['depthwise_regularizer'] = regularizers.serialize(
- self.depthwise_regularizer)
- config['depthwise_constraint'] = constraints.serialize(
- self.depthwise_constraint)
- return config
-
-
@tf_export('keras.applications.MobileNet',
'keras.applications.mobilenet.MobileNet')
def MobileNet(input_shape=None,
@@ -318,18 +131,11 @@ def MobileNet(input_shape=None,
classes=1000):
"""Instantiates the MobileNet architecture.
- Note that only TensorFlow is supported for now,
- therefore it only works with the data format
- `image_data_format='channels_last'` in your Keras config
- at `~/.keras/keras.json`.
-
To load a MobileNet model via `load_model`, import the custom
- objects `relu6` and `DepthwiseConv2D` and pass them to the
- `custom_objects` parameter.
+ objects `relu6` and pass them to the `custom_objects` parameter.
E.g.
model = load_model('mobilenet.h5', custom_objects={
- 'relu6': mobilenet.relu6,
- 'DepthwiseConv2D': mobilenet.DepthwiseConv2D})
+ 'relu6': mobilenet.relu6})
Arguments:
input_shape: optional shape tuple, only to be specified
@@ -383,11 +189,6 @@ def MobileNet(input_shape=None,
backend that does not support separable convolutions.
"""
- if K.backend() != 'tensorflow':
- raise RuntimeError('Only TensorFlow backend is currently supported, '
- 'as other backends do not support '
- 'depthwise convolution.')
-
if not (weights in {'imagenet', None} or os.path.exists(weights)):
raise ValueError('The `weights` argument should be either '
'`None` (random initialization), `imagenet` '
@@ -522,7 +323,7 @@ def MobileNet(input_shape=None,
# load weights
if weights == 'imagenet':
if K.image_data_format() == 'channels_first':
- raise ValueError('Weights for "channels_last" format '
+ raise ValueError('Weights for "channels_first" format '
'are not available.')
if alpha == 1.0:
alpha_text = '1_0'
@@ -598,14 +399,14 @@ def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)):
"""
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
filters = int(filters * alpha)
+ x = ZeroPadding2D(padding=(1, 1), name='conv1_pad')(inputs)
x = Conv2D(
filters,
kernel,
- padding='same',
+ padding='valid',
use_bias=False,
strides=strides,
- name='conv1')(
- inputs)
+ name='conv1')(x)
x = BatchNormalization(axis=channel_axis, name='conv1_bn')(x)
return Activation(relu6, name='conv1_relu')(x)
@@ -665,15 +466,14 @@ def _depthwise_conv_block(inputs,
"""
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
pointwise_conv_filters = int(pointwise_conv_filters * alpha)
-
+ x = ZeroPadding2D(padding=(1, 1), name='conv_pad_%d' % block_id)(inputs)
x = DepthwiseConv2D( # pylint: disable=not-callable
(3, 3),
- padding='same',
+ padding='valid',
depth_multiplier=depth_multiplier,
strides=strides,
use_bias=False,
- name='conv_dw_%d' % block_id)(
- inputs)
+ name='conv_dw_%d' % block_id)(x)
x = BatchNormalization(axis=channel_axis, name='conv_dw_%d_bn' % block_id)(x)
x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x)
diff --git a/tensorflow/python/keras/_impl/keras/applications/resnet50.py b/tensorflow/python/keras/_impl/keras/applications/resnet50.py
index 46c0e63557..f8c6aff4f2 100644
--- a/tensorflow/python/keras/_impl/keras/applications/resnet50.py
+++ b/tensorflow/python/keras/_impl/keras/applications/resnet50.py
@@ -45,6 +45,7 @@ from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D
from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D
from tensorflow.python.keras._impl.keras.layers import Input
from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
+from tensorflow.python.keras._impl.keras.layers import ZeroPadding2D
from tensorflow.python.keras._impl.keras.models import Model
from tensorflow.python.keras._impl.keras.utils import layer_utils
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
@@ -236,9 +237,9 @@ def ResNet50(include_top=True,
else:
bn_axis = 1
+ x = ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)
x = Conv2D(
- 64, (7, 7), strides=(2, 2), padding='same', name='conv1')(
- img_input)
+ 64, (7, 7), strides=(2, 2), padding='valid', name='conv1')(x)
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3), strides=(2, 2))(x)
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional.py b/tensorflow/python/keras/_impl/keras/layers/convolutional.py
index 162ae6c28f..7cdebc6aa4 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional.py
@@ -27,6 +27,7 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
# imports for backwards namespace compatibility
# pylint: disable=unused-import
from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling1D
@@ -1024,6 +1025,200 @@ class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer):
return dict(list(base_config.items()) + list(config.items()))
+@tf_export('keras.layers.DepthwiseConv2D')
+class DepthwiseConv2D(Conv2D):
+ """Depthwise separable 2D convolution.
+
+ Depthwise Separable convolutions consists in performing
+ just the first step in a depthwise spatial convolution
+ (which acts on each input channel separately).
+ The `depth_multiplier` argument controls how many
+ output channels are generated per input channel in the depthwise step.
+
+ Arguments:
+ kernel_size: An integer or tuple/list of 2 integers, specifying the
+ width and height of the 2D convolution window.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ strides: An integer or tuple/list of 2 integers,
+ specifying the strides of the convolution along the width and height.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: one of `'valid'` or `'same'` (case-insensitive).
+ depth_multiplier: The number of depthwise convolution output channels
+ for each input channel.
+ The total number of depthwise convolution output
+ channels will be equal to `filters_in * depth_multiplier`.
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, height, width, channels)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, channels, height, width)`.
+ It defaults to the `image_data_format` value found in your
+ Keras config file at `~/.keras/keras.json`.
+ If you never set it, then it will be 'channels_last'.
+ activation: Activation function to use.
+ If you don't specify anything, no activation is applied
+ (ie. 'linear' activation: `a(x) = x`).
+ use_bias: Boolean, whether the layer uses a bias vector.
+ depthwise_initializer: Initializer for the depthwise kernel matrix.
+ bias_initializer: Initializer for the bias vector.
+ depthwise_regularizer: Regularizer function applied to
+ the depthwise kernel matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
+ activity_regularizer: Regularizer function applied to
+ the output of the layer (its 'activation').
+ depthwise_constraint: Constraint function applied to
+ the depthwise kernel matrix.
+ bias_constraint: Constraint function applied to the bias vector.
+
+ Input shape:
+ 4D tensor with shape:
+ `[batch, channels, rows, cols]` if data_format='channels_first'
+ or 4D tensor with shape:
+ `[batch, rows, cols, channels]` if data_format='channels_last'.
+
+ Output shape:
+ 4D tensor with shape:
+ `[batch, filters, new_rows, new_cols]` if data_format='channels_first'
+ or 4D tensor with shape:
+ `[batch, new_rows, new_cols, filters]` if data_format='channels_last'.
+ `rows` and `cols` values might have changed due to padding.
+ """
+
+ def __init__(self,
+ kernel_size,
+ strides=(1, 1),
+ padding='valid',
+ depth_multiplier=1,
+ data_format=None,
+ activation=None,
+ use_bias=True,
+ depthwise_initializer='glorot_uniform',
+ bias_initializer='zeros',
+ depthwise_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ depthwise_constraint=None,
+ bias_constraint=None,
+ **kwargs):
+ super(DepthwiseConv2D, self).__init__(
+ filters=None,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ activation=activation,
+ use_bias=use_bias,
+ bias_regularizer=bias_regularizer,
+ activity_regularizer=activity_regularizer,
+ bias_constraint=bias_constraint,
+ **kwargs)
+ self.depth_multiplier = depth_multiplier
+ self.depthwise_initializer = initializers.get(depthwise_initializer)
+ self.depthwise_regularizer = regularizers.get(depthwise_regularizer)
+ self.depthwise_constraint = constraints.get(depthwise_constraint)
+ self.bias_initializer = initializers.get(bias_initializer)
+
+ def build(self, input_shape):
+ if len(input_shape) < 4:
+ raise ValueError('Inputs to `DepthwiseConv2D` should have rank 4. '
+ 'Received input shape:', str(input_shape))
+ if self.data_format == 'channels_first':
+ channel_axis = 1
+ else:
+ channel_axis = 3
+ if input_shape[channel_axis] is None:
+ raise ValueError('The channel dimension of the inputs to '
+ '`DepthwiseConv2D` '
+ 'should be defined. Found `None`.')
+ input_dim = int(input_shape[channel_axis])
+ depthwise_kernel_shape = (self.kernel_size[0],
+ self.kernel_size[1],
+ input_dim,
+ self.depth_multiplier)
+
+ self.depthwise_kernel = self.add_weight(
+ shape=depthwise_kernel_shape,
+ initializer=self.depthwise_initializer,
+ name='depthwise_kernel',
+ regularizer=self.depthwise_regularizer,
+ constraint=self.depthwise_constraint)
+
+ if self.use_bias:
+ self.bias = self.add_weight(shape=(input_dim * self.depth_multiplier,),
+ initializer=self.bias_initializer,
+ name='bias',
+ regularizer=self.bias_regularizer,
+ constraint=self.bias_constraint)
+ else:
+ self.bias = None
+ # Set input spec.
+ self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim})
+ self.built = True
+
+ def call(self, inputs, training=None):
+ outputs = K.depthwise_conv2d(
+ inputs,
+ self.depthwise_kernel,
+ strides=self.strides,
+ padding=self.padding,
+ dilation_rate=self.dilation_rate,
+ data_format=self.data_format)
+
+ if self.bias:
+ outputs = K.bias_add(
+ outputs,
+ self.bias,
+ data_format=self.data_format)
+
+ if self.activation is not None:
+ return self.activation(outputs)
+
+ return outputs
+
+ @shape_type_conversion
+ def compute_output_shape(self, input_shape):
+ if self.data_format == 'channels_first':
+ rows = input_shape[2]
+ cols = input_shape[3]
+ out_filters = input_shape[1] * self.depth_multiplier
+ elif self.data_format == 'channels_last':
+ rows = input_shape[1]
+ cols = input_shape[2]
+ out_filters = input_shape[3] * self.depth_multiplier
+
+ rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
+ self.padding,
+ self.strides[0])
+ cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
+ self.padding,
+ self.strides[1])
+ if self.data_format == 'channels_first':
+ return (input_shape[0], out_filters, rows, cols)
+ elif self.data_format == 'channels_last':
+ return (input_shape[0], rows, cols, out_filters)
+
+ def get_config(self):
+ config = super(DepthwiseConv2D, self).get_config()
+ config.pop('filters')
+ config.pop('kernel_initializer')
+ config.pop('kernel_regularizer')
+ config.pop('kernel_constraint')
+ config['depth_multiplier'] = self.depth_multiplier
+ config['depthwise_initializer'] = initializers.serialize(
+ self.depthwise_initializer)
+ config['depthwise_regularizer'] = regularizers.serialize(
+ self.depthwise_regularizer)
+ config['depthwise_constraint'] = constraints.serialize(
+ self.depthwise_constraint)
+ return config
+
+
@tf_export('keras.layers.UpSampling1D')
class UpSampling1D(Layer):
"""Upsampling layer for 1D inputs.
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py
index f4a134b96c..12b4267675 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py
@@ -961,5 +961,43 @@ class CroppingTest(test.TestCase):
keras.layers.Cropping3D(cropping=None)
+class DepthwiseConv2DTest(test.TestCase):
+
+ def _run_test(self, kwargs, arg, values):
+ num_samples = 2
+ stack_size = 3
+ num_row = 7
+ num_col = 6
+
+ test_kwargs = copy.copy(kwargs)
+ for value in values:
+ test_kwargs[arg] = value
+ with self.test_session(use_gpu=True):
+ testing_utils.layer_test(
+ keras.layers.DepthwiseConv2D,
+ kwargs=test_kwargs,
+ input_shape=(num_samples, num_row, num_col, stack_size))
+
+ def test_depthwise_conv2d(self):
+ kwargs = {'kernel_size': (3, 3)}
+
+ self._run_test(kwargs, 'padding', ['valid', 'same'])
+ self._run_test(kwargs, 'strides', [(2, 2)])
+ if test.is_gpu_available(cuda_only=True):
+ self._run_test(kwargs, 'data_format', ['channels_first'])
+ self._run_test(kwargs, 'depth_multiplier', [1, 2])
+
+ kwargs = {'kernel_size': 3,
+ 'padding': 'valid',
+ 'data_format': 'channels_first',
+ 'activation': None,
+ 'depthwise_regularizer': 'l2',
+ 'bias_regularizer': 'l2',
+ 'activity_regularizer': 'l2',
+ 'depthwise_constraint': 'unit_norm',
+ 'strides': (2, 2),
+ }
+ self._run_test(kwargs, 'depth_multiplier', [1])
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
index 7f9f77c296..f53db987ff 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
@@ -251,7 +251,7 @@ class RNN(Layer):
It is also possible for `cell` to be a list of RNN cell instances,
in which cases the cells get stacked on after the other in the RNN,
implementing an efficient stacked RNN.
- return_sequences: Boolean. Whether to return the last output.
+ return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence.
return_state: Boolean. Whether to return the last state
in addition to the output.
@@ -797,10 +797,10 @@ class RNN(Layer):
@property
def losses(self):
- losses = []
+ layer_losses = super(RNN, self).losses
if isinstance(self.cell, Layer):
- losses += self.cell.losses
- return losses + self._losses
+ return self.cell.losses + layer_losses
+ return layer_losses
@property
def updates(self):
@@ -1017,7 +1017,7 @@ class SimpleRNN(RNN):
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the recurrent state.
- return_sequences: Boolean. Whether to return the last output.
+ return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence.
return_state: Boolean. Whether to return the last state
in addition to the output.
@@ -1237,6 +1237,9 @@ class GRUCell(Layer):
batch them into fewer, larger operations. These modes will
have different performance profiles on different hardware and
for different applications.
+ reset_after: GRU convention (whether to apply reset gate after or
+ before matrix multiplication). False = "before" (default),
+ True = "after" (CuDNN compatible).
"""
def __init__(self,
@@ -1256,6 +1259,7 @@ class GRUCell(Layer):
dropout=0.,
recurrent_dropout=0.,
implementation=1,
+ reset_after=False,
**kwargs):
super(GRUCell, self).__init__(**kwargs)
self.units = units
@@ -1278,6 +1282,7 @@ class GRUCell(Layer):
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.implementation = implementation
+ self.reset_after = reset_after
self.state_size = self.units
self._dropout_mask = None
self._recurrent_dropout_mask = None
@@ -1299,12 +1304,25 @@ class GRUCell(Layer):
constraint=self.recurrent_constraint)
if self.use_bias:
- self.bias = self.add_weight(
- shape=(self.units * 3,),
- name='bias',
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint)
+ if not self.reset_after:
+ bias_shape = (3 * self.units,)
+ else:
+ # separate biases for input and recurrent kernels
+ # Note: the shape is intentionally different from CuDNNGRU biases
+ # `(2 * 3 * self.units,)`, so that we can distinguish the classes
+ # when loading and converting saved weights.
+ bias_shape = (2, 3 * self.units)
+ self.bias = self.add_weight(shape=bias_shape,
+ name='bias',
+ initializer=self.bias_initializer,
+ regularizer=self.bias_regularizer,
+ constraint=self.bias_constraint)
+ if not self.reset_after:
+ self.input_bias, self.recurrent_bias = self.bias, None
+ else:
+ self.input_bias = K.flatten(self.bias[0])
+ self.recurrent_bias = K.flatten(self.bias[1])
+
else:
self.bias = None
self.built = True
@@ -1340,13 +1358,15 @@ class GRUCell(Layer):
inputs_z = inputs
inputs_r = inputs
inputs_h = inputs
+
x_z = K.dot(inputs_z, self.kernel[:, :self.units])
x_r = K.dot(inputs_r, self.kernel[:, self.units:self.units * 2])
x_h = K.dot(inputs_h, self.kernel[:, self.units * 2:])
+
if self.use_bias:
- x_z = K.bias_add(x_z, self.bias[:self.units])
- x_r = K.bias_add(x_r, self.bias[self.units:self.units * 2])
- x_h = K.bias_add(x_h, self.bias[self.units * 2:])
+ x_z = K.bias_add(x_z, self.input_bias[:self.units])
+ x_r = K.bias_add(x_r, self.input_bias[self.units: self.units * 2])
+ x_h = K.bias_add(x_h, self.input_bias[self.units * 2:])
if 0. < self.recurrent_dropout < 1.:
h_tm1_z = h_tm1 * rec_dp_mask[0]
@@ -1356,42 +1376,70 @@ class GRUCell(Layer):
h_tm1_z = h_tm1
h_tm1_r = h_tm1
h_tm1_h = h_tm1
- z = self.recurrent_activation(
- x_z + K.dot(h_tm1_z, self.recurrent_kernel[:, :self.units]))
- r = self.recurrent_activation(
- x_r + K.dot(h_tm1_r, self.recurrent_kernel[:, self.units:
- self.units * 2]))
-
- hh = self.activation(x_h + K.dot(r * h_tm1_h,
- self.recurrent_kernel[:,
- self.units * 2:]))
+
+ recurrent_z = K.dot(h_tm1_z, self.recurrent_kernel[:, :self.units])
+ recurrent_r = K.dot(h_tm1_r,
+ self.recurrent_kernel[:, self.units:self.units * 2])
+ if self.reset_after and self.use_bias:
+ recurrent_z = K.bias_add(recurrent_z, self.recurrent_bias[:self.units])
+ recurrent_r = K.bias_add(recurrent_r,
+ self.recurrent_bias[self.units:
+ self.units * 2])
+
+ z = self.recurrent_activation(x_z + recurrent_z)
+ r = self.recurrent_activation(x_r + recurrent_r)
+
+ # reset gate applied after/before matrix multiplication
+ if self.reset_after:
+ recurrent_h = K.dot(h_tm1_h, self.recurrent_kernel[:, self.units * 2:])
+ if self.use_bias:
+ recurrent_h = K.bias_add(recurrent_h,
+ self.recurrent_bias[self.units * 2:])
+ recurrent_h = r * recurrent_h
+ else:
+ recurrent_h = K.dot(r * h_tm1_h,
+ self.recurrent_kernel[:, self.units * 2:])
+
+ hh = self.activation(x_h + recurrent_h)
else:
if 0. < self.dropout < 1.:
inputs *= dp_mask[0]
+
+ # inputs projected by all gate matrices at once
matrix_x = K.dot(inputs, self.kernel)
if self.use_bias:
- matrix_x = K.bias_add(matrix_x, self.bias)
+ # biases: bias_z_i, bias_r_i, bias_h_i
+ matrix_x = K.bias_add(matrix_x, self.input_bias)
+
+ x_z = matrix_x[:, :self.units]
+ x_r = matrix_x[:, self.units: 2 * self.units]
+ x_h = matrix_x[:, 2 * self.units:]
+
if 0. < self.recurrent_dropout < 1.:
h_tm1 *= rec_dp_mask[0]
matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units])
- x_z = matrix_x[:, :self.units]
- x_r = matrix_x[:, self.units:2 * self.units]
recurrent_z = matrix_inner[:, :self.units]
recurrent_r = matrix_inner[:, self.units:2 * self.units]
z = self.recurrent_activation(x_z + recurrent_z)
r = self.recurrent_activation(x_r + recurrent_r)
- x_h = matrix_x[:, 2 * self.units:]
- recurrent_h = K.dot(r * h_tm1, self.recurrent_kernel[:, 2 * self.units:])
+ if self.reset_after:
+ recurrent_h = r * matrix_inner[:, 2 * self.units:]
+ else:
+ recurrent_h = K.dot(r * h_tm1,
+ self.recurrent_kernel[:, 2 * self.units:])
+
hh = self.activation(x_h + recurrent_h)
+ # previous and candidate state mixed by update gate
h = z * h_tm1 + (1 - z) * hh
if 0 < self.dropout + self.recurrent_dropout:
if training is None and not context.executing_eagerly():
# This would be harmless to set in eager mode, but eager tensors
# disallow setting arbitrary attributes.
h._uses_learning_phase = True
+
return h, [h]
def get_config(self):
@@ -1415,7 +1463,8 @@ class GRUCell(Layer):
'bias_constraint': constraints.serialize(self.bias_constraint),
'dropout': self.dropout,
'recurrent_dropout': self.recurrent_dropout,
- 'implementation': self.implementation
+ 'implementation': self.implementation,
+ 'reset_after': self.reset_after
}
base_config = super(GRUCell, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@@ -1423,9 +1472,16 @@ class GRUCell(Layer):
@tf_export('keras.layers.GRU')
class GRU(RNN):
- """Gated Recurrent Unit - Cho et al.
+ """Gated Recurrent Unit - Cho et al. 2014.
- 2014.
+ There are two variants. The default one is based on 1406.1078v3 and
+ has reset gate applied to hidden state before matrix multiplication. The
+ other one is based on original 1406.1078v1 and has the order reversed.
+
+ The second variant is compatible with CuDNNGRU (GPU-only) and allows
+ inference on CPU. Thus it has separate biases for `kernel` and
+ `recurrent_kernel`. Use `'reset_after'=True` and
+ `recurrent_activation='sigmoid'`.
Arguments:
units: Positive integer, dimensionality of the output space.
@@ -1469,7 +1525,7 @@ class GRU(RNN):
batch them into fewer, larger operations. These modes will
have different performance profiles on different hardware and
for different applications.
- return_sequences: Boolean. Whether to return the last output.
+ return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence.
return_state: Boolean. Whether to return the last state
in addition to the output.
@@ -1485,6 +1541,9 @@ class GRU(RNN):
Unrolling can speed-up a RNN,
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
+ reset_after: GRU convention (whether to apply reset gate after or
+ before matrix multiplication). False = "before" (default),
+ True = "after" (CuDNN compatible).
"""
@@ -1511,6 +1570,7 @@ class GRU(RNN):
go_backwards=False,
stateful=False,
unroll=False,
+ reset_after=False,
**kwargs):
if implementation == 0:
logging.warning('`implementation=0` has been deprecated, '
@@ -1532,7 +1592,8 @@ class GRU(RNN):
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
- implementation=implementation)
+ implementation=implementation,
+ reset_after=reset_after)
super(GRU, self).__init__(
cell,
return_sequences=return_sequences,
@@ -1613,6 +1674,10 @@ class GRU(RNN):
def implementation(self):
return self.cell.implementation
+ @property
+ def reset_after(self):
+ return self.cell.reset_after
+
def get_config(self):
config = {
'units':
@@ -1648,7 +1713,9 @@ class GRU(RNN):
'recurrent_dropout':
self.recurrent_dropout,
'implementation':
- self.implementation
+ self.implementation,
+ 'reset_after':
+ self.reset_after
}
base_config = super(GRU, self).get_config()
del base_config['cell']
@@ -1929,7 +1996,7 @@ class LSTMCell(Layer):
@tf_export('keras.layers.LSTM')
class LSTM(RNN):
- """Long-Short Term Memory layer - Hochreiter 1997.
+ """Long Short-Term Memory layer - Hochreiter 1997.
Arguments:
units: Positive integer, dimensionality of the output space.
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
index fb743b617f..641b563a25 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
@@ -232,6 +232,7 @@ class RNNTest(test.TestCase):
cell = RNNCellWithConstants(32)
layer = keras.layers.RNN(cell)
y = layer(x, constants=c)
+
model = keras.models.Model([x, c], y)
model.compile(optimizer='rmsprop', loss='mse')
model.train_on_batch(
@@ -280,6 +281,20 @@ class RNNTest(test.TestCase):
)
with self.test_session():
+ # Test GRUCell reset_after property.
+ x = keras.Input((None, 5))
+ c = keras.Input((3,))
+ cells = [keras.layers.recurrent.GRUCell(32, reset_after=True)]
+ layer = keras.layers.recurrent.RNN(cells)
+ y = layer(x, constants=c)
+ model = keras.models.Model([x, c], y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ [np.zeros((6, 5, 5)), np.zeros((6, 3))],
+ np.zeros((6, 32))
+ )
+
+ with self.test_session():
# Test stacked RNN serialization
x_np = np.random.random((6, 5, 5))
c_np = np.random.random((6, 3))
@@ -541,6 +556,5 @@ class RNNTest(test.TestCase):
[tuple(o.as_list()) for o in output_shape],
expected_output_shape)
-
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py
index 84ee5040dc..b45cafed31 100644
--- a/tensorflow/python/keras/layers/__init__.py
+++ b/tensorflow/python/keras/layers/__init__.py
@@ -49,6 +49,7 @@ from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution
from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3DTranspose
from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution1D
from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution2D
+from tensorflow.python.keras._impl.keras.layers.convolutional import DepthwiseConv2D
# Image processing layers.
from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling1D
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
new file mode 100644
index 0000000000..b38716aa2c
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
@@ -0,0 +1,187 @@
+path: "tensorflow.keras.layers.DepthwiseConv2D"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.DepthwiseConv2D\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.Conv2D\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional.Conv2D\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.training.checkpointable.CheckpointableBase\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "activity_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "inbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "outbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'kernel_size\', \'strides\', \'padding\', \'depth_multiplier\', \'data_format\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'1\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_weight"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "compute_mask"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "compute_output_shape"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "count_params"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_weights"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_weights"
+ argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
index 1fd3febad2..4274b8d425 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
@@ -91,7 +91,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\'], "
+ argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'reset_after\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
index f5f41d879d..8d9f06083c 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
@@ -123,6 +123,10 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "reset_after"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "scope_name"
mtype: "<type \'property\'>"
}
@@ -160,7 +164,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\', \'False\', \'False\', \'False\', \'False\'], "
+ argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\', \'reset_after\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\', \'False\', \'False\', \'False\', \'False\', \'False\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
index 088c8e88e2..affc9bd09b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
@@ -117,6 +117,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "DepthwiseConv2D"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "Dot"
mtype: "<type \'type\'>"
}