aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-04 06:09:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 06:14:08 -0700
commit82ea80b979768c7fe1daa4b50cf054e5a0968f31 (patch)
tree11d36ac98c1f6b3e4d3e8335188d33cf4a32161d /tensorflow/python
parent2c9369c8d878c913b5dfcd3c27849bcd3d6af6c9 (diff)
Add option in tf.gradients() to return zero tensors for unconnected gradients.
tf.gradients currently returns [NONE] when the gradient of unconnected variables is required. This backwards compatable change adds in the option to have zero tensors returned that match the dimensions of the input tensor. PiperOrigin-RevId: 215725488
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/BUILD4
-rw-r--r--tensorflow/python/ops/gradients.py1
-rw-r--r--tensorflow/python/ops/gradients_impl.py67
-rw-r--r--tensorflow/python/ops/gradients_test.py34
4 files changed, 101 insertions, 5 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index fe81254ef7..da3c56db92 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2152,6 +2152,7 @@ py_library(
":array_grad",
":array_ops",
":bitwise_ops",
+ ":check_ops",
":cond_v2_impl",
":control_flow_grad",
":control_flow_ops",
@@ -2172,8 +2173,11 @@ py_library(
":random_grad",
":resource_variable_ops",
":spectral_grad",
+ ":tensor_array_ops",
+ ":tensor_util",
":util",
":variable_scope",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:tape",
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
index 1dc666e78b..794465b10e 100644
--- a/tensorflow/python/ops/gradients.py
+++ b/tensorflow/python/ops/gradients.py
@@ -25,4 +25,5 @@ from tensorflow.python.ops.custom_gradient import custom_gradient
from tensorflow.python.ops.gradients_impl import AggregationMethod
from tensorflow.python.ops.gradients_impl import gradients
from tensorflow.python.ops.gradients_impl import hessians
+from tensorflow.python.ops.gradients_impl import UnconnectedGradients
# pylint: enable=unused-import
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 056015d6b6..aac95037dc 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import collections
import contextlib
+import enum # pylint: disable=g-bad-import-order
import sys
import warnings
@@ -537,6 +538,26 @@ def _Consumers(t, func_graphs):
return consumers
+@tf_export("UnconnectedGradients")
+class UnconnectedGradients(enum.Enum):
+ """Controls how gradient computation behaves when y does not depend on x.
+
+ The gradient of y with respect to x can be zero in two different ways: there
+ could be no differentiable path in the graph connecting x to y (and so we can
+ statically prove that the gradient is zero) or it could be that runtime values
+ of tensors in a particular execution lead to a gradient of zero (say, if a
+ relu unit happens to not be activated). To allow you to distinguish between
+ these two cases you can choose what value gets returned for the gradient when
+ there is no path in the graph from x to y:
+
+ * `NONE`: Indicates that [None] will be returned if there is no path from x
+ to y
+ * `ZERO`: Indicates that a zero tensor will be returned in the shape of x.
+ """
+ NONE = "none"
+ ZERO = "zero"
+
+
@tf_export("gradients")
def gradients(ys,
xs,
@@ -545,7 +566,8 @@ def gradients(ys,
colocate_gradients_with_ops=False,
gate_gradients=False,
aggregation_method=None,
- stop_gradients=None):
+ stop_gradients=None,
+ unconnected_gradients=UnconnectedGradients.NONE):
"""Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`.
`ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys`
@@ -596,6 +618,23 @@ def gradients(ys,
All integer tensors are considered constant with respect to all `xs`, as if
they were included in `stop_gradients`.
+ `unconnected_gradients` determines the value returned for each x in xs if it
+ is unconnected in the graph to ys. By default this is None to safeguard
+ against errors. MAthematically these gradients are zero which can be requested
+ using the `'zero'` option. `tf.UnconnectedGradients` provides the
+ following options and behaviors:
+
+ ```python
+ a = tf.ones([1, 2])
+ b = tf.ones([3, 1])
+ g1 = tf.gradients([b], [a], unnconnected_gradients='none')
+ sess.run(g1) # [None]
+
+ g2 = tf.gradients([b], [a], unconnected_gradients='zero')
+ sess.run(g2) # [array([[0., 0.]], dtype=float32)]
+ ```
+
+
Args:
ys: A `Tensor` or list of tensors to be differentiated.
xs: A `Tensor` or list of tensors to be used for differentiation.
@@ -611,6 +650,10 @@ def gradients(ys,
Accepted values are constants defined in the class `AggregationMethod`.
stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate
through.
+ unconnected_gradients: Optional. Specifies the gradient value returned when
+ the given input tensors are unconnected. Accepted values are constants
+ defined in the class `tf.UnconnectedGradients` and the default value is
+ `none`.
Returns:
A list of `sum(dy/dx)` for each x in `xs`.
@@ -627,7 +670,8 @@ def gradients(ys,
# mutating new ops.
with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
- gate_gradients, aggregation_method, stop_gradients)
+ gate_gradients, aggregation_method, stop_gradients,
+ unconnected_gradients)
def _GradientsHelper(ys,
@@ -638,6 +682,7 @@ def _GradientsHelper(ys,
gate_gradients=False,
aggregation_method=None,
stop_gradients=None,
+ unconnected_gradients=UnconnectedGradients.NONE,
src_graph=None):
"""Implementation of gradients()."""
if context.executing_eagerly():
@@ -645,6 +690,11 @@ def _GradientsHelper(ys,
"is enabled. Use tf.GradientTape instead.")
if src_graph is None:
src_graph = ops.get_default_graph()
+ try:
+ unconnected_gradients = UnconnectedGradients(unconnected_gradients)
+ except ValueError:
+ raise ValueError(
+ "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
# If src_graph is a _FuncGraph (i.e. a function body), gather it and all
# ancestor graphs. This is necessary for correctly handling captured values.
@@ -856,7 +906,7 @@ def _GradientsHelper(ys,
if loop_state:
loop_state.PostProcessing()
- return [_GetGrad(grads, x) for x in xs]
+ return [_GetGrad(grads, x, unconnected_gradients) for x in xs]
def _HasAnyNotNoneGrads(grads, op):
@@ -924,12 +974,19 @@ def _SetGrad(grads, t, grad):
op_grads[t.value_index] = grad
-def _GetGrad(grads, t):
+def _GetGrad(grads, t, unconnected_gradients):
"""Gets gradient for tensor "t"."""
op = t.op
op_grads = grads.get(op)
if not op_grads:
- return None
+ if unconnected_gradients == UnconnectedGradients.ZERO:
+ return array_ops.zeros_like(t)
+ elif unconnected_gradients == UnconnectedGradients.NONE:
+ return None
+ else:
+ raise ValueError(
+ "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
+
t_grad = op_grads[t.value_index]
assert not isinstance(
t_grad, list), ("gradients list should have been aggregated by now.")
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 3c9b7a01c7..c93e2493ee 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -350,6 +350,40 @@ class GradientsTest(test_util.TensorFlowTestCase):
for a, b in zip(npgrad1, npgrad2):
np.testing.assert_allclose(a, b)
+ def testUnconnectedGradientsNoneUnconnectedGradients(self):
+ with ops.Graph().as_default():
+ x = constant(1.0, shape=[2, 2])
+ y = constant(3.0, shape=[3, 1])
+ grad = gradients.gradients(
+ [y], [x], unconnected_gradients="none")
+ self.assertIsNone(grad[0])
+
+ def testUnconnectedGradientsZerosUnconnectedGradients(self):
+ with ops.Graph().as_default():
+ x = constant(1.0, shape=[2, 2])
+ y = constant(3.0, shape=[3, 1])
+ grads = gradients.gradients(
+ [y], [x], unconnected_gradients="zero")
+ with self.cached_session() as sess:
+ self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], sess.run(grads)[0])
+
+ def testUnconnectedGradientsZeroConnectedGradients(self):
+ with ops.Graph().as_default():
+ x = constant(1.0)
+ y = x * 3.0
+ grad = gradients.gradients(
+ [y], [x], unconnected_gradients="zero")
+ with self.cached_session() as sess:
+ self.assertEquals(3.0, sess.run(grad)[0])
+
+ def testUnknownUnconnectedGradientsValueGiven(self):
+ with ops.Graph().as_default():
+ x = constant(1.0)
+ y = constant(1.0)
+ with self.assertRaisesRegexp(
+ ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
+ gradients.gradients([y], [x], unconnected_gradients="nonsense")
+
class FunctionGradientsTest(test_util.TensorFlowTestCase):