aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-02-05 12:40:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-05 12:45:07 -0800
commitd0904cbe01c88332acb4faa8bede21adb5fa1de7 (patch)
treef3b3d23303faa193c0b0a48c88b7da24c1bd4fd4 /tensorflow/contrib/rnn
parentfc8d9c38692d3dddf474dba7e43c666105c08a3d (diff)
contrib/rnn: Fix #16703
(Bug introduced in https://github.com/tensorflow/tensorflow/commit/3f579020bab8f00e4621e9c7c740cbf13136a809) Kudos to @akhti for pointing this out. PiperOrigin-RevId: 184570448
Diffstat (limited to 'tensorflow/contrib/rnn')
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 8adf5dce6e..eb8fd0c1cd 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -2285,7 +2285,7 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
else:
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
self._output_size = num_units
- self._linear1 = None
+ self._linear1 = [None] * number_of_groups
self._linear2 = None
@property
@@ -2359,9 +2359,11 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
self._group_shape[0])
],
axis=1)
- if self._linear1 is None:
- self._linear1 = _Linear(x_g_id, 4 * self._group_shape[1], False)
- R_k = self._linear1(x_g_id) # pylint: disable=invalid-name
+ linear = self._linear1[group_id]
+ if linear is None:
+ linear = _Linear(x_g_id, 4 * self._group_shape[1], False)
+ self._linear1[group_id] = linear
+ R_k = linear(x_g_id) # pylint: disable=invalid-name
i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1)
i_parts.append(i_k)