aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/ops/nn_impl.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 2c83e4e29f..431ea1186a 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -275,9 +275,6 @@ def _swish_shape(op):
return [op.inputs[0].shape]
-# Set noinline=True so that sigmoid(features) is re-computed during
-# backprop, and we can free the sigmoid(features) expression immediately
-# after use during the forward pass.
@function.Defun(shape_func=_swish_shape, func_name="swish_grad", noinline=True)
def _swish_grad(features, grad):
"""Gradient of Swish function defined below."""
@@ -287,6 +284,11 @@ def _swish_grad(features, grad):
return grad * activation_grad
+# Naively, x * tf.nn.sigmoid(x) requires keeping both x and sigmoid(x) around
+# for backprop, effectively doubling the tensor's memory consumption. We use a
+# @Defun decorator with noinline=True so that sigmoid(features) is re-computed
+# during backprop, and we can free the sigmoid(features) expression immediately
+# after use during the forward pass.
@function.Defun(
grad_func=_swish_grad,
shape_func=_swish_shape,
@@ -296,7 +298,7 @@ def swish(features):
# pylint: disable=g-doc-args
"""Computes the Swish activation function: `x * sigmoid(x)`.
- Source: "Swish: a Self-Gated Activation Function" (Ramachandran et al. 2017)
+ Source: "Searching for Activation Functions" (Ramachandran et al. 2017)
https://arxiv.org/abs/1710.05941
Args: