aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-10-05 17:34:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 17:39:58 -0700
commit213d76a6ed77a696883502c53a3a4f81d2ee4042 (patch)
treed701196115c416f23b6861621ce4df79eaee5262 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parent4831740f90eaf266a99d3ffa7d390d54325b689f (diff)
Simply the logic for bubbling captured tensors when building cond_v2 grad.
The current logic tries to bubble the forward pass tensor to the outermost graph. That might not always be do-able e.g. when the cond is inside a while loop it will need to know accumulator logic for while_loop. So instead, the cond_grad now captures tensors from the forward If op's graph. When the grad If op is built these tensors will be appropriately captured by the surrounding FuncGraph. PiperOrigin-RevId: 215993009
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 7fae5249aa..baea5c0f6d 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -661,8 +661,7 @@ class ControlFlowTest(test.TestCase):
sess.run(r)
def testCondGrad_1(self):
- graph = ops.Graph()
- with graph.as_default():
+ with self.cached_session():
x = constant_op.constant(10.0, name="x")
pred = math_ops.less(1, 2)
fn1 = lambda: array_ops.identity(x)
@@ -670,8 +669,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
grad = gradients_impl.gradients(r, [x])[0]
- with self.cached_session():
- self.assertAllEqual(1.0, grad.eval())
+ self.assertAllEqual(1.0, grad.eval())
def testCondGrad_2(self):
with self.cached_session():