diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/lstm_ops.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/lstm_ops.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index 2e6f2ac05c..c1ec46d763 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -362,7 +362,7 @@ class LSTMBlockCell(core_rnn_cell.RNNCell): @property def state_size(self): - return (self._num_units,) * 2 + return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units) @property def output_size(self): @@ -401,7 +401,8 @@ class LSTMBlockCell(core_rnn_cell.RNNCell): forget_bias=self._forget_bias, use_peephole=self._use_peephole) - return (h, (cs, h)) + new_state = core_rnn_cell.LSTMStateTuple(cs, h) + return h, new_state class LSTMBlockWrapper(fused_rnn_cell.FusedRNNCell): @@ -544,7 +545,9 @@ class LSTMBlockWrapper(fused_rnn_cell.FusedRNNCell): # Input was a list, so return a list outputs = array_ops.unstack(outputs) - return outputs, (final_cell_state, final_output) + final_state = core_rnn_cell.LSTMStateTuple(final_cell_state, + final_output) + return outputs, final_state def _gather_states(self, data, indices, batch_size): """Produce `out`, s.t. out(i, j) = data(indices(i), i, j).""" |