aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-01-21 12:04:15 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-01-21 17:59:51 -0800
commitc4926d8b7853d19915f5d8176a1a30cca9955285 (patch)
treecebb914824b66ad2a8fbf3c5c68ec78e7802642e
parent990121cf1b727896e6d69e9fb15d980273284bf6 (diff)
Need to check if the op is in a loop context.
Change: 112709520
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py54
-rw-r--r--tensorflow/python/ops/control_flow_ops.py25
2 files changed, 42 insertions, 37 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 7c405e8212..ebb20e0cc0 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -680,7 +680,7 @@ class ControlFlowTest(tf.test.TestCase):
r = control_flow_ops.While(c, b, [n])
result = r.eval()
- self.assertEqual(10.0, result)
+ self.assertAllClose(10.0, result)
def testWhile_Gpu_1(self):
self._testWhile_Gpu_1(use_gpu=False)
@@ -696,7 +696,7 @@ class ControlFlowTest(tf.test.TestCase):
r = control_flow_ops.While(c, b, [n])
result = r.eval()
- self.assertEqual(10.0, result)
+ self.assertAllClose(10.0, result)
def testWhile_Gpu_2(self):
self._testWhile_Gpu_1(use_gpu=False)
@@ -846,7 +846,7 @@ class ControlFlowTest(tf.test.TestCase):
tf.initialize_all_variables().run()
self.assertEqual(3, r.eval())
result = select.eval()
- self.assertAllEqual(np.array([10.0, 10.0, 10.0]), result)
+ self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
def testWhileUpdateVariable_2(self):
with self.test_session():
@@ -873,9 +873,9 @@ class ControlFlowTest(tf.test.TestCase):
tf.initialize_all_variables().run()
self.assertEqual(3, r.eval())
result1 = select1.eval()
- self.assertAllEqual(np.array([10.0, 10.0, 10.0]), result1)
+ self.assertAllClose(np.array([10.0, 10.0, 10.0]), result1)
result2 = select2.eval()
- self.assertAllEqual(np.array([10.0, 10.0, 10.0]), result2)
+ self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)
def testWhileUpdateVariable_3(self):
with self.test_session():
@@ -897,7 +897,7 @@ class ControlFlowTest(tf.test.TestCase):
tf.initialize_all_variables().run()
result = r[1].eval()
self.assertTrue(check_op_order(n.graph))
- self.assertAllEqual(np.array([10.0, 10.0, 10.0]), result)
+ self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
# b/24814703
def testWhileUpdateVariable_4(self):
@@ -1029,10 +1029,11 @@ class ControlFlowTest(tf.test.TestCase):
c = lambda v: tf.less(v, 100.0)
b = tf.square
r = control_flow_ops.While(c, b, [v], parallel_iterations=1)
+ r = control_flow_ops.cond(tf.less(1, 2), lambda: r, lambda: v)
- r = tf.gradients(r, v)
- result = r[0].eval()
- self.assertEqual(1024.0, result)
+ r = tf.gradients(r, v)[0]
+ result = r.eval()
+ self.assertAllClose(1024.0, result)
def _testWhileGrad_Mul(self, use_gpu, p_iters):
with self.test_session(use_gpu=use_gpu) as sess:
@@ -1045,8 +1046,8 @@ class ControlFlowTest(tf.test.TestCase):
grad_a, grad_v = tf.gradients(r, [a, v])
grad_a_val, grad_v_val = sess.run([grad_a, grad_v])
- self.assertEqual(216.0, grad_a_val)
- self.assertEqual(81.0, grad_v_val)
+ self.assertAllClose(216.0, grad_a_val)
+ self.assertAllClose(81.0, grad_v_val)
def testWhileGrad_Mul(self):
self._testWhileGrad_Mul(use_gpu=False, p_iters=1)
@@ -1065,7 +1066,7 @@ class ControlFlowTest(tf.test.TestCase):
r = tf.gradients(r, a)
tf.initialize_all_variables().run()
result = r[0].eval()
- self.assertEqual(216.0, result)
+ self.assertAllClose(216.0, result)
def testWhileGrad_ys_xs(self):
with self.test_session():
@@ -1080,13 +1081,13 @@ class ControlFlowTest(tf.test.TestCase):
rx, ry = control_flow_ops.While(c, b, [x, y], parallel_iterations=1)
r = tf.gradients([rx, ry], x)
- self.assertEqual(304.0, r[0].eval())
+ self.assertAllClose(304.0, r[0].eval())
r = tf.gradients([rx, ry], y)
- self.assertEqual(124.0, r[0].eval())
+ self.assertAllClose(124.0, r[0].eval())
r = tf.gradients([rx], x)
- self.assertEqual(295.0, r[0].eval())
+ self.assertAllClose(295.0, r[0].eval())
r = tf.gradients([rx], y)
- self.assertEqual(120.0, r[0].eval())
+ self.assertAllClose(120.0, r[0].eval())
def testWhileGrad_Dependency(self):
with self.test_session():
@@ -1101,9 +1102,9 @@ class ControlFlowTest(tf.test.TestCase):
ri, rx = control_flow_ops.While(c, b, [i, x], parallel_iterations=1)
r = tf.gradients([ri, rx], x)
- self.assertEqual(1024.0, r[0].eval())
+ self.assertAllClose(1024.0, r[0].eval())
r = tf.gradients([rx], x)
- self.assertEqual(1024.0, r[0].eval())
+ self.assertAllClose(1024.0, r[0].eval())
def testWhileGrad_NoGradient(self):
with self.test_session():
@@ -1114,7 +1115,7 @@ class ControlFlowTest(tf.test.TestCase):
r = tf.add(r, v)
r = tf.gradients(r, v)
result = r[0].eval()
- self.assertEqual(1.0, result)
+ self.assertAllClose(1.0, result)
def testWhileGrad_SerialTwoLoops(self):
with self.test_session():
@@ -1130,7 +1131,7 @@ class ControlFlowTest(tf.test.TestCase):
_, rx = control_flow_ops.While(c, b, [i, rx], parallel_iterations=1)
r = tf.gradients([rx], x)
- self.assertEqual(1024.0, r[0].eval())
+ self.assertAllClose(1024.0, r[0].eval())
def testWhileGrad_ParallelTwoLoops(self):
with self.test_session():
@@ -1147,7 +1148,7 @@ class ControlFlowTest(tf.test.TestCase):
rx = tf.add(r1, r2)
r = tf.gradients([rx], x)
- self.assertEqual(64.0, r[0].eval())
+ self.assertAllClose(64.0, r[0].eval())
def _testNestedWhileGrad_Simple(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
@@ -1161,7 +1162,7 @@ class ControlFlowTest(tf.test.TestCase):
r = control_flow_ops.While(c, b, [v])
r = tf.gradients(r, v)[0]
- self.assertEqual(8.0, r.eval())
+ self.assertAllClose(8.0, r.eval())
def testNestedWhileGrad_Simple(self):
self._testNestedWhileGrad_Simple(use_gpu=False)
@@ -1185,7 +1186,7 @@ class ControlFlowTest(tf.test.TestCase):
r = control_flow_ops.While(c, b, [v])
r = tf.gradients(r, v)[0]
- self.assertEqual(256.0, r.eval())
+ self.assertAllClose(256.0, r.eval())
def testNestedWhileGrad_ParallelInner(self):
with self.test_session():
@@ -1205,7 +1206,7 @@ class ControlFlowTest(tf.test.TestCase):
r = control_flow_ops.While(c, b, [v])
r = tf.gradients(r, v)[0]
- self.assertEqual(512.0, r.eval())
+ self.assertAllClose(512.0, r.eval())
def _testWhileCondGrad_Simple(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
@@ -1218,7 +1219,7 @@ class ControlFlowTest(tf.test.TestCase):
lambda: tf.sub(x, one))
r = control_flow_ops.While(c, b, [v])
r = tf.gradients(r, v)[0]
- self.assertEqual(1024.0, r.eval())
+ self.assertAllClose(1024.0, r.eval())
def testWhileCondGrad_Simple(self):
self._testWhileCondGrad_Simple(use_gpu=False)
@@ -1274,7 +1275,7 @@ class ControlFlowTest(tf.test.TestCase):
return tf.reduce_sum(tf.abs(x))
i = tf.cond(tf.equal(d, 2), l2, l1)
- self.assertEqual(4.0, i.eval(feed_dict={d: 1}))
+ self.assertAllClose(4.0, i.eval(feed_dict={d: 1}))
self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
def testOneOpCond(self):
@@ -1561,6 +1562,5 @@ class TupleTest(tf.test.TestCase):
self.assertEquals(1, var.eval())
-
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index abb7a48f54..6ab6125a3d 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -735,11 +735,16 @@ class ControlFlowState(object):
def __init__(self):
self._map = {} # maps forward loop context to GradLoopState
+ def _GetGradState(self, op):
+ forward_ctxt = _GetWhileContext(op)
+ if forward_ctxt is None:
+ return None
+ return self._map.get(forward_ctxt)
+
def MakeWrapper(self, op):
"""Make a wrapper for op if it is in a WhileContext."""
- forward_ctxt = _GetWhileContext(op)
- if forward_ctxt:
- grad_state = self._map.get(forward_ctxt)
+ grad_state = self._GetGradState(op)
+ if grad_state:
return ControlFlowOpWrapper(op, grad_state)
return op
@@ -753,16 +758,14 @@ class ControlFlowState(object):
def EnterGradWhileContext(self, op):
"""Enter the WhileContext for gradient computation."""
- forward_ctxt = _GetWhileContext(op)
- if forward_ctxt:
- grad_state = self._map.get(forward_ctxt)
+ grad_state = self._GetGradState(op)
+ if grad_state:
grad_state.grad_context.Enter()
def ExitGradWhileContext(self, op):
"""Exit the WhileContext for gradient computation."""
- forward_ctxt = _GetWhileContext(op)
- if forward_ctxt:
- grad_state = self._map.get(forward_ctxt)
+ grad_state = self._GetGradState(op)
+ if grad_state:
grad_state.grad_context.Exit()
def AddWhileContext(self, op, between_op_list, between_ops):
@@ -856,8 +859,10 @@ class ControlFlowState(object):
if IsLoopSwitch(op): return None
dead_branch = op.type in {"Switch", "RefSwitch"}
+ forward_ctxt = _GetWhileContext(op)
+ if forward_ctxt is None:
+ return array_ops.zeros_like(op.outputs[index])
op_ctxt = op._get_control_flow_context()
- forward_ctxt = op_ctxt.GetWhileContext()
grad_state = self._map.get(forward_ctxt)
val = ops.convert_to_tensor(op.outputs[index], name="tensor")
shape = val.get_shape()