diff options
Diffstat (limited to 'tensorflow/python/ops/rnn_cell_impl.py')
-rw-r--r-- | tensorflow/python/ops/rnn_cell_impl.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index f481726d54..85a6a2233c 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -193,6 +193,13 @@ class RNNCell(base_layer.Layer): for each `s` in `self.batch_size`. """ + def __init__(self, trainable=True, name=None, dtype=None, **kwargs): + super(RNNCell, self).__init__( + trainable=trainable, name=name, dtype=dtype, **kwargs) + # Attribute that indicates whether the cell is a TF RNN cell, due the slight + # difference between TF and Keras RNN cell. + self._is_tf_rnn_cell = True + def __call__(self, inputs, state, scope=None): """Run this RNN cell on inputs, starting from the given state. @@ -524,8 +531,8 @@ class GRUCell(LayerRNNCell): def get_config(self): config = { "num_units": self._num_units, - "initializer": initializers.serialize(self._initializer), "kernel_initializer": initializers.serialize(self._kernel_initializer), + "bias_initializer": initializers.serialize(self._bias_initializer), "activation": activations.serialize(self._activation), "reuse": self._reuse, } |