aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/nn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-06 14:53:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-06 14:58:07 -0700
commiteb1a0a5294b9b7b209d419b4113fb57d6443b45f (patch)
treeed8dbb4d906d26dd3fa32adb4b57cd07cf0f2809 /tensorflow/contrib/nn
parentbbfef93661ebf8ec23c7b9ad920313be9898bbbc (diff)
(1) Adds broadcasting to scaled_softplus
(2) Adds the ability to clip (so we can get a soft version of relu6) PiperOrigin-RevId: 171347879
Diffstat (limited to 'tensorflow/contrib/nn')
-rw-r--r--tensorflow/contrib/nn/python/ops/scaled_softplus.py82
-rw-r--r--tensorflow/contrib/nn/python/ops/scaled_softplus_test.py23
2 files changed, 77 insertions, 28 deletions
diff --git a/tensorflow/contrib/nn/python/ops/scaled_softplus.py b/tensorflow/contrib/nn/python/ops/scaled_softplus.py
index 5fc11d8ec6..fcbfbc239c 100644
--- a/tensorflow/contrib/nn/python/ops/scaled_softplus.py
+++ b/tensorflow/contrib/nn/python/ops/scaled_softplus.py
@@ -20,58 +20,96 @@ from __future__ import print_function
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
-def scaled_softplus(x, alpha, name=None):
- """Returns `alpha * ln(1 + exp(x / alpha))`, for scalar `alpha > 0`.
+def _reduce_and_reshape_grad(g, t):
+ """Returns the gradient, sum-reduced and reshaped to `t`'s shape."""
+ shape = array_ops.shape(t)
+ g_shape = array_ops.shape(g)
+ # pylint: disable=protected-access
+ bcast_dims, _ = gen_array_ops._broadcast_gradient_args(shape, g_shape)
+ # pylint: enable=protected-access
+ return array_ops.reshape(math_ops.reduce_sum(g, bcast_dims), shape)
+
+
+def scaled_softplus(x, alpha, clip=None, name=None):
+ """Returns `y = alpha * ln(1 + exp(x / alpha))` or `min(y, clip)`.
This can be seen as a softplus applied to the scaled input, with the output
appropriately scaled. As `alpha` tends to 0, `scaled_softplus(x, alpha)` tends
- to `relu(x)`.
+ to `relu(x)`. The clipping is optional. As alpha->0, scaled_softplus(x, alpha)
+ tends to relu(x), and scaled_softplus(x, alpha, clip=6) tends to relu6(x).
Note: the gradient for this operation is defined to depend on the backprop
inputs as well as the outputs of this operation.
Args:
x: A `Tensor` of inputs.
- alpha: A scalar `Tensor`, indicating the amount of smoothness. The caller
+ alpha: A `Tensor`, indicating the amount of smoothness. The caller
must ensure that `alpha > 0`.
+ clip: (optional) A `Tensor`, the upper bound to clip the values.
name: A name for the scope of the operations (optional).
Returns:
- A tensor of same size and type as `x`.
+ A tensor of the size and type determined by broadcasting of the inputs.
"""
- with ops.name_scope(name, 'scaled_softplus', [x, alpha]):
+ clipping = clip is not None
+ with ops.name_scope(name, 'scaled_softplus',
+ [x, alpha] + ([clip] if clipping else [])):
x = ops.convert_to_tensor(x, name='x')
dtype = x.dtype
alpha = ops.convert_to_tensor(alpha, dtype=dtype, name='alpha')
- # Verify that alpha is a scalar.
- alpha.get_shape().assert_has_rank(0)
+ # Compute the forward value.
+ y = alpha * nn.softplus(x / alpha)
+ if clipping:
+ clip = ops.convert_to_tensor(clip, dtype=dtype, name='clip')
+ y = math_ops.minimum(y, clip)
def _grad(op, g):
- """Backprop for scaled softplus."""
- y = op.outputs[0]
- alpha = op.inputs[1]
- # Prevent the expensive computations from happening before g is available.
+ """Backprop for scaled softplus, with optional clipping."""
+ y, x, alpha = op.inputs[:3]
+ # Prevent the memory-expensive computations from happening before g is
+ # available.
with ops.control_dependencies([g]):
- y /= alpha
+ y = array_ops.identity(y)
+ clip_grad = []
+ if clipping:
+ clip = op.inputs[3]
+ unclipped = math_ops.cast(y < clip, g.dtype)
+ clip_grad = [_reduce_and_reshape_grad(g * (1. - unclipped), clip)]
+ g *= unclipped
+ y /= alpha
emy = math_ops.exp(-y)
dy_dx = 1. - emy
# The eps below avoids log(0). Note that t*log(t) -> 0 as t->0.
eps = 1e-8
dy_dalpha = y * emy - dy_dx * math_ops.log(dy_dx + eps)
- return g * dy_dx, math_ops.reduce_sum(g * dy_dalpha)
+ # Backprop to the actual inputs, but not to the output.
+ return [None,
+ _reduce_and_reshape_grad(g * dy_dx, x),
+ _reduce_and_reshape_grad(g * dy_dalpha, alpha)] + clip_grad
- @function.Defun(dtype, dtype,
- func_name='ScaledSoftplus_%s' % dtype.name,
- shape_func=lambda op: [op.inputs[0].get_shape()],
+ if clipping:
+ @function.Defun(dtype, dtype, dtype, dtype,
+ func_name='ScaledSoftplusHelper_clip_%s' % dtype.name,
+ shape_func=lambda op: [op.inputs[0].shape],
+ python_grad_func=_grad)
+ def _forward_helper_clip(y, x, alpha, clip):
+ del x, alpha, clip # Unused.
+ return y
+ return _forward_helper_clip(y, x, alpha, clip)
+ # No clipping.
+ @function.Defun(dtype, dtype, dtype,
+ func_name='ScaledSoftplusHelper_%s' % dtype.name,
+ shape_func=lambda op: [op.inputs[0].shape],
python_grad_func=_grad)
- def _forward(x, alpha):
- """Forward computation of scaled softplus."""
- return alpha * nn.softplus(x / alpha)
-
- return _forward(x, alpha)
+ def _forward_helper(y, x, alpha):
+ del x, alpha # Unused.
+ return y
+ return _forward_helper(y, x, alpha)
diff --git a/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py b/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py
index 3a459330ce..b978343c6a 100644
--- a/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py
+++ b/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py
@@ -33,10 +33,11 @@ class ScaledSoftplusTest(test.TestCase):
x = np.random.randn(3, 4).astype(np.float32)
x64 = np.random.randn(3, 4).astype(np.float64)
alpha = np.random.rand() + 0.01
- y = alpha * np.log(1. + np.exp(x / alpha))
+ clip = np.float32(0.1)
+ y = np.minimum(alpha * np.log(1. + np.exp(x / alpha)), clip)
y64 = alpha * np.log(1. + np.exp(x64 / alpha))
with self.test_session(use_gpu=True) as sess:
- z = scaled_softplus(constant_op.constant(x), alpha)
+ z = scaled_softplus(constant_op.constant(x), alpha, clip)
z64 = scaled_softplus(constant_op.constant(x64), alpha)
z, z64 = sess.run([z, z64])
eps = 1e-6
@@ -47,18 +48,28 @@ class ScaledSoftplusTest(test.TestCase):
np.random.seed(1) # Make it reproducible.
x_shape = [5, 10]
x_np = np.random.randn(*x_shape).astype(np.float32)
- alpha_np = np.float32(np.random.rand() + 0.01)
+ alpha_np = np.float32(np.random.rand(1, x_shape[1]) + 0.01)
+ clip_np = np.float32(np.random.rand(x_shape[0], 1) * 5.)
with self.test_session(use_gpu=True):
x_tf = constant_op.constant(x_np)
alpha_tf = constant_op.constant(alpha_np)
+ clip_tf = constant_op.constant(clip_np)
y_tf = scaled_softplus(x_tf, alpha_tf)
+ z_tf = scaled_softplus(x_tf, alpha_tf, clip_tf * 0.1)
err = gradient_checker.compute_gradient_error([x_tf, alpha_tf],
- [x_shape, []],
+ [x_shape, alpha_np.shape],
y_tf, x_shape,
[x_np, alpha_np],
- delta=1e-2)
- eps = 1e-4
+ delta=0.002)
+ err_clip = gradient_checker.compute_gradient_error(
+ [x_tf, alpha_tf, clip_tf],
+ [x_shape, alpha_np.shape, clip_np.shape],
+ z_tf, x_shape,
+ [x_np, alpha_np, clip_np],
+ delta=0.002)
+ eps = 2e-4
self.assertLess(err, eps)
+ self.assertLess(err_clip, eps)
if __name__ == '__main__':