diff options
author | Yifei Feng <yifeif@google.com> | 2018-02-22 14:24:57 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-22 14:29:27 -0800 |
commit | dce9a49c19f406ba45919e8c94474e55dc5ccd54 (patch) | |
tree | 928db8a52603e00aef76985cda16b8bceb9debb2 /tensorflow/contrib/rnn | |
parent | cb7e1963c625fd9713e7475d85621f95be6762f1 (diff) |
Merge changes from github.
PiperOrigin-RevId: 186674197
Diffstat (limited to 'tensorflow/contrib/rnn')
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/lstm_ops.py | 5 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/rnn_cell.py | 5 |
2 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index f700717394..4eb4fbcd92 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -572,9 +572,8 @@ class LSTMBlockWrapper(base_layer.Layer): def _gather_states(self, data, indices, batch_size): """Produce `out`, s.t. out(i, j) = data(indices(i), i, j).""" - mod_indices = indices * batch_size + math_ops.range(batch_size) - return array_ops.gather( - array_ops.reshape(data, [-1, self.num_units]), mod_indices) + return array_ops.gather_nd( + data, array_ops.stack([indices, math_ops.range(batch_size)], axis=1)) class LSTMBlockFusedCell(LSTMBlockWrapper): diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index dce71c393a..a6c2d9cdbb 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -424,8 +424,9 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell): "W_O_diag", shape=[self._num_units], dtype=dtype) # initialize the first freq state to be zero - m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), self._num_units], - dtype) + m_prev_freq = array_ops.zeros( + [inputs.shape[0].value or inputs.get_shape()[0], self._num_units], + dtype) for fq in range(len(freq_inputs)): c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units], [-1, self._num_units]) |