aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/layers/recurrent.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/layers/recurrent.py')
-rw-r--r--tensorflow/python/keras/layers/recurrent.py16
1 files changed, 14 insertions, 2 deletions
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index 66c68e2085..12c82a53f6 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -670,6 +670,8 @@ class RNN(Layer):
if generic_utils.has_arg(self.cell.call, 'training'):
kwargs['training'] = training
+ # TF RNN cells expect single tensor as state instead of list wrapped tensor.
+ is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None
if constants:
if not generic_utils.has_arg(self.cell.call, 'constants'):
raise ValueError('RNN cell does not support constants')
@@ -677,11 +679,21 @@ class RNN(Layer):
def step(inputs, states):
constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type
- return self.cell.call(inputs, states, constants=constants, **kwargs)
+
+ states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
+ output, new_states = self.cell.call(
+ inputs, states, constants=constants, **kwargs)
+ if not nest.is_sequence(new_states):
+ new_states = [new_states]
+ return output, new_states
else:
def step(inputs, states):
- return self.cell.call(inputs, states, **kwargs)
+ states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
+ output, new_states = self.cell.call(inputs, states, **kwargs)
+ if not nest.is_sequence(new_states):
+ new_states = [new_states]
+ return output, new_states
last_output, outputs, states = K.rnn(
step,