aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-10-05 13:24:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 13:28:41 -0700
commitf14287eabf69c57a2d2e044c311f2db1413cb6a5 (patch)
tree124c7ebf03dd0057ded4a54700c25a21d069d1dd /tensorflow/python
parentec451f5ab43467d7cb4ae7736f2de16331441e0b (diff)
Copy device from If op to the lowered ops.
Enable GPU tests for cond_v2. PiperOrigin-RevId: 215956220
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py49
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py5
3 files changed, 25 insertions, 32 deletions
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index e055ef1c1b..4e8639dfc8 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -3255,7 +3255,7 @@ tf_py_test(
tags = ["no_pip"],
)
-tf_py_test(
+cuda_py_test(
name = "cond_v2_test",
size = "medium",
srcs = ["cond_v2_test.py"],
@@ -3272,7 +3272,6 @@ tf_py_test(
"//tensorflow/python:training",
],
grpc_enabled = True,
- tags = ["no_gpu"], # TODO(b/111656070)
)
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 377c041675..ec875aae59 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -172,7 +172,7 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testNestedDefunInCond(self):
- self.skipTest("b/110550782")
+ self.skipTest("b/117284369")
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -198,7 +198,7 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testDoubleNestedDefunInCond(self):
- self.skipTest("b/110550782")
+ self.skipTest("b/117284369")
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -468,7 +468,6 @@ class CondV2Test(test.TestCase):
}), [5., 0.])
def testBuildCondAndGradientInsideDefun(self):
- self.skipTest("b/110550782")
def build_graph():
pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer")
@@ -502,29 +501,29 @@ class CondV2Test(test.TestCase):
return grads, pred_outer, pred_inner
- with ops.Graph().as_default():
+ with ops.Graph().as_default(), self.session(
+ graph=ops.get_default_graph()) as sess:
grads, pred_outer, pred_inner = build_graph()
- with self.session(graph=ops.get_default_graph()) as sess:
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: True,
- pred_inner: True
- }), [0., 0.])
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: True,
- pred_inner: False
- }), [0., 0.])
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: False,
- pred_inner: True
- }), [4., 2.])
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: False,
- pred_inner: False
- }), [5., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: True,
+ pred_inner: True
+ }), [0., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: True,
+ pred_inner: False
+ }), [0., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: False,
+ pred_inner: True
+ }), [4., 2.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: False,
+ pred_inner: False
+ }), [5., 0.])
def testSecondDerivative(self):
with self.cached_session() as sess:
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index c7e89dd5f9..7fae5249aa 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -23,7 +23,6 @@ from __future__ import print_function
import collections
import math
import time
-import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -661,7 +660,6 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
sess.run(r)
- @test_util.disable_control_flow_v2("b/113346829 (gpu failure)")
def testCondGrad_1(self):
graph = ops.Graph()
with graph.as_default():
@@ -3424,9 +3422,6 @@ class EagerTest(test.TestCase):
# TODO(b/117279927): Re-enable once msan failure is fixed.
def DISABLED_testCondInDefun(self):
- if "GPU" in [d.device_type for d in device_lib.list_local_devices()]:
- return unittest.skip("b/113346829 (gpu failure)")
-
with context.eager_mode():
@eager_function.defun