diff options
author | Saurabh Saxena <srbs@google.com> | 2018-10-05 13:24:34 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-05 13:28:41 -0700 |
commit | f14287eabf69c57a2d2e044c311f2db1413cb6a5 (patch) | |
tree | 124c7ebf03dd0057ded4a54700c25a21d069d1dd /tensorflow/python | |
parent | ec451f5ab43467d7cb4ae7736f2de16331441e0b (diff) |
Copy device from If op to the lowered ops.
Enable GPU tests for cond_v2.
PiperOrigin-RevId: 215956220
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/cond_v2_test.py | 49 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 5 |
3 files changed, 25 insertions, 32 deletions
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index e055ef1c1b..4e8639dfc8 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -3255,7 +3255,7 @@ tf_py_test( tags = ["no_pip"], ) -tf_py_test( +cuda_py_test( name = "cond_v2_test", size = "medium", srcs = ["cond_v2_test.py"], @@ -3272,7 +3272,6 @@ tf_py_test( "//tensorflow/python:training", ], grpc_enabled = True, - tags = ["no_gpu"], # TODO(b/111656070) ) cuda_py_test( diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index 377c041675..ec875aae59 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -172,7 +172,7 @@ class CondV2Test(test.TestCase): self._testCond(true_fn, false_fn, [y]) def testNestedDefunInCond(self): - self.skipTest("b/110550782") + self.skipTest("b/117284369") x = constant_op.constant(1.0, name="x") y = constant_op.constant(2.0, name="y") @@ -198,7 +198,7 @@ class CondV2Test(test.TestCase): self._testCond(true_fn, false_fn, [y]) def testDoubleNestedDefunInCond(self): - self.skipTest("b/110550782") + self.skipTest("b/117284369") x = constant_op.constant(1.0, name="x") y = constant_op.constant(2.0, name="y") @@ -468,7 +468,6 @@ class CondV2Test(test.TestCase): }), [5., 0.]) def testBuildCondAndGradientInsideDefun(self): - self.skipTest("b/110550782") def build_graph(): pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer") @@ -502,29 +501,29 @@ class CondV2Test(test.TestCase): return grads, pred_outer, pred_inner - with ops.Graph().as_default(): + with ops.Graph().as_default(), self.session( + graph=ops.get_default_graph()) as sess: grads, pred_outer, pred_inner = build_graph() - with self.session(graph=ops.get_default_graph()) as sess: - self.assertSequenceEqual( - sess.run(grads, { - pred_outer: True, - pred_inner: True - }), [0., 0.]) - self.assertSequenceEqual( - sess.run(grads, { - pred_outer: True, - pred_inner: False - }), [0., 0.]) - self.assertSequenceEqual( - sess.run(grads, { - pred_outer: False, - pred_inner: True - }), [4., 2.]) - self.assertSequenceEqual( - sess.run(grads, { - pred_outer: False, - pred_inner: False - }), [5., 0.]) + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: True, + pred_inner: True + }), [0., 0.]) + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: True, + pred_inner: False + }), [0., 0.]) + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: False, + pred_inner: True + }), [4., 2.]) + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: False, + pred_inner: False + }), [5., 0.]) def testSecondDerivative(self): with self.cached_session() as sess: diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index c7e89dd5f9..7fae5249aa 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -23,7 +23,6 @@ from __future__ import print_function import collections import math import time -import unittest import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin @@ -661,7 +660,6 @@ class ControlFlowTest(test.TestCase): r = control_flow_ops.cond(pred, fn1, fn2) sess.run(r) - @test_util.disable_control_flow_v2("b/113346829 (gpu failure)") def testCondGrad_1(self): graph = ops.Graph() with graph.as_default(): @@ -3424,9 +3422,6 @@ class EagerTest(test.TestCase): # TODO(b/117279927): Re-enable once msan failure is fixed. def DISABLED_testCondInDefun(self): - if "GPU" in [d.device_type for d in device_lib.list_local_devices()]: - return unittest.skip("b/113346829 (gpu failure)") - with context.eager_mode(): @eager_function.defun |