diff options
author | Francois Chollet <fchollet@google.com> | 2018-09-20 15:08:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-20 15:13:52 -0700 |
commit | 1d1ec99bd3b322ea35a2d3d0eb754589ec2fd512 (patch) | |
tree | 20ed07fc3a996e36d89a13c05971297018b145c4 /tensorflow/python/keras | |
parent | 424f0556ad8acde8f912a67e46421957a71dcef2 (diff) |
Add more specific ReLU implementation tests.
PiperOrigin-RevId: 213890403
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r-- | tensorflow/python/keras/layers/advanced_activations.py | 4 | ||||
-rw-r--r-- | tensorflow/python/keras/layers/advanced_activations_test.py | 8 |
2 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py index 4ab786a184..a2385dfdbb 100644 --- a/tensorflow/python/keras/layers/advanced_activations.py +++ b/tensorflow/python/keras/layers/advanced_activations.py @@ -314,7 +314,9 @@ class ReLU(Layer): 'cannot be negative value: ' + str(negative_slope)) self.support_masking = True - self.max_value = K.cast_to_floatx(max_value) + if max_value is not None: + max_value = K.cast_to_floatx(max_value) + self.max_value = max_value self.negative_slope = K.cast_to_floatx(negative_slope) self.threshold = K.cast_to_floatx(threshold) diff --git a/tensorflow/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py index b020b6e730..c41087be0a 100644 --- a/tensorflow/python/keras/layers/advanced_activations_test.py +++ b/tensorflow/python/keras/layers/advanced_activations_test.py @@ -67,6 +67,14 @@ class AdvancedActivationsTest(test.TestCase): testing_utils.layer_test(keras.layers.ReLU, kwargs={'max_value': 10}, input_shape=(2, 3, 4)) + x = keras.backend.ones((3, 4)) + # Test that we use `leaky_relu` when appropriate in graph mode. + self.assertTrue( + 'LeakyRelu' in keras.layers.ReLU(negative_slope=0.2)(x).name) + # Test that we use `relu` when appropriate in graph mode. + self.assertTrue('Relu' in keras.layers.ReLU()(x).name) + # Test that we use `relu6` when appropriate in graph mode. + self.assertTrue('Relu6' in keras.layers.ReLU(max_value=6)(x).name) def test_relu_with_invalid_arg(self): with self.assertRaisesRegexp( |