diff options
Diffstat (limited to 'tensorflow/python/keras/layers/recurrent.py')
-rw-r--r-- | tensorflow/python/keras/layers/recurrent.py | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 66c68e2085..12c82a53f6 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -670,6 +670,8 @@ class RNN(Layer): if generic_utils.has_arg(self.cell.call, 'training'): kwargs['training'] = training + # TF RNN cells expect single tensor as state instead of list wrapped tensor. + is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None if constants: if not generic_utils.has_arg(self.cell.call, 'constants'): raise ValueError('RNN cell does not support constants') @@ -677,11 +679,21 @@ class RNN(Layer): def step(inputs, states): constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type - return self.cell.call(inputs, states, constants=constants, **kwargs) + + states = states[0] if len(states) == 1 and is_tf_rnn_cell else states + output, new_states = self.cell.call( + inputs, states, constants=constants, **kwargs) + if not nest.is_sequence(new_states): + new_states = [new_states] + return output, new_states else: def step(inputs, states): - return self.cell.call(inputs, states, **kwargs) + states = states[0] if len(states) == 1 and is_tf_rnn_cell else states + output, new_states = self.cell.call(inputs, states, **kwargs) + if not nest.is_sequence(new_states): + new_states = [new_states] + return output, new_states last_output, outputs, states = K.rnn( step, |