aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-02-22 14:24:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-22 14:29:27 -0800
commitdce9a49c19f406ba45919e8c94474e55dc5ccd54 (patch)
tree928db8a52603e00aef76985cda16b8bceb9debb2 /tensorflow/contrib/rnn
parentcb7e1963c625fd9713e7475d85621f95be6762f1 (diff)
Merge changes from github.
PiperOrigin-RevId: 186674197
Diffstat (limited to 'tensorflow/contrib/rnn')
-rw-r--r--tensorflow/contrib/rnn/python/ops/lstm_ops.py5
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py5
2 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
index f700717394..4eb4fbcd92 100644
--- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -572,9 +572,8 @@ class LSTMBlockWrapper(base_layer.Layer):
def _gather_states(self, data, indices, batch_size):
"""Produce `out`, s.t. out(i, j) = data(indices(i), i, j)."""
- mod_indices = indices * batch_size + math_ops.range(batch_size)
- return array_ops.gather(
- array_ops.reshape(data, [-1, self.num_units]), mod_indices)
+ return array_ops.gather_nd(
+ data, array_ops.stack([indices, math_ops.range(batch_size)], axis=1))
class LSTMBlockFusedCell(LSTMBlockWrapper):
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index dce71c393a..a6c2d9cdbb 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -424,8 +424,9 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
"W_O_diag", shape=[self._num_units], dtype=dtype)
# initialize the first freq state to be zero
- m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), self._num_units],
- dtype)
+ m_prev_freq = array_ops.zeros(
+ [inputs.shape[0].value or inputs.get_shape()[0], self._num_units],
+ dtype)
for fq in range(len(freq_inputs)):
c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units],
[-1, self._num_units])