aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/gradients_test.py
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-05-07 14:29:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-07 17:15:37 -0700
commitdfae6ff29e95345c7c6c0ef50fd5f45bd458cfdc (patch)
tree0164da6c20fb0d74e16ef941acd57e6642730e9f /tensorflow/python/ops/gradients_test.py
parentcd065ca7be11a4c87c9a5e68271cbc2d9aaaa260 (diff)
Fix resource variable in cond gradient.
PiperOrigin-RevId: 195722449
Diffstat (limited to 'tensorflow/python/ops/gradients_test.py')
-rw-r--r--tensorflow/python/ops/gradients_test.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 5e8b8822ef..e729950201 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -944,6 +944,21 @@ class CustomGradientTest(test_util.TensorFlowTestCase):
# Smoke test to ensure numpy inputs are accepted
F(x)
+ def testRVGradientsDynamicCond(self):
+ with self.test_session():
+ alpha = resource_variable_ops.ResourceVariable(
+ np.random.random((1,)),
+ dtype="float32")
+
+ conditional = array_ops.placeholder_with_default(True, shape=())
+ output = control_flow_ops.cond(
+ conditional, lambda: alpha * 2, lambda: alpha * 3)
+
+ g, = gradients_impl.gradients(output, alpha)
+ variables.global_variables_initializer().run()
+ self.assertAllEqual(g.eval(), [2.0])
+ self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0])
+
if __name__ == "__main__":
googletest.main()