aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-02-07 12:58:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-07 13:06:59 -0800
commiteca81d4d31dc083de5df2c722cd36662c561211b (patch)
tree31df118bc7624b66c6a0af07d09d35d962f08699 /tensorflow/python
parent79d495177e33d6621b53959504edc8abd90cd228 (diff)
Better error message from trying to differentiate PreventGradient.
Also fix small lint error in gradients_impl.py Change: 146822514
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/kernel_tests/ctc_loss_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/sparse_xent_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/xent_op_test.py2
-rw-r--r--tensorflow/python/ops/array_grad.py6
-rw-r--r--tensorflow/python/ops/ctc_ops.py5
-rw-r--r--tensorflow/python/ops/gradients_test.py2
-rw-r--r--tensorflow/python/ops/nn_grad.py9
7 files changed, 21 insertions, 7 deletions
diff --git a/tensorflow/python/kernel_tests/ctc_loss_op_test.py b/tensorflow/python/kernel_tests/ctc_loss_op_test.py
index 434faefafa..5b93f90a79 100644
--- a/tensorflow/python/kernel_tests/ctc_loss_op_test.py
+++ b/tensorflow/python/kernel_tests/ctc_loss_op_test.py
@@ -257,7 +257,7 @@ class CTCLossTest(test.TestCase):
# Taking ths second gradient should fail, since it is not
# yet supported.
with self.assertRaisesRegexp(LookupError,
- ".*No gradient defined.*PreventGradient.*"):
+ "explicitly disabled"):
_ = gradients_impl._hessian_vector_product(loss, [inputs_t], v)
diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
index bff9ac4ec0..cd5b711a0e 100644
--- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
@@ -214,7 +214,7 @@ class SparseXentTest(test.TestCase):
# Taking ths second gradient should fail, since it is not
# yet supported.
with self.assertRaisesRegexp(LookupError,
- ".*No gradient defined.*PreventGradient.*"):
+ "explicitly disabled"):
_ = gradients_impl.hessians(loss, [weights])
def _testHighDim(self, features, labels):
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index c8c4e2bd1e..d037ceac61 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -191,7 +191,7 @@ class XentTest(test.TestCase):
# Taking ths second gradient should fail, since it is not
# yet supported.
with self.assertRaisesRegexp(LookupError,
- ".*No gradient defined.*PreventGradient.*"):
+ "explicitly disabled"):
_ = gradients_impl.hessians(loss, [f])
def testWrapper(self):
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 50aab4fad5..cb545f9fca 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -343,6 +343,12 @@ def _FillGrad(_, grad):
ops.NotDifferentiable("ZerosLike")
+@ops.RegisterGradient("PreventGradient")
+def _PreventGradientGrad(op, _):
+ raise LookupError(
+ "Gradient explicitly disabled. Reason: %s" % op.get_attr("message"))
+
+
@ops.RegisterGradient("Gather")
def _GatherGrad(op, grad):
"""Gradient for Gather op."""
diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py
index 1ce5597e13..b0a1fc3dd1 100644
--- a/tensorflow/python/ops/ctc_ops.py
+++ b/tensorflow/python/ops/ctc_ops.py
@@ -165,7 +165,10 @@ def _CTCLossGrad(op, grad_loss, _):
# due to the fused implementation's interaction with tf.gradients(),
# so we make sure we prevent silently incorrect results by raising
# an error if the second derivative is requested via prevent_gradient.
- grad_without_gradient = array_ops.prevent_gradient(op.outputs[1])
+ grad_without_gradient = array_ops.prevent_gradient(
+ op.outputs[1], message="Currently there is no way to take the second "
+ " derivative of ctc_loss due to the fused implementation's interaction "
+ " with tf.gradients()")
# Return gradient for inputs and None for
# labels_indices, labels_values and sequence_length
return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None]
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 453313b4ac..cfd463283d 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -418,7 +418,7 @@ class PreventGradientTest(test_util.TensorFlowTestCase):
with ops.Graph().as_default():
inp = constant(1.0, shape=[100, 32], name="in")
out = array_ops.prevent_gradient(inp)
- with self.assertRaisesRegexp(LookupError, "No gradient defined"):
+ with self.assertRaisesRegexp(LookupError, "explicitly disabled"):
_ = gradients.gradients(out, inp)
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 6147cdf221..397c522dbe 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -331,7 +331,10 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
# due to the fused implementation's interaction with tf.gradients(),
# so we make sure we prevent silently incorrect results by raising
# an error if the second derivative is requested via prevent_gradient.
- softmax_grad_without_gradient = array_ops.prevent_gradient(op.outputs[1])
+ softmax_grad_without_gradient = array_ops.prevent_gradient(
+ op.outputs[1], message="Currently there is no way to take the second "
+ "derivative of softmax_cross_entropy_with_logits due to the fused "
+ " implementation's interaction with tf.gradients()")
return _BroadcastMul(grad_0, softmax_grad_without_gradient), None
@@ -347,7 +350,9 @@ def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
# so we make sure we prevent silently incorrect results by raising
# an error if the second derivative is requested via prevent_gradient.
sparse_softmax_grad_without_gradient = array_ops.prevent_gradient(
- op.outputs[1])
+ op.outputs[1], message="Currently there is no way to take the second "
+ "derivative of sparse_softmax_cross_entropy_with_logits due to the fused "
+ "implementation's interaction with tf.gradients()")
return _BroadcastMul(grad_0, sparse_softmax_grad_without_gradient), None