aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/ops/lstm_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/lstm_ops.py')
-rw-r--r--tensorflow/contrib/rnn/python/ops/lstm_ops.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
index f700717394..4eb4fbcd92 100644
--- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -572,9 +572,8 @@ class LSTMBlockWrapper(base_layer.Layer):
def _gather_states(self, data, indices, batch_size):
"""Produce `out`, s.t. out(i, j) = data(indices(i), i, j)."""
- mod_indices = indices * batch_size + math_ops.range(batch_size)
- return array_ops.gather(
- array_ops.reshape(data, [-1, self.num_units]), mod_indices)
+ return array_ops.gather_nd(
+ data, array_ops.stack([indices, math_ops.range(batch_size)], axis=1))
class LSTMBlockFusedCell(LSTMBlockWrapper):