aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/ops/lstm_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/lstm_ops.py')
-rw-r--r--tensorflow/contrib/rnn/python/ops/lstm_ops.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
index 2e6f2ac05c..c1ec46d763 100644
--- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -362,7 +362,7 @@ class LSTMBlockCell(core_rnn_cell.RNNCell):
@property
def state_size(self):
- return (self._num_units,) * 2
+ return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
@property
def output_size(self):
@@ -401,7 +401,8 @@ class LSTMBlockCell(core_rnn_cell.RNNCell):
forget_bias=self._forget_bias,
use_peephole=self._use_peephole)
- return (h, (cs, h))
+ new_state = core_rnn_cell.LSTMStateTuple(cs, h)
+ return h, new_state
class LSTMBlockWrapper(fused_rnn_cell.FusedRNNCell):
@@ -544,7 +545,9 @@ class LSTMBlockWrapper(fused_rnn_cell.FusedRNNCell):
# Input was a list, so return a list
outputs = array_ops.unstack(outputs)
- return outputs, (final_cell_state, final_output)
+ final_state = core_rnn_cell.LSTMStateTuple(final_cell_state,
+ final_output)
+ return outputs, final_state
def _gather_states(self, data, indices, batch_size):
"""Produce `out`, s.t. out(i, j) = data(indices(i), i, j)."""