diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-26 12:42:54 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-26 12:45:22 -0700 |
commit | f63750645826df65b05cad505546a86f0e347674 (patch) | |
tree | 8467d73780d74b0f7ef4c87f8866d3bf0a233254 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py | |
parent | 667077cbd2cc86c4a656233a2d5f579aa4caf1f1 (diff) |
For tf.gradients(), do not backpropagate through integer tensors.
All integer tensors are now considered constant with respect to all `xs`.
This fixes a bug in gradients through tf.while_loop.
PiperOrigin-RevId: 194438529
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 65 |
1 files changed, 61 insertions, 4 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 209411cf51..77e6f5f1a0 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -2222,14 +2222,14 @@ class ControlFlowTest(test.TestCase): def testWhileWithRefsWithGradients_1(self): with self.test_session() as sess: - x = variables.Variable(0)._ref() # pylint: disable=protected-access + x = variables.Variable(0.)._ref() # pylint: disable=protected-access i = constant_op.constant(0) c = lambda i, x: math_ops.less(i, 10) - self.assertEqual(x.dtype, dtypes.int32_ref) + self.assertEqual(x.dtype, dtypes.float32_ref) def body(i, x): - self.assertEqual(x.dtype, dtypes.int32_ref) + self.assertEqual(x.dtype, dtypes.float32_ref) return [i + 1, gen_array_ops.ref_identity(x)] r = control_flow_ops.while_loop(c, body, [i, x], parallel_iterations=5) @@ -2240,7 +2240,7 @@ class ControlFlowTest(test.TestCase): variables.global_variables_initializer().run() self.assertEqual(r[0].dtype, dtypes.int32) - self.assertEqual(r[1].dtype, dtypes.int32_ref) + self.assertEqual(r[1].dtype, dtypes.float32_ref) value_i, value_x, value_x_grad = sess.run(r + grad) @@ -2443,6 +2443,63 @@ class ControlFlowTest(test.TestCase): r = gradients_impl.gradients(r, y)[0] self.assertEqual(388.0, r.eval()) + def testWhileGradientWithNontrainablePath1(self): + q = variables.Variable([7., 8.]) + + def cond(_, y): + del y + return False + + def body(x, _): + return x, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q) + + _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.)) + dy_dq, = gradients_impl.gradients(y, q) + self.assertIsNotNone(dy_dq) + with self.test_session() as sess: + sess.run(q.initializer) + self.assertAllClose([0., 0.], sess.run(dy_dq)) + + def testWhileGradientWithNontrainablePath2(self): + q = variables.Variable([7., 8.]) + + def cond(_, y): + return math_ops.equal(y, 0.) + + def body(x, _): + zero = constant_op.constant(0, dtype=dtypes.int64) + return zero, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q) + + _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.)) + dy_dq, = gradients_impl.gradients(y, q) + self.assertIsNotNone(dy_dq) + with self.test_session() as sess: + sess.run(q.initializer) + self.assertAllClose([1., 1.], sess.run(dy_dq)) + + def testIssue16504(self): + c = constant_op.constant(np.arange(100), dtype=dtypes.float32) + w = variables.Variable( + initial_value=np.ones(100), dtype=dtypes.float32) / 100 + k = variables.Variable(0, dtype=dtypes.int32) + chg_w = constant_op.constant(np.inf, dtype=dtypes.float32) + + def cond(k, _, chg_w): + return math_ops.logical_and(k < 10, chg_w > 1e-3) + + def body(k, w, chg_w): + grad, = gradients_impl.gradients(-math_ops.reduce_sum(w * c), w) + w_n = w * math_ops.exp(-0.1 * grad) + w_n /= math_ops.reduce_sum(w_n) + chg_w = ( + math_ops.reduce_sum(math_ops.abs(w_n - w)) / math_ops.reduce_sum( + math_ops.abs(w))) + return k + 1, w_n, chg_w + + _, w, _ = control_flow_ops.while_loop(cond, body, [k, w, chg_w]) + grad, = gradients_impl.gradients(w, c) + self.assertIsNotNone(grad) + def testStopGradMultiFlows(self): with self.test_session(): |