diff options
-rw-r--r-- | tensorflow/python/keras/_impl/keras/backend.py | 8 | ||||
-rw-r--r-- | tensorflow/python/keras/_impl/keras/backend_test.py | 9 | ||||
-rw-r--r-- | tensorflow/python/keras/_impl/keras/layers/lstm_test.py | 17 |
3 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py index 098ea063f9..3b8023e938 100644 --- a/tensorflow/python/keras/_impl/keras/backend.py +++ b/tensorflow/python/keras/_impl/keras/backend.py @@ -2781,6 +2781,7 @@ def rnn(step_function, ndim = len(inputs.get_shape()) if ndim < 3: raise ValueError('Input should be at least 3D.') + inputs_shape = inputs.get_shape() axes = [1, 0] + list(range(2, ndim)) inputs = array_ops.transpose(inputs, (axes)) @@ -2965,6 +2966,13 @@ def rnn(step_function, axes = [1, 0] + list(range(2, len(outputs.get_shape()))) outputs = array_ops.transpose(outputs, axes) + + # Static shape inference: (samples, time, ...) + outputs_shape = outputs.get_shape().as_list() + outputs_shape[0] = inputs_shape[0] + outputs_shape[1] = inputs_shape[1] + outputs.set_shape(outputs_shape) + last_output._uses_learning_phase = uses_learning_phase return last_output, outputs, new_states diff --git a/tensorflow/python/keras/_impl/keras/backend_test.py b/tensorflow/python/keras/_impl/keras/backend_test.py index 27833e368d..f29ca49378 100644 --- a/tensorflow/python/keras/_impl/keras/backend_test.py +++ b/tensorflow/python/keras/_impl/keras/backend_test.py @@ -915,6 +915,15 @@ class BackendNNOpsTest(test.TestCase): last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs, initial_states, **kwargs) + # check static shape inference + self.assertEquals(last_output.get_shape().as_list(), + [num_samples, output_dim]) + self.assertEquals(outputs.get_shape().as_list(), + [num_samples, timesteps, output_dim]) + for state in new_states: + self.assertEquals(state.get_shape().as_list(), + [num_samples, output_dim]) + last_output_list[i].append(keras.backend.eval(last_output)) outputs_list[i].append(keras.backend.eval(outputs)) self.assertEqual(len(new_states), 1) diff --git a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py index 8d359bf17c..deb1d7c0c6 100644 --- a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py @@ -39,6 +39,23 @@ class LSTMLayerTest(test.TestCase): 'return_sequences': True}, input_shape=(num_samples, timesteps, embedding_dim)) + def test_static_shape_inference_LSTM(self): + # Github issue: 15165 + num_samples = 2 + timesteps = 3 + embedding_dim = 4 + units = 2 + + model = keras.models.Sequential() + inputs = keras.layers.Dense(embedding_dim, + input_shape=(timesteps, embedding_dim)) + model.add(inputs) + layer = keras.layers.LSTM(units, return_sequences=True) + model.add(layer) + outputs = model.layers[-1].output + self.assertEquals(outputs.get_shape().as_list(), + [None, timesteps, units]) + def test_dynamic_behavior_LSTM(self): num_samples = 2 timesteps = 3 |