aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-27 20:02:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 20:10:18 -0700
commita309e136dcfdd13dc8e8eb7570b6c5945bb6f967 (patch)
treef7d28106e0ffb3523b5f1ea3b557030c10cf4a04 /tensorflow/python/keras
parentacb13e448786838feb500973f51279dc90eeab50 (diff)
Keras Lambda - enhancements to output_shape computation
PiperOrigin-RevId: 214878428
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r--tensorflow/python/keras/layers/core.py51
-rw-r--r--tensorflow/python/keras/layers/core_test.py45
2 files changed, 82 insertions, 14 deletions
diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py
index 4032202986..efa21955e6 100644
--- a/tensorflow/python/keras/layers/core.py
+++ b/tensorflow/python/keras/layers/core.py
@@ -671,22 +671,34 @@ class Lambda(Layer):
if mask is not None:
self.supports_masking = True
self.mask = mask
- if output_shape is None:
- self._output_shape = None
- elif isinstance(output_shape, (tuple, list)):
- self._output_shape = tuple(output_shape)
- else:
- if not callable(output_shape):
- raise TypeError('In Lambda, `output_shape` '
- 'must be a list, a tuple, or a function.')
- self._output_shape = output_shape
+ if (output_shape is not None and not isinstance(output_shape,
+ (tuple, list)) and
+ not callable(output_shape)):
+ raise TypeError('In Lambda, `output_shape` '
+ 'must be a list, a tuple, or a function.')
+ # Convert a list representing a single shape into a tuple.
+ if (isinstance(output_shape, list) and isinstance(output_shape[0],
+ (int, type(None)))):
+ output_shape = tuple(output_shape)
+ self._output_shape = output_shape
@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if self._output_shape is None:
if context.executing_eagerly():
- raise NotImplementedError
- x = K.placeholder(shape=input_shape)
+ # Make use of existing autocomputation for Eager mode but provide
+ # Lambda-specific error message.
+ try:
+ return super(Lambda, self).compute_output_shape(input_shape)
+ except NotImplementedError:
+ raise NotImplementedError('We could not automatically infer '
+ 'the static shape of the Lambda\'s output.'
+ ' Please specify the `output_shape` for'
+ ' this Lambda.')
+ if isinstance(input_shape, list):
+ x = [K.placeholder(shape=shape) for shape in input_shape]
+ else:
+ x = K.placeholder(shape=input_shape)
x = self.call(x)
if isinstance(x, list):
return [tensor_shape.TensorShape(K.int_shape(x_elem)) for x_elem in x]
@@ -697,16 +709,27 @@ class Lambda(Layer):
num_samples = input_shape[0][0]
else:
num_samples = input_shape[0] if input_shape else None
- return tensor_shape.TensorShape((num_samples,) +
- tuple(self._output_shape))
+ # List here represents multiple outputs.
+ if isinstance(self._output_shape, list):
+ return [
+ tensor_shape.TensorShape((num_samples,) + tuple(single_shape))
+ for single_shape in self._output_shape
+ ]
+ return tensor_shape.TensorShape((num_samples,) + self._output_shape)
else:
shape = self._output_shape(input_shape)
if not isinstance(shape, (list, tuple)):
raise ValueError(
'`output_shape` function must return a tuple or a list of tuples.')
+ # List here can represent multiple outputs or single output.
if isinstance(shape, list):
- if isinstance(shape[0], int) or shape[0] is None:
+ # Convert list representing single output into a tuple.
+ if isinstance(shape[0], (int, type(None))):
shape = tuple(shape)
+ else:
+ return [
+ tensor_shape.TensorShape(single_shape) for single_shape in shape
+ ]
return tensor_shape.TensorShape(shape)
def call(self, inputs, mask=None):
diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py
index 1df1d575b1..f0fea1f65c 100644
--- a/tensorflow/python/keras/layers/core_test.py
+++ b/tensorflow/python/keras/layers/core_test.py
@@ -252,6 +252,51 @@ class CoreLayersTest(test.TestCase):
l(keras.backend.variable(np.ones((1, 1))))
self.assertEqual('lambda', l.get_config()['output_shape_type'])
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_autocalculate_multiple_inputs(self):
+
+ def lambda_fn(x):
+ return math_ops.matmul(x[0], x[1])
+
+ l = keras.layers.Lambda(lambda_fn)
+ output_shape = l.compute_output_shape([(10, 10), (10, 20)])
+ self.assertAllEqual((10, 20), output_shape)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_list_multiple_outputs(self):
+
+ def lambda_fn(x):
+ return x
+
+ l = keras.layers.Lambda(lambda_fn, output_shape=[(10,), (20,)])
+ output_shape = l.compute_output_shape([(10, 10), (10, 20)])
+ self.assertAllEqual([(10, 10), (10, 20)], output_shape)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_tuple_with_none(self):
+
+ def lambda_fn(x):
+ return x
+
+ l = keras.layers.Lambda(lambda_fn, output_shape=(None, 10))
+ output_shape = l.compute_output_shape((5, 10, 20))
+ # Dimension(None) != Dimension(None), so check
+ # str representations for equality.
+ self.assertAllEqual(('5', '?', '10'), tuple([str(s) for s in output_shape]))
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_function_multiple_outputs(self):
+
+ def lambda_fn(x):
+ return x
+
+ def output_shape_fn(input_shape):
+ return input_shape
+
+ l = keras.layers.Lambda(lambda_fn, output_shape=output_shape_fn)
+ output_shape = l.compute_output_shape([(10, 10), (10, 20)])
+ self.assertAllEqual([(10, 10), (10, 20)], output_shape)
+
def test_lambda_config_serialization(self):
with self.cached_session():
# test serialization with output_shape and output_shape_type