aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/backend.py
diff options
context:
space:
mode:
authorGravatar Zhenyu Tan <tanzheny@google.com>2018-06-26 18:18:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-26 18:20:50 -0700
commit41731b13598c50a31432e769f4cb9d9fc355cf7a (patch)
treeb26eb49bbcf02a4ae22288d4c32a6256d457334e /tensorflow/python/keras/backend.py
parentb9752f52426f397c3bee42e7c0c6aa3227b0e1ca (diff)
Fix shape mismatch in `rnn()` of keras backend
PiperOrigin-RevId: 202231273
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r--tensorflow/python/keras/backend.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index fed779650e..11f99c030f 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -3161,10 +3161,16 @@ def rnn(step_function,
array_ops.stack(
[1, array_ops.shape(output)[1]]))
output = array_ops.where(tiled_mask_t, output, states[0])
- new_states = [
- array_ops.where(tiled_mask_t, new_states[i], states[i])
- for i in range(len(states))
- ]
+
+ masked_states = []
+ for i in range(len(states)):
+ states_dim = array_ops.shape(new_states[i])[1]
+ stacked_states_dim = array_ops.stack([1, states_dim])
+ tiled_mask = array_ops.tile(mask_t, stacked_states_dim)
+ masked_state = array_ops.where(tiled_mask, new_states[i], states[i])
+ masked_states.append(masked_state)
+ new_states = masked_states
+
output_ta_t = output_ta_t.write(time, output)
return (time + 1, output_ta_t) + tuple(new_states)
else: