aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/keras/_impl/keras/backend.py8
-rw-r--r--tensorflow/python/keras/_impl/keras/backend_test.py9
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/lstm_test.py17
3 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py
index 098ea063f9..3b8023e938 100644
--- a/tensorflow/python/keras/_impl/keras/backend.py
+++ b/tensorflow/python/keras/_impl/keras/backend.py
@@ -2781,6 +2781,7 @@ def rnn(step_function,
ndim = len(inputs.get_shape())
if ndim < 3:
raise ValueError('Input should be at least 3D.')
+ inputs_shape = inputs.get_shape()
axes = [1, 0] + list(range(2, ndim))
inputs = array_ops.transpose(inputs, (axes))
@@ -2965,6 +2966,13 @@ def rnn(step_function,
axes = [1, 0] + list(range(2, len(outputs.get_shape())))
outputs = array_ops.transpose(outputs, axes)
+
+ # Static shape inference: (samples, time, ...)
+ outputs_shape = outputs.get_shape().as_list()
+ outputs_shape[0] = inputs_shape[0]
+ outputs_shape[1] = inputs_shape[1]
+ outputs.set_shape(outputs_shape)
+
last_output._uses_learning_phase = uses_learning_phase
return last_output, outputs, new_states
diff --git a/tensorflow/python/keras/_impl/keras/backend_test.py b/tensorflow/python/keras/_impl/keras/backend_test.py
index 27833e368d..f29ca49378 100644
--- a/tensorflow/python/keras/_impl/keras/backend_test.py
+++ b/tensorflow/python/keras/_impl/keras/backend_test.py
@@ -915,6 +915,15 @@ class BackendNNOpsTest(test.TestCase):
last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs,
initial_states,
**kwargs)
+ # check static shape inference
+ self.assertEquals(last_output.get_shape().as_list(),
+ [num_samples, output_dim])
+ self.assertEquals(outputs.get_shape().as_list(),
+ [num_samples, timesteps, output_dim])
+ for state in new_states:
+ self.assertEquals(state.get_shape().as_list(),
+ [num_samples, output_dim])
+
last_output_list[i].append(keras.backend.eval(last_output))
outputs_list[i].append(keras.backend.eval(outputs))
self.assertEqual(len(new_states), 1)
diff --git a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py
index 8d359bf17c..deb1d7c0c6 100644
--- a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py
@@ -39,6 +39,23 @@ class LSTMLayerTest(test.TestCase):
'return_sequences': True},
input_shape=(num_samples, timesteps, embedding_dim))
+ def test_static_shape_inference_LSTM(self):
+ # Github issue: 15165
+ num_samples = 2
+ timesteps = 3
+ embedding_dim = 4
+ units = 2
+
+ model = keras.models.Sequential()
+ inputs = keras.layers.Dense(embedding_dim,
+ input_shape=(timesteps, embedding_dim))
+ model.add(inputs)
+ layer = keras.layers.LSTM(units, return_sequences=True)
+ model.add(layer)
+ outputs = model.layers[-1].output
+ self.assertEquals(outputs.get_shape().as_list(),
+ [None, timesteps, units])
+
def test_dynamic_behavior_LSTM(self):
num_samples = 2
timesteps = 3