aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-26 12:42:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-26 12:45:22 -0700
commitf63750645826df65b05cad505546a86f0e347674 (patch)
tree8467d73780d74b0f7ef4c87f8866d3bf0a233254 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parent667077cbd2cc86c4a656233a2d5f579aa4caf1f1 (diff)
For tf.gradients(), do not backpropagate through integer tensors.
All integer tensors are now considered constant with respect to all `xs`. This fixes a bug in gradients through tf.while_loop. PiperOrigin-RevId: 194438529
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py65
1 files changed, 61 insertions, 4 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 209411cf51..77e6f5f1a0 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -2222,14 +2222,14 @@ class ControlFlowTest(test.TestCase):
def testWhileWithRefsWithGradients_1(self):
with self.test_session() as sess:
- x = variables.Variable(0)._ref() # pylint: disable=protected-access
+ x = variables.Variable(0.)._ref() # pylint: disable=protected-access
i = constant_op.constant(0)
c = lambda i, x: math_ops.less(i, 10)
- self.assertEqual(x.dtype, dtypes.int32_ref)
+ self.assertEqual(x.dtype, dtypes.float32_ref)
def body(i, x):
- self.assertEqual(x.dtype, dtypes.int32_ref)
+ self.assertEqual(x.dtype, dtypes.float32_ref)
return [i + 1, gen_array_ops.ref_identity(x)]
r = control_flow_ops.while_loop(c, body, [i, x], parallel_iterations=5)
@@ -2240,7 +2240,7 @@ class ControlFlowTest(test.TestCase):
variables.global_variables_initializer().run()
self.assertEqual(r[0].dtype, dtypes.int32)
- self.assertEqual(r[1].dtype, dtypes.int32_ref)
+ self.assertEqual(r[1].dtype, dtypes.float32_ref)
value_i, value_x, value_x_grad = sess.run(r + grad)
@@ -2443,6 +2443,63 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, y)[0]
self.assertEqual(388.0, r.eval())
+ def testWhileGradientWithNontrainablePath1(self):
+ q = variables.Variable([7., 8.])
+
+ def cond(_, y):
+ del y
+ return False
+
+ def body(x, _):
+ return x, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q)
+
+ _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
+ dy_dq, = gradients_impl.gradients(y, q)
+ self.assertIsNotNone(dy_dq)
+ with self.test_session() as sess:
+ sess.run(q.initializer)
+ self.assertAllClose([0., 0.], sess.run(dy_dq))
+
+ def testWhileGradientWithNontrainablePath2(self):
+ q = variables.Variable([7., 8.])
+
+ def cond(_, y):
+ return math_ops.equal(y, 0.)
+
+ def body(x, _):
+ zero = constant_op.constant(0, dtype=dtypes.int64)
+ return zero, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q)
+
+ _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
+ dy_dq, = gradients_impl.gradients(y, q)
+ self.assertIsNotNone(dy_dq)
+ with self.test_session() as sess:
+ sess.run(q.initializer)
+ self.assertAllClose([1., 1.], sess.run(dy_dq))
+
+ def testIssue16504(self):
+ c = constant_op.constant(np.arange(100), dtype=dtypes.float32)
+ w = variables.Variable(
+ initial_value=np.ones(100), dtype=dtypes.float32) / 100
+ k = variables.Variable(0, dtype=dtypes.int32)
+ chg_w = constant_op.constant(np.inf, dtype=dtypes.float32)
+
+ def cond(k, _, chg_w):
+ return math_ops.logical_and(k < 10, chg_w > 1e-3)
+
+ def body(k, w, chg_w):
+ grad, = gradients_impl.gradients(-math_ops.reduce_sum(w * c), w)
+ w_n = w * math_ops.exp(-0.1 * grad)
+ w_n /= math_ops.reduce_sum(w_n)
+ chg_w = (
+ math_ops.reduce_sum(math_ops.abs(w_n - w)) / math_ops.reduce_sum(
+ math_ops.abs(w)))
+ return k + 1, w_n, chg_w
+
+ _, w, _ = control_flow_ops.while_loop(cond, body, [k, w, chg_w])
+ grad, = gradients_impl.gradients(w, c)
+ self.assertIsNotNone(grad)
+
def testStopGradMultiFlows(self):
with self.test_session():