aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/layers/core_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-27 20:02:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 20:10:18 -0700
commita309e136dcfdd13dc8e8eb7570b6c5945bb6f967 (patch)
treef7d28106e0ffb3523b5f1ea3b557030c10cf4a04 /tensorflow/python/keras/layers/core_test.py
parentacb13e448786838feb500973f51279dc90eeab50 (diff)
Keras Lambda - enhancements to output_shape computation
PiperOrigin-RevId: 214878428
Diffstat (limited to 'tensorflow/python/keras/layers/core_test.py')
-rw-r--r--tensorflow/python/keras/layers/core_test.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py
index 1df1d575b1..f0fea1f65c 100644
--- a/tensorflow/python/keras/layers/core_test.py
+++ b/tensorflow/python/keras/layers/core_test.py
@@ -252,6 +252,51 @@ class CoreLayersTest(test.TestCase):
l(keras.backend.variable(np.ones((1, 1))))
self.assertEqual('lambda', l.get_config()['output_shape_type'])
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_autocalculate_multiple_inputs(self):
+
+ def lambda_fn(x):
+ return math_ops.matmul(x[0], x[1])
+
+ l = keras.layers.Lambda(lambda_fn)
+ output_shape = l.compute_output_shape([(10, 10), (10, 20)])
+ self.assertAllEqual((10, 20), output_shape)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_list_multiple_outputs(self):
+
+ def lambda_fn(x):
+ return x
+
+ l = keras.layers.Lambda(lambda_fn, output_shape=[(10,), (20,)])
+ output_shape = l.compute_output_shape([(10, 10), (10, 20)])
+ self.assertAllEqual([(10, 10), (10, 20)], output_shape)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_tuple_with_none(self):
+
+ def lambda_fn(x):
+ return x
+
+ l = keras.layers.Lambda(lambda_fn, output_shape=(None, 10))
+ output_shape = l.compute_output_shape((5, 10, 20))
+ # Dimension(None) != Dimension(None), so check
+ # str representations for equality.
+ self.assertAllEqual(('5', '?', '10'), tuple([str(s) for s in output_shape]))
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_function_multiple_outputs(self):
+
+ def lambda_fn(x):
+ return x
+
+ def output_shape_fn(input_shape):
+ return input_shape
+
+ l = keras.layers.Lambda(lambda_fn, output_shape=output_shape_fn)
+ output_shape = l.compute_output_shape([(10, 10), (10, 20)])
+ self.assertAllEqual([(10, 10), (10, 20)], output_shape)
+
def test_lambda_config_serialization(self):
with self.cached_session():
# test serialization with output_shape and output_shape_type