diff options
author | 2016-01-21 12:04:15 -0800 | |
---|---|---|
committer | 2016-01-21 17:59:51 -0800 | |
commit | c4926d8b7853d19915f5d8176a1a30cca9955285 (patch) | |
tree | cebb914824b66ad2a8fbf3c5c68ec78e7802642e | |
parent | 990121cf1b727896e6d69e9fb15d980273284bf6 (diff) |
Need to check if the op is in a loop context.
Change: 112709520
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 54 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 25 |
2 files changed, 42 insertions, 37 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 7c405e8212..ebb20e0cc0 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -680,7 +680,7 @@ class ControlFlowTest(tf.test.TestCase): r = control_flow_ops.While(c, b, [n]) result = r.eval() - self.assertEqual(10.0, result) + self.assertAllClose(10.0, result) def testWhile_Gpu_1(self): self._testWhile_Gpu_1(use_gpu=False) @@ -696,7 +696,7 @@ class ControlFlowTest(tf.test.TestCase): r = control_flow_ops.While(c, b, [n]) result = r.eval() - self.assertEqual(10.0, result) + self.assertAllClose(10.0, result) def testWhile_Gpu_2(self): self._testWhile_Gpu_1(use_gpu=False) @@ -846,7 +846,7 @@ class ControlFlowTest(tf.test.TestCase): tf.initialize_all_variables().run() self.assertEqual(3, r.eval()) result = select.eval() - self.assertAllEqual(np.array([10.0, 10.0, 10.0]), result) + self.assertAllClose(np.array([10.0, 10.0, 10.0]), result) def testWhileUpdateVariable_2(self): with self.test_session(): @@ -873,9 +873,9 @@ class ControlFlowTest(tf.test.TestCase): tf.initialize_all_variables().run() self.assertEqual(3, r.eval()) result1 = select1.eval() - self.assertAllEqual(np.array([10.0, 10.0, 10.0]), result1) + self.assertAllClose(np.array([10.0, 10.0, 10.0]), result1) result2 = select2.eval() - self.assertAllEqual(np.array([10.0, 10.0, 10.0]), result2) + self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2) def testWhileUpdateVariable_3(self): with self.test_session(): @@ -897,7 +897,7 @@ class ControlFlowTest(tf.test.TestCase): tf.initialize_all_variables().run() result = r[1].eval() self.assertTrue(check_op_order(n.graph)) - self.assertAllEqual(np.array([10.0, 10.0, 10.0]), result) + self.assertAllClose(np.array([10.0, 10.0, 10.0]), result) # b/24814703 def testWhileUpdateVariable_4(self): @@ -1029,10 +1029,11 @@ class ControlFlowTest(tf.test.TestCase): c = lambda v: tf.less(v, 100.0) b = tf.square r = control_flow_ops.While(c, b, [v], parallel_iterations=1) + r = control_flow_ops.cond(tf.less(1, 2), lambda: r, lambda: v) - r = tf.gradients(r, v) - result = r[0].eval() - self.assertEqual(1024.0, result) + r = tf.gradients(r, v)[0] + result = r.eval() + self.assertAllClose(1024.0, result) def _testWhileGrad_Mul(self, use_gpu, p_iters): with self.test_session(use_gpu=use_gpu) as sess: @@ -1045,8 +1046,8 @@ class ControlFlowTest(tf.test.TestCase): grad_a, grad_v = tf.gradients(r, [a, v]) grad_a_val, grad_v_val = sess.run([grad_a, grad_v]) - self.assertEqual(216.0, grad_a_val) - self.assertEqual(81.0, grad_v_val) + self.assertAllClose(216.0, grad_a_val) + self.assertAllClose(81.0, grad_v_val) def testWhileGrad_Mul(self): self._testWhileGrad_Mul(use_gpu=False, p_iters=1) @@ -1065,7 +1066,7 @@ class ControlFlowTest(tf.test.TestCase): r = tf.gradients(r, a) tf.initialize_all_variables().run() result = r[0].eval() - self.assertEqual(216.0, result) + self.assertAllClose(216.0, result) def testWhileGrad_ys_xs(self): with self.test_session(): @@ -1080,13 +1081,13 @@ class ControlFlowTest(tf.test.TestCase): rx, ry = control_flow_ops.While(c, b, [x, y], parallel_iterations=1) r = tf.gradients([rx, ry], x) - self.assertEqual(304.0, r[0].eval()) + self.assertAllClose(304.0, r[0].eval()) r = tf.gradients([rx, ry], y) - self.assertEqual(124.0, r[0].eval()) + self.assertAllClose(124.0, r[0].eval()) r = tf.gradients([rx], x) - self.assertEqual(295.0, r[0].eval()) + self.assertAllClose(295.0, r[0].eval()) r = tf.gradients([rx], y) - self.assertEqual(120.0, r[0].eval()) + self.assertAllClose(120.0, r[0].eval()) def testWhileGrad_Dependency(self): with self.test_session(): @@ -1101,9 +1102,9 @@ class ControlFlowTest(tf.test.TestCase): ri, rx = control_flow_ops.While(c, b, [i, x], parallel_iterations=1) r = tf.gradients([ri, rx], x) - self.assertEqual(1024.0, r[0].eval()) + self.assertAllClose(1024.0, r[0].eval()) r = tf.gradients([rx], x) - self.assertEqual(1024.0, r[0].eval()) + self.assertAllClose(1024.0, r[0].eval()) def testWhileGrad_NoGradient(self): with self.test_session(): @@ -1114,7 +1115,7 @@ class ControlFlowTest(tf.test.TestCase): r = tf.add(r, v) r = tf.gradients(r, v) result = r[0].eval() - self.assertEqual(1.0, result) + self.assertAllClose(1.0, result) def testWhileGrad_SerialTwoLoops(self): with self.test_session(): @@ -1130,7 +1131,7 @@ class ControlFlowTest(tf.test.TestCase): _, rx = control_flow_ops.While(c, b, [i, rx], parallel_iterations=1) r = tf.gradients([rx], x) - self.assertEqual(1024.0, r[0].eval()) + self.assertAllClose(1024.0, r[0].eval()) def testWhileGrad_ParallelTwoLoops(self): with self.test_session(): @@ -1147,7 +1148,7 @@ class ControlFlowTest(tf.test.TestCase): rx = tf.add(r1, r2) r = tf.gradients([rx], x) - self.assertEqual(64.0, r[0].eval()) + self.assertAllClose(64.0, r[0].eval()) def _testNestedWhileGrad_Simple(self, use_gpu): with self.test_session(use_gpu=use_gpu): @@ -1161,7 +1162,7 @@ class ControlFlowTest(tf.test.TestCase): r = control_flow_ops.While(c, b, [v]) r = tf.gradients(r, v)[0] - self.assertEqual(8.0, r.eval()) + self.assertAllClose(8.0, r.eval()) def testNestedWhileGrad_Simple(self): self._testNestedWhileGrad_Simple(use_gpu=False) @@ -1185,7 +1186,7 @@ class ControlFlowTest(tf.test.TestCase): r = control_flow_ops.While(c, b, [v]) r = tf.gradients(r, v)[0] - self.assertEqual(256.0, r.eval()) + self.assertAllClose(256.0, r.eval()) def testNestedWhileGrad_ParallelInner(self): with self.test_session(): @@ -1205,7 +1206,7 @@ class ControlFlowTest(tf.test.TestCase): r = control_flow_ops.While(c, b, [v]) r = tf.gradients(r, v)[0] - self.assertEqual(512.0, r.eval()) + self.assertAllClose(512.0, r.eval()) def _testWhileCondGrad_Simple(self, use_gpu): with self.test_session(use_gpu=use_gpu): @@ -1218,7 +1219,7 @@ class ControlFlowTest(tf.test.TestCase): lambda: tf.sub(x, one)) r = control_flow_ops.While(c, b, [v]) r = tf.gradients(r, v)[0] - self.assertEqual(1024.0, r.eval()) + self.assertAllClose(1024.0, r.eval()) def testWhileCondGrad_Simple(self): self._testWhileCondGrad_Simple(use_gpu=False) @@ -1274,7 +1275,7 @@ class ControlFlowTest(tf.test.TestCase): return tf.reduce_sum(tf.abs(x)) i = tf.cond(tf.equal(d, 2), l2, l1) - self.assertEqual(4.0, i.eval(feed_dict={d: 1})) + self.assertAllClose(4.0, i.eval(feed_dict={d: 1})) self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2})) def testOneOpCond(self): @@ -1561,6 +1562,5 @@ class TupleTest(tf.test.TestCase): self.assertEquals(1, var.eval()) - if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index abb7a48f54..6ab6125a3d 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -735,11 +735,16 @@ class ControlFlowState(object): def __init__(self): self._map = {} # maps forward loop context to GradLoopState + def _GetGradState(self, op): + forward_ctxt = _GetWhileContext(op) + if forward_ctxt is None: + return None + return self._map.get(forward_ctxt) + def MakeWrapper(self, op): """Make a wrapper for op if it is in a WhileContext.""" - forward_ctxt = _GetWhileContext(op) - if forward_ctxt: - grad_state = self._map.get(forward_ctxt) + grad_state = self._GetGradState(op) + if grad_state: return ControlFlowOpWrapper(op, grad_state) return op @@ -753,16 +758,14 @@ class ControlFlowState(object): def EnterGradWhileContext(self, op): """Enter the WhileContext for gradient computation.""" - forward_ctxt = _GetWhileContext(op) - if forward_ctxt: - grad_state = self._map.get(forward_ctxt) + grad_state = self._GetGradState(op) + if grad_state: grad_state.grad_context.Enter() def ExitGradWhileContext(self, op): """Exit the WhileContext for gradient computation.""" - forward_ctxt = _GetWhileContext(op) - if forward_ctxt: - grad_state = self._map.get(forward_ctxt) + grad_state = self._GetGradState(op) + if grad_state: grad_state.grad_context.Exit() def AddWhileContext(self, op, between_op_list, between_ops): @@ -856,8 +859,10 @@ class ControlFlowState(object): if IsLoopSwitch(op): return None dead_branch = op.type in {"Switch", "RefSwitch"} + forward_ctxt = _GetWhileContext(op) + if forward_ctxt is None: + return array_ops.zeros_like(op.outputs[index]) op_ctxt = op._get_control_flow_context() - forward_ctxt = op_ctxt.GetWhileContext() grad_state = self._map.get(forward_ctxt) val = ops.convert_to_tensor(op.outputs[index], name="tensor") shape = val.get_shape() |