diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/lstm_ops.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/lstm_ops.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index f700717394..4eb4fbcd92 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -572,9 +572,8 @@ class LSTMBlockWrapper(base_layer.Layer): def _gather_states(self, data, indices, batch_size): """Produce `out`, s.t. out(i, j) = data(indices(i), i, j).""" - mod_indices = indices * batch_size + math_ops.range(batch_size) - return array_ops.gather( - array_ops.reshape(data, [-1, self.num_units]), mod_indices) + return array_ops.gather_nd( + data, array_ops.stack([indices, math_ops.range(batch_size)], axis=1)) class LSTMBlockFusedCell(LSTMBlockWrapper): |