diff options
author | 2018-09-27 20:02:51 -0700 | |
---|---|---|
committer | 2018-09-27 20:10:18 -0700 | |
commit | a309e136dcfdd13dc8e8eb7570b6c5945bb6f967 (patch) | |
tree | f7d28106e0ffb3523b5f1ea3b557030c10cf4a04 /tensorflow/python/keras/layers/core_test.py | |
parent | acb13e448786838feb500973f51279dc90eeab50 (diff) |
Keras Lambda - enhancements to output_shape computation
PiperOrigin-RevId: 214878428
Diffstat (limited to 'tensorflow/python/keras/layers/core_test.py')
-rw-r--r-- | tensorflow/python/keras/layers/core_test.py | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py index 1df1d575b1..f0fea1f65c 100644 --- a/tensorflow/python/keras/layers/core_test.py +++ b/tensorflow/python/keras/layers/core_test.py @@ -252,6 +252,51 @@ class CoreLayersTest(test.TestCase): l(keras.backend.variable(np.ones((1, 1)))) self.assertEqual('lambda', l.get_config()['output_shape_type']) + @tf_test_util.run_in_graph_and_eager_modes + def test_lambda_output_shape_autocalculate_multiple_inputs(self): + + def lambda_fn(x): + return math_ops.matmul(x[0], x[1]) + + l = keras.layers.Lambda(lambda_fn) + output_shape = l.compute_output_shape([(10, 10), (10, 20)]) + self.assertAllEqual((10, 20), output_shape) + + @tf_test_util.run_in_graph_and_eager_modes + def test_lambda_output_shape_list_multiple_outputs(self): + + def lambda_fn(x): + return x + + l = keras.layers.Lambda(lambda_fn, output_shape=[(10,), (20,)]) + output_shape = l.compute_output_shape([(10, 10), (10, 20)]) + self.assertAllEqual([(10, 10), (10, 20)], output_shape) + + @tf_test_util.run_in_graph_and_eager_modes + def test_lambda_output_shape_tuple_with_none(self): + + def lambda_fn(x): + return x + + l = keras.layers.Lambda(lambda_fn, output_shape=(None, 10)) + output_shape = l.compute_output_shape((5, 10, 20)) + # Dimension(None) != Dimension(None), so check + # str representations for equality. + self.assertAllEqual(('5', '?', '10'), tuple([str(s) for s in output_shape])) + + @tf_test_util.run_in_graph_and_eager_modes + def test_lambda_output_shape_function_multiple_outputs(self): + + def lambda_fn(x): + return x + + def output_shape_fn(input_shape): + return input_shape + + l = keras.layers.Lambda(lambda_fn, output_shape=output_shape_fn) + output_shape = l.compute_output_shape([(10, 10), (10, 20)]) + self.assertAllEqual([(10, 10), (10, 20)], output_shape) + def test_lambda_config_serialization(self): with self.cached_session(): # test serialization with output_shape and output_shape_type |