diff options
author | 2018-05-07 14:29:03 -0700 | |
---|---|---|
committer | 2018-05-07 17:15:37 -0700 | |
commit | dfae6ff29e95345c7c6c0ef50fd5f45bd458cfdc (patch) | |
tree | 0164da6c20fb0d74e16ef941acd57e6642730e9f /tensorflow/python/ops/gradients_test.py | |
parent | cd065ca7be11a4c87c9a5e68271cbc2d9aaaa260 (diff) |
Fix resource variable in cond gradient.
PiperOrigin-RevId: 195722449
Diffstat (limited to 'tensorflow/python/ops/gradients_test.py')
-rw-r--r-- | tensorflow/python/ops/gradients_test.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 5e8b8822ef..e729950201 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -944,6 +944,21 @@ class CustomGradientTest(test_util.TensorFlowTestCase): # Smoke test to ensure numpy inputs are accepted F(x) + def testRVGradientsDynamicCond(self): + with self.test_session(): + alpha = resource_variable_ops.ResourceVariable( + np.random.random((1,)), + dtype="float32") + + conditional = array_ops.placeholder_with_default(True, shape=()) + output = control_flow_ops.cond( + conditional, lambda: alpha * 2, lambda: alpha * 3) + + g, = gradients_impl.gradients(output, alpha) + variables.global_variables_initializer().run() + self.assertAllEqual(g.eval(), [2.0]) + self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0]) + if __name__ == "__main__": googletest.main() |