aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/nn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-28 11:53:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-28 11:59:04 -0700
commit0254d0d31337724db911c89609336afd60e8192d (patch)
treefb167b3647a9b2030173387831b94eece9a59fdd /tensorflow/contrib/nn
parent996a85d436a0f45d5bfdaad2946cef12f70883eb (diff)
Adds tf.contrib.nn.scaled_softplus(x, alpha) = alpha * softplus(x/alpha). This can be thought of as a smoothed version of a ReLU. On Imagenet, alpha=0.3 gives 0.6-1% improvement in validation accuracy compared to ReLU, by reducing the generalization gap.
PiperOrigin-RevId: 170376244
Diffstat (limited to 'tensorflow/contrib/nn')
-rw-r--r--tensorflow/contrib/nn/BUILD26
-rw-r--r--tensorflow/contrib/nn/__init__.py3
-rw-r--r--tensorflow/contrib/nn/python/ops/scaled_softplus.py77
-rw-r--r--tensorflow/contrib/nn/python/ops/scaled_softplus_test.py67
4 files changed, 167 insertions, 6 deletions
diff --git a/tensorflow/contrib/nn/BUILD b/tensorflow/contrib/nn/BUILD
index 4b7288e235..0ed7e52159 100644
--- a/tensorflow/contrib/nn/BUILD
+++ b/tensorflow/contrib/nn/BUILD
@@ -18,6 +18,7 @@ py_library(
"python/ops/alpha_dropout.py",
"python/ops/cross_entropy.py",
"python/ops/sampling_ops.py",
+ "python/ops/scaled_softplus.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
@@ -26,6 +27,7 @@ py_library(
"//tensorflow/python:dtypes",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
"//tensorflow/python:random_ops",
@@ -36,6 +38,23 @@ py_library(
)
py_test(
+ name = "alpha_dropout_test",
+ size = "small",
+ srcs = ["python/ops/alpha_dropout_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":nn_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:nn",
+ "//tensorflow/python:random_ops",
+ ],
+)
+
+py_test(
name = "sampling_ops_test",
size = "small",
srcs = ["python/ops/sampling_ops_test.py"],
@@ -51,19 +70,16 @@ py_test(
)
py_test(
- name = "alpha_dropout_test",
+ name = "scaled_softplus_test",
size = "small",
- srcs = ["python/ops/alpha_dropout_test.py"],
+ srcs = ["python/ops/scaled_softplus_test.py"],
srcs_version = "PY2AND3",
deps = [
":nn_py",
- "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:nn",
- "//tensorflow/python:random_ops",
],
)
diff --git a/tensorflow/contrib/nn/__init__.py b/tensorflow/contrib/nn/__init__.py
index 2cfeaa955d..be0957f473 100644
--- a/tensorflow/contrib/nn/__init__.py
+++ b/tensorflow/contrib/nn/__init__.py
@@ -26,9 +26,10 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.nn.python.ops.alpha_dropout import *
from tensorflow.contrib.nn.python.ops.cross_entropy import *
from tensorflow.contrib.nn.python.ops.sampling_ops import *
-from tensorflow.contrib.nn.python.ops.alpha_dropout import *
+from tensorflow.contrib.nn.python.ops.scaled_softplus import *
# pylint: enable=unused-import,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/nn/python/ops/scaled_softplus.py b/tensorflow/contrib/nn/python/ops/scaled_softplus.py
new file mode 100644
index 0000000000..5fc11d8ec6
--- /dev/null
+++ b/tensorflow/contrib/nn/python/ops/scaled_softplus.py
@@ -0,0 +1,77 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Support for scaled softplus, a smoothed version of ReLU."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+
+
+def scaled_softplus(x, alpha, name=None):
+ """Returns `alpha * ln(1 + exp(x / alpha))`, for scalar `alpha > 0`.
+
+ This can be seen as a softplus applied to the scaled input, with the output
+ appropriately scaled. As `alpha` tends to 0, `scaled_softplus(x, alpha)` tends
+ to `relu(x)`.
+
+ Note: the gradient for this operation is defined to depend on the backprop
+ inputs as well as the outputs of this operation.
+
+ Args:
+ x: A `Tensor` of inputs.
+ alpha: A scalar `Tensor`, indicating the amount of smoothness. The caller
+ must ensure that `alpha > 0`.
+ name: A name for the scope of the operations (optional).
+
+ Returns:
+ A tensor of same size and type as `x`.
+
+ """
+ with ops.name_scope(name, 'scaled_softplus', [x, alpha]):
+ x = ops.convert_to_tensor(x, name='x')
+ dtype = x.dtype
+ alpha = ops.convert_to_tensor(alpha, dtype=dtype, name='alpha')
+ # Verify that alpha is a scalar.
+ alpha.get_shape().assert_has_rank(0)
+
+ def _grad(op, g):
+ """Backprop for scaled softplus."""
+ y = op.outputs[0]
+ alpha = op.inputs[1]
+ # Prevent the expensive computations from happening before g is available.
+ with ops.control_dependencies([g]):
+ y /= alpha
+ emy = math_ops.exp(-y)
+ dy_dx = 1. - emy
+ # The eps below avoids log(0). Note that t*log(t) -> 0 as t->0.
+ eps = 1e-8
+ dy_dalpha = y * emy - dy_dx * math_ops.log(dy_dx + eps)
+ return g * dy_dx, math_ops.reduce_sum(g * dy_dalpha)
+
+ @function.Defun(dtype, dtype,
+ func_name='ScaledSoftplus_%s' % dtype.name,
+ shape_func=lambda op: [op.inputs[0].get_shape()],
+ python_grad_func=_grad)
+ def _forward(x, alpha):
+ """Forward computation of scaled softplus."""
+ return alpha * nn.softplus(x / alpha)
+
+ return _forward(x, alpha)
+
diff --git a/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py b/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py
new file mode 100644
index 0000000000..3a459330ce
--- /dev/null
+++ b/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py
@@ -0,0 +1,67 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for scaled_softplus.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.nn.python.ops.scaled_softplus import scaled_softplus
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.platform import test
+
+
+class ScaledSoftplusTest(test.TestCase):
+
+ def test(self):
+ np.random.seed(1) # Make it reproducible.
+ x = np.random.randn(3, 4).astype(np.float32)
+ x64 = np.random.randn(3, 4).astype(np.float64)
+ alpha = np.random.rand() + 0.01
+ y = alpha * np.log(1. + np.exp(x / alpha))
+ y64 = alpha * np.log(1. + np.exp(x64 / alpha))
+ with self.test_session(use_gpu=True) as sess:
+ z = scaled_softplus(constant_op.constant(x), alpha)
+ z64 = scaled_softplus(constant_op.constant(x64), alpha)
+ z, z64 = sess.run([z, z64])
+ eps = 1e-6
+ self.assertAllClose(y, z, eps)
+ self.assertAllClose(y64, z64, eps)
+
+ def testGradient(self):
+ np.random.seed(1) # Make it reproducible.
+ x_shape = [5, 10]
+ x_np = np.random.randn(*x_shape).astype(np.float32)
+ alpha_np = np.float32(np.random.rand() + 0.01)
+ with self.test_session(use_gpu=True):
+ x_tf = constant_op.constant(x_np)
+ alpha_tf = constant_op.constant(alpha_np)
+ y_tf = scaled_softplus(x_tf, alpha_tf)
+ err = gradient_checker.compute_gradient_error([x_tf, alpha_tf],
+ [x_shape, []],
+ y_tf, x_shape,
+ [x_np, alpha_np],
+ delta=1e-2)
+ eps = 1e-4
+ self.assertLess(err, eps)
+
+
+if __name__ == '__main__':
+ test.main()
+
+