aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-30 13:40:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-30 13:47:36 -0800
commit1243fbee608ac89299a69fd12fc338325116c219 (patch)
treecdd1402bb876bc4b4e365621761b7ac99b273d6f
parent7455e254e901e33acc1367e3aa011b1548f0b145 (diff)
Deal with the case where _SwitchGrad() is not called the first time for a while loop (i.e. non-differentiable outputs)
Change: 143264614
-rw-r--r--tensorflow/python/BUILD4
-rw-r--r--tensorflow/python/ops/control_flow_grad.py9
-rw-r--r--tensorflow/python/ops/gradients_impl.py5
-rw-r--r--tensorflow/python/ops/gradients_test.py25
4 files changed, 37 insertions, 6 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index e8ba3435e9..6dd34c6be5 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1763,6 +1763,8 @@ cuda_py_test(
additional_deps = [
":array_grad",
":array_ops",
+ ":control_flow_grad",
+ ":control_flow_ops",
":data_flow_grad",
":data_flow_ops",
":framework_for_generated_wrappers",
@@ -1775,6 +1777,8 @@ cuda_py_test(
":nn_ops",
":platform_test",
":state_grad",
+ ":tensor_array_grad",
+ ":tensor_array_ops",
":test_ops",
"//third_party/py/numpy",
],
diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py
index d74a5ded3c..af84022795 100644
--- a/tensorflow/python/ops/control_flow_grad.py
+++ b/tensorflow/python/ops/control_flow_grad.py
@@ -55,14 +55,19 @@ def _SwitchGrad(op, *grad):
control_flow_ops._AddNextAndBackEdge(merge_grad, grad[1])
# pylint: enable=protected-access
return None, None
- else:
- # This is the first time this Switch is visited. It always comes from
+ elif grad[0] is not None:
+ # This is the first time this Switch is visited. It comes from
# the Exit branch, which is grad[0]. grad[1] is empty at this point.
# Use grad[0] for both inputs to merge for now, but update the second
# input of merge when we see this Switch the second time.
merge_grad = merge([grad[0], grad[0]], name="b_switch")[0]
grad_ctxt.grad_state.switch_map[op] = merge_grad
return merge_grad, None
+ else:
+ # This is the first time this Switch is visited. It comes from the
+ # Identity branch. Such a Switch has `None` gradient for the Exit branch,
+ # meaning the output is not differentiable.
+ return None, None
elif isinstance(op_ctxt, CondContext):
good_grad = grad[op_ctxt.branch]
zero_grad = grad[1 - op_ctxt.branch]
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 546031d737..7017640b7a 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -773,8 +773,9 @@ def _AggregatedGrads(grads, op, loop_state, aggregation_method=None):
array_ops.concat_v2([x.values for x in out_grad], 0),
array_ops.concat_v2([x.indices for x in out_grad], 0),
out_grad[0].dense_shape)
- else:
- out_grads[i] = []
+ else: # not out_grad
+ # out_grads[i] is [], thus its aggregation is simply None.
+ out_grads[i] = None
return out_grads
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index eac37e6bfb..cf48bdda68 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -31,6 +31,8 @@ from tensorflow.python.framework import test_util
from tensorflow.python.framework.constant_op import constant
from tensorflow.python.ops import array_grad # pylint: disable=unused-import
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import
from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import
from tensorflow.python.ops import functional_ops # pylint: disable=unused-import
@@ -40,6 +42,8 @@ from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
from tensorflow.python.ops import state_grad # pylint: disable=unused-import
+from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
+from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.nn_ops import bias_add
from tensorflow.python.platform import googletest
@@ -202,7 +206,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
# Test that we don't differentiate 'x'. The gradient function for 'x' is
# set explicitly to None so we will get an exception if the gradient code
# tries to differentiate 'x'.
- with ops.Graph().as_default() as g:
+ with ops.Graph().as_default():
c = constant(1.0)
x = array_ops.identity(c)
y = x + 1.0
@@ -290,6 +294,23 @@ class GradientsTest(test_util.TensorFlowTestCase):
# tf.IndexedSlices.
self.assertEqual(dx, dy)
+ def testNonDifferentiableSwitchInWhileLoop(self):
+ with ops.Graph().as_default():
+ v = array_ops.placeholder(dtypes.float32, [])
+
+ def _Step(i, a, ta):
+ a += math_ops.cast(v, dtypes.int32)
+ return (i + 1, a, ta.write(i, a))
+
+ n = 4
+ i, _, ta = control_flow_ops.while_loop(
+ lambda i, *_: i < n,
+ _Step, [0, 0, tensor_array_ops.TensorArray(
+ dtypes.int32, size=n)])
+ target = ta.read(i - 1)
+ grad, = gradients.gradients(target, v)
+ self.assertIsNone(grad)
+
class FunctionGradientsTest(test_util.TensorFlowTestCase):
@@ -422,7 +443,7 @@ class HessianTest(test_util.TensorFlowTestCase):
def testHessian1D(self):
# Manually compute the Hessian explicitly for a low-dimensional problem
- # and check that `hessian` matches. Specifically, the Hessian of
+ # and check that `hessian` matches. Specifically, the Hessian of
# f(x) = x^T A x is H = A + A^T.
m = 4
rng = np.random.RandomState([1, 2, 3])