aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/rnn_cell_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/rnn_cell_impl.py')
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index f481726d54..85a6a2233c 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -193,6 +193,13 @@ class RNNCell(base_layer.Layer):
for each `s` in `self.batch_size`.
"""
+ def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
+ super(RNNCell, self).__init__(
+ trainable=trainable, name=name, dtype=dtype, **kwargs)
+ # Attribute that indicates whether the cell is a TF RNN cell, due the slight
+ # difference between TF and Keras RNN cell.
+ self._is_tf_rnn_cell = True
+
def __call__(self, inputs, state, scope=None):
"""Run this RNN cell on inputs, starting from the given state.
@@ -524,8 +531,8 @@ class GRUCell(LayerRNNCell):
def get_config(self):
config = {
"num_units": self._num_units,
- "initializer": initializers.serialize(self._initializer),
"kernel_initializer": initializers.serialize(self._kernel_initializer),
+ "bias_initializer": initializers.serialize(self._bias_initializer),
"activation": activations.serialize(self._activation),
"reuse": self._reuse,
}