diff options
author | Zhenyu Tan <tanzheny@google.com> | 2018-06-26 18:18:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-26 18:20:50 -0700 |
commit | 41731b13598c50a31432e769f4cb9d9fc355cf7a (patch) | |
tree | b26eb49bbcf02a4ae22288d4c32a6256d457334e /tensorflow/python/keras/backend.py | |
parent | b9752f52426f397c3bee42e7c0c6aa3227b0e1ca (diff) |
Fix shape mismatch in `rnn()` of keras backend
PiperOrigin-RevId: 202231273
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r-- | tensorflow/python/keras/backend.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index fed779650e..11f99c030f 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -3161,10 +3161,16 @@ def rnn(step_function, array_ops.stack( [1, array_ops.shape(output)[1]])) output = array_ops.where(tiled_mask_t, output, states[0]) - new_states = [ - array_ops.where(tiled_mask_t, new_states[i], states[i]) - for i in range(len(states)) - ] + + masked_states = [] + for i in range(len(states)): + states_dim = array_ops.shape(new_states[i])[1] + stacked_states_dim = array_ops.stack([1, states_dim]) + tiled_mask = array_ops.tile(mask_t, stacked_states_dim) + masked_state = array_ops.where(tiled_mask, new_states[i], states[i]) + masked_states.append(masked_state) + new_states = masked_states + output_ta_t = output_ta_t.write(time, output) return (time + 1, output_ta_t) + tuple(new_states) else: |