aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/_impl/keras/backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/_impl/keras/backend.py')
-rw-r--r--tensorflow/python/keras/_impl/keras/backend.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py
index 1fa264660d..a238a3f748 100644
--- a/tensorflow/python/keras/_impl/keras/backend.py
+++ b/tensorflow/python/keras/_impl/keras/backend.py
@@ -2895,6 +2895,7 @@ def rnn(step_function,
ndim = len(inputs.get_shape())
if ndim < 3:
raise ValueError('Input should be at least 3D.')
+ inputs_shape = inputs.get_shape()
axes = [1, 0] + list(range(2, ndim))
inputs = array_ops.transpose(inputs, (axes))
@@ -3079,6 +3080,13 @@ def rnn(step_function,
axes = [1, 0] + list(range(2, len(outputs.get_shape())))
outputs = array_ops.transpose(outputs, axes)
+
+ # Static shape inference: (samples, time, ...)
+ outputs_shape = outputs.get_shape().as_list()
+ outputs_shape[0] = inputs_shape[0]
+ outputs_shape[1] = inputs_shape[1]
+ outputs.set_shape(outputs_shape)
+
last_output._uses_learning_phase = uses_learning_phase
return last_output, outputs, new_states