diff options
author | 2016-12-30 13:40:23 -0800 | |
---|---|---|
committer | 2016-12-30 13:47:36 -0800 | |
commit | 1243fbee608ac89299a69fd12fc338325116c219 (patch) | |
tree | cdd1402bb876bc4b4e365621761b7ac99b273d6f | |
parent | 7455e254e901e33acc1367e3aa011b1548f0b145 (diff) |
Deal with the case where _SwitchGrad() is not called the first time for a while loop (i.e. non-differentiable outputs)
Change: 143264614
-rw-r--r-- | tensorflow/python/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_grad.py | 9 | ||||
-rw-r--r-- | tensorflow/python/ops/gradients_impl.py | 5 | ||||
-rw-r--r-- | tensorflow/python/ops/gradients_test.py | 25 |
4 files changed, 37 insertions, 6 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index e8ba3435e9..6dd34c6be5 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1763,6 +1763,8 @@ cuda_py_test( additional_deps = [ ":array_grad", ":array_ops", + ":control_flow_grad", + ":control_flow_ops", ":data_flow_grad", ":data_flow_ops", ":framework_for_generated_wrappers", @@ -1775,6 +1777,8 @@ cuda_py_test( ":nn_ops", ":platform_test", ":state_grad", + ":tensor_array_grad", + ":tensor_array_ops", ":test_ops", "//third_party/py/numpy", ], diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py index d74a5ded3c..af84022795 100644 --- a/tensorflow/python/ops/control_flow_grad.py +++ b/tensorflow/python/ops/control_flow_grad.py @@ -55,14 +55,19 @@ def _SwitchGrad(op, *grad): control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1]) # pylint: enable=protected-access return None, None - else: - # This is the first time this Switch is visited. It always comes from + elif grad[0] is not None: + # This is the first time this Switch is visited. It comes from # the Exit branch, which is grad[0]. grad[1] is empty at this point. # Use grad[0] for both inputs to merge for now, but update the second # input of merge when we see this Switch the second time. merge_grad = merge([grad[0], grad[0]], name="b_switch")[0] grad_ctxt.grad_state.switch_map[op] = merge_grad return merge_grad, None + else: + # This is the first time this Switch is visited. It comes from the + # Identity branch. Such a Switch has `None` gradient for the Exit branch, + # meaning the output is not differentiable. + return None, None elif isinstance(op_ctxt, CondContext): good_grad = grad[op_ctxt.branch] zero_grad = grad[1 - op_ctxt.branch] diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 546031d737..7017640b7a 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -773,8 +773,9 @@ def _AggregatedGrads(grads, op, loop_state, aggregation_method=None): array_ops.concat_v2([x.values for x in out_grad], 0), array_ops.concat_v2([x.indices for x in out_grad], 0), out_grad[0].dense_shape) - else: - out_grads[i] = [] + else: # not out_grad + # out_grads[i] is [], thus its aggregation is simply None. + out_grads[i] = None return out_grads diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index eac37e6bfb..cf48bdda68 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -31,6 +31,8 @@ from tensorflow.python.framework import test_util from tensorflow.python.framework.constant_op import constant from tensorflow.python.ops import array_grad # pylint: disable=unused-import from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import from tensorflow.python.ops import functional_ops # pylint: disable=unused-import @@ -40,6 +42,8 @@ from tensorflow.python.ops import math_grad # pylint: disable=unused-import from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_grad # pylint: disable=unused-import from tensorflow.python.ops import state_grad # pylint: disable=unused-import +from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops.nn_ops import bias_add from tensorflow.python.platform import googletest @@ -202,7 +206,7 @@ class GradientsTest(test_util.TensorFlowTestCase): # Test that we don't differentiate 'x'. The gradient function for 'x' is # set explicitly to None so we will get an exception if the gradient code # tries to differentiate 'x'. - with ops.Graph().as_default() as g: + with ops.Graph().as_default(): c = constant(1.0) x = array_ops.identity(c) y = x + 1.0 @@ -290,6 +294,23 @@ class GradientsTest(test_util.TensorFlowTestCase): # tf.IndexedSlices. self.assertEqual(dx, dy) + def testNonDifferentiableSwitchInWhileLoop(self): + with ops.Graph().as_default(): + v = array_ops.placeholder(dtypes.float32, []) + + def _Step(i, a, ta): + a += math_ops.cast(v, dtypes.int32) + return (i + 1, a, ta.write(i, a)) + + n = 4 + i, _, ta = control_flow_ops.while_loop( + lambda i, *_: i < n, + _Step, [0, 0, tensor_array_ops.TensorArray( + dtypes.int32, size=n)]) + target = ta.read(i - 1) + grad, = gradients.gradients(target, v) + self.assertIsNone(grad) + class FunctionGradientsTest(test_util.TensorFlowTestCase): @@ -422,7 +443,7 @@ class HessianTest(test_util.TensorFlowTestCase): def testHessian1D(self): # Manually compute the Hessian explicitly for a low-dimensional problem - # and check that `hessian` matches. Specifically, the Hessian of + # and check that `hessian` matches. Specifically, the Hessian of # f(x) = x^T A x is H = A + A^T. m = 4 rng = np.random.RandomState([1, 2, 3]) |