aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/backend.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-24 17:07:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-24 17:11:24 -0700
commit0cf2c612e5e6ff8c5026011e8186056801def747 (patch)
tree12793c95ad2aaa21bcddb466904ebd936feb326b /tensorflow/python/keras/backend.py
parent4c161d7306eb934232e3fe65de2c31c3bb7cf875 (diff)
Keras ReLU Consolidation
Consolidate functionality of ThresholdedReLU and LeakyReLU layers into ReLU layer PiperOrigin-RevId: 205917439
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r--tensorflow/python/keras/backend.py32
1 files changed, 27 insertions, 5 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 333f927d2f..38794f1612 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -3372,26 +3372,48 @@ def in_test_phase(x, alt, training=None):
@tf_export('keras.backend.relu')
-def relu(x, alpha=0., max_value=None):
+def relu(x, alpha=0., max_value=None, threshold=0):
"""Rectified linear unit.
With default values, it returns element-wise `max(x, 0)`.
+ Otherwise, it follows:
+ `f(x) = max_value` for `x >= max_value`,
+ `f(x) = x` for `threshold <= x < max_value`,
+ `f(x) = alpha * (x - threshold)` otherwise.
+
Arguments:
x: A tensor or variable.
alpha: A scalar, slope of negative section (default=`0.`).
- max_value: Saturation threshold.
+ max_value: float. Saturation threshold.
+ threshold: float. Threshold value for thresholded activation.
Returns:
A tensor.
"""
+ clip_max = max_value is not None
+
if alpha != 0.:
- negative_part = nn.relu(-x)
- x = nn.relu(x)
- if max_value is not None:
+ if threshold != 0:
+ negative_part = nn.relu(-x + threshold)
+ else:
+ negative_part = nn.relu(-x)
+
+ if threshold != 0:
+ # computes x for x > threshold else 0
+ x = x * math_ops.cast(math_ops.greater(x, threshold), floatx())
+ elif max_value == 6:
+ # if no threshold, then can use nn.relu6 native TF op for performance
+ x = nn.relu6(x)
+ clip_max = False
+ else:
+ x = nn.relu(x)
+
+ if clip_max:
max_value = _to_tensor(max_value, x.dtype.base_dtype)
zero = _to_tensor(0., x.dtype.base_dtype)
x = clip_ops.clip_by_value(x, zero, max_value)
+
if alpha != 0.:
alpha = _to_tensor(alpha, x.dtype.base_dtype)
x -= alpha * negative_part