diff options
-rw-r--r-- | tensorflow/python/keras/layers/core.py | 51 | ||||
-rw-r--r-- | tensorflow/python/keras/layers/core_test.py | 45 |
2 files changed, 82 insertions, 14 deletions
diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index 4032202986..efa21955e6 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -671,22 +671,34 @@ class Lambda(Layer): if mask is not None: self.supports_masking = True self.mask = mask - if output_shape is None: - self._output_shape = None - elif isinstance(output_shape, (tuple, list)): - self._output_shape = tuple(output_shape) - else: - if not callable(output_shape): - raise TypeError('In Lambda, `output_shape` ' - 'must be a list, a tuple, or a function.') - self._output_shape = output_shape + if (output_shape is not None and not isinstance(output_shape, + (tuple, list)) and + not callable(output_shape)): + raise TypeError('In Lambda, `output_shape` ' + 'must be a list, a tuple, or a function.') + # Convert a list representing a single shape into a tuple. + if (isinstance(output_shape, list) and isinstance(output_shape[0], + (int, type(None)))): + output_shape = tuple(output_shape) + self._output_shape = output_shape @tf_utils.shape_type_conversion def compute_output_shape(self, input_shape): if self._output_shape is None: if context.executing_eagerly(): - raise NotImplementedError - x = K.placeholder(shape=input_shape) + # Make use of existing autocomputation for Eager mode but provide + # Lambda-specific error message. + try: + return super(Lambda, self).compute_output_shape(input_shape) + except NotImplementedError: + raise NotImplementedError('We could not automatically infer ' + 'the static shape of the Lambda\'s output.' + ' Please specify the `output_shape` for' + ' this Lambda.') + if isinstance(input_shape, list): + x = [K.placeholder(shape=shape) for shape in input_shape] + else: + x = K.placeholder(shape=input_shape) x = self.call(x) if isinstance(x, list): return [tensor_shape.TensorShape(K.int_shape(x_elem)) for x_elem in x] @@ -697,16 +709,27 @@ class Lambda(Layer): num_samples = input_shape[0][0] else: num_samples = input_shape[0] if input_shape else None - return tensor_shape.TensorShape((num_samples,) + - tuple(self._output_shape)) + # List here represents multiple outputs. + if isinstance(self._output_shape, list): + return [ + tensor_shape.TensorShape((num_samples,) + tuple(single_shape)) + for single_shape in self._output_shape + ] + return tensor_shape.TensorShape((num_samples,) + self._output_shape) else: shape = self._output_shape(input_shape) if not isinstance(shape, (list, tuple)): raise ValueError( '`output_shape` function must return a tuple or a list of tuples.') + # List here can represent multiple outputs or single output. if isinstance(shape, list): - if isinstance(shape[0], int) or shape[0] is None: + # Convert list representing single output into a tuple. + if isinstance(shape[0], (int, type(None))): shape = tuple(shape) + else: + return [ + tensor_shape.TensorShape(single_shape) for single_shape in shape + ] return tensor_shape.TensorShape(shape) def call(self, inputs, mask=None): diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py index 1df1d575b1..f0fea1f65c 100644 --- a/tensorflow/python/keras/layers/core_test.py +++ b/tensorflow/python/keras/layers/core_test.py @@ -252,6 +252,51 @@ class CoreLayersTest(test.TestCase): l(keras.backend.variable(np.ones((1, 1)))) self.assertEqual('lambda', l.get_config()['output_shape_type']) + @tf_test_util.run_in_graph_and_eager_modes + def test_lambda_output_shape_autocalculate_multiple_inputs(self): + + def lambda_fn(x): + return math_ops.matmul(x[0], x[1]) + + l = keras.layers.Lambda(lambda_fn) + output_shape = l.compute_output_shape([(10, 10), (10, 20)]) + self.assertAllEqual((10, 20), output_shape) + + @tf_test_util.run_in_graph_and_eager_modes + def test_lambda_output_shape_list_multiple_outputs(self): + + def lambda_fn(x): + return x + + l = keras.layers.Lambda(lambda_fn, output_shape=[(10,), (20,)]) + output_shape = l.compute_output_shape([(10, 10), (10, 20)]) + self.assertAllEqual([(10, 10), (10, 20)], output_shape) + + @tf_test_util.run_in_graph_and_eager_modes + def test_lambda_output_shape_tuple_with_none(self): + + def lambda_fn(x): + return x + + l = keras.layers.Lambda(lambda_fn, output_shape=(None, 10)) + output_shape = l.compute_output_shape((5, 10, 20)) + # Dimension(None) != Dimension(None), so check + # str representations for equality. + self.assertAllEqual(('5', '?', '10'), tuple([str(s) for s in output_shape])) + + @tf_test_util.run_in_graph_and_eager_modes + def test_lambda_output_shape_function_multiple_outputs(self): + + def lambda_fn(x): + return x + + def output_shape_fn(input_shape): + return input_shape + + l = keras.layers.Lambda(lambda_fn, output_shape=output_shape_fn) + output_shape = l.compute_output_shape([(10, 10), (10, 20)]) + self.assertAllEqual([(10, 10), (10, 20)], output_shape) + def test_lambda_config_serialization(self): with self.cached_session(): # test serialization with output_shape and output_shape_type |