aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2018-09-20 15:08:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 15:13:52 -0700
commit1d1ec99bd3b322ea35a2d3d0eb754589ec2fd512 (patch)
tree20ed07fc3a996e36d89a13c05971297018b145c4 /tensorflow/python/keras
parent424f0556ad8acde8f912a67e46421957a71dcef2 (diff)
Add more specific ReLU implementation tests.
PiperOrigin-RevId: 213890403
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r--tensorflow/python/keras/layers/advanced_activations.py4
-rw-r--r--tensorflow/python/keras/layers/advanced_activations_test.py8
2 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py
index 4ab786a184..a2385dfdbb 100644
--- a/tensorflow/python/keras/layers/advanced_activations.py
+++ b/tensorflow/python/keras/layers/advanced_activations.py
@@ -314,7 +314,9 @@ class ReLU(Layer):
'cannot be negative value: ' + str(negative_slope))
self.support_masking = True
- self.max_value = K.cast_to_floatx(max_value)
+ if max_value is not None:
+ max_value = K.cast_to_floatx(max_value)
+ self.max_value = max_value
self.negative_slope = K.cast_to_floatx(negative_slope)
self.threshold = K.cast_to_floatx(threshold)
diff --git a/tensorflow/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py
index b020b6e730..c41087be0a 100644
--- a/tensorflow/python/keras/layers/advanced_activations_test.py
+++ b/tensorflow/python/keras/layers/advanced_activations_test.py
@@ -67,6 +67,14 @@ class AdvancedActivationsTest(test.TestCase):
testing_utils.layer_test(keras.layers.ReLU,
kwargs={'max_value': 10},
input_shape=(2, 3, 4))
+ x = keras.backend.ones((3, 4))
+ # Test that we use `leaky_relu` when appropriate in graph mode.
+ self.assertTrue(
+ 'LeakyRelu' in keras.layers.ReLU(negative_slope=0.2)(x).name)
+ # Test that we use `relu` when appropriate in graph mode.
+ self.assertTrue('Relu' in keras.layers.ReLU()(x).name)
+ # Test that we use `relu6` when appropriate in graph mode.
+ self.assertTrue('Relu6' in keras.layers.ReLU(max_value=6)(x).name)
def test_relu_with_invalid_arg(self):
with self.assertRaisesRegexp(