diff options
author | Asim Shankar <ashankar@google.com> | 2018-02-05 12:40:51 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-05 12:45:07 -0800 |
commit | d0904cbe01c88332acb4faa8bede21adb5fa1de7 (patch) | |
tree | f3b3d23303faa193c0b0a48c88b7da24c1bd4fd4 | |
parent | fc8d9c38692d3dddf474dba7e43c666105c08a3d (diff) |
contrib/rnn: Fix #16703
(Bug introduced in
https://github.com/tensorflow/tensorflow/commit/3f579020bab8f00e4621e9c7c740cbf13136a809)
Kudos to @akhti for pointing this out.
PiperOrigin-RevId: 184570448
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/rnn_cell.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 8adf5dce6e..eb8fd0c1cd 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -2285,7 +2285,7 @@ class GLSTMCell(rnn_cell_impl.RNNCell): else: self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) self._output_size = num_units - self._linear1 = None + self._linear1 = [None] * number_of_groups self._linear2 = None @property @@ -2359,9 +2359,11 @@ class GLSTMCell(rnn_cell_impl.RNNCell): self._group_shape[0]) ], axis=1) - if self._linear1 is None: - self._linear1 = _Linear(x_g_id, 4 * self._group_shape[1], False) - R_k = self._linear1(x_g_id) # pylint: disable=invalid-name + linear = self._linear1[group_id] + if linear is None: + linear = _Linear(x_g_id, 4 * self._group_shape[1], False) + self._linear1[group_id] = linear + R_k = linear(x_g_id) # pylint: disable=invalid-name i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1) i_parts.append(i_k) |