aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-13 15:29:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-13 16:32:13 -0700
commit9d2e920a1453ad5247923e2d0cd8f11495bf1e36 (patch)
tree7ec2d505d59421a05f74a3b825eb6a371aab860a
parentbcc724f7ac82e536baeaaec8bf09d79ceb63b67c (diff)
Change sigmoid_cross_entropy_with_logits to fix gradients at 0.
Change: 122307487
-rw-r--r--tensorflow/python/ops/nn.py12
-rw-r--r--tensorflow/python/ops/nn_xent_test.py8
2 files changed, 17 insertions, 3 deletions
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index d4159c7e02..e3b5d1d5eb 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -306,10 +306,16 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
# x - x * z + log(1 + exp(-x))
# For x < 0, a more numerically stable formula is
# -x * z + log(1 + exp(x))
- # To avoid branching, we use the combined version
+ # Note that these two expressions can be combined into the following:
# max(x, 0) - x * z + log(1 + exp(-abs(x)))
- return math_ops.add(nn_ops.relu(logits) - logits * targets,
- math_ops.log(1 + math_ops.exp(-math_ops.abs(logits))),
+ # To allow computing gradients at zero, we define custom versions of max and
+ # abs functions.
+ zeros = array_ops.zeros_like(logits, dtype=logits.dtype)
+ cond = (logits >= zeros)
+ relu_logits = math_ops.select(cond, logits, zeros)
+ neg_abs_logits = math_ops.select(cond, -logits, logits)
+ return math_ops.add(relu_logits - logits * targets,
+ math_ops.log(1 + math_ops.exp(neg_abs_logits)),
name=name)
diff --git a/tensorflow/python/ops/nn_xent_test.py b/tensorflow/python/ops/nn_xent_test.py
index 92c67f0903..043f957833 100644
--- a/tensorflow/python/ops/nn_xent_test.py
+++ b/tensorflow/python/ops/nn_xent_test.py
@@ -82,6 +82,14 @@ class SigmoidCrossEntropyWithLogitsTest(tf.test.TestCase):
print("logistic loss gradient err = ", err)
self.assertLess(err, 1e-7)
+ def testGradientAtZero(self):
+ with self.test_session():
+ logits = tf.constant([0.0, 0.0], dtype=tf.float64)
+ targets = tf.constant([0.0, 1.0], dtype=tf.float64)
+ loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets)
+ grads = tf.gradients(loss, logits)[0].eval()
+ self.assertAllClose(grads, [0.5, -0.5])
+
def testShapeError(self):
with self.assertRaisesRegexp(ValueError, "must have the same shape"):
tf.nn.sigmoid_cross_entropy_with_logits([[2, 1]], [1, 2, 3])