diff options
Diffstat (limited to 'tensorflow/python/keras/_impl/keras/backend.py')
-rw-r--r-- | tensorflow/python/keras/_impl/keras/backend.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py index 1fa264660d..a238a3f748 100644 --- a/tensorflow/python/keras/_impl/keras/backend.py +++ b/tensorflow/python/keras/_impl/keras/backend.py @@ -2895,6 +2895,7 @@ def rnn(step_function, ndim = len(inputs.get_shape()) if ndim < 3: raise ValueError('Input should be at least 3D.') + inputs_shape = inputs.get_shape() axes = [1, 0] + list(range(2, ndim)) inputs = array_ops.transpose(inputs, (axes)) @@ -3079,6 +3080,13 @@ def rnn(step_function, axes = [1, 0] + list(range(2, len(outputs.get_shape()))) outputs = array_ops.transpose(outputs, axes) + + # Static shape inference: (samples, time, ...) + outputs_shape = outputs.get_shape().as_list() + outputs_shape[0] = inputs_shape[0] + outputs_shape[1] = inputs_shape[1] + outputs.set_shape(outputs_shape) + last_output._uses_learning_phase = uses_learning_phase return last_output, outputs, new_states |