aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.cc9
-rw-r--r--tensorflow/python/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py49
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py5
4 files changed, 33 insertions, 33 deletions
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
index a02084f223..9306386117 100644
--- a/tensorflow/core/common_runtime/lower_if_op.cc
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -107,6 +107,8 @@ CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name,
then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()),
else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) {
TF_CHECK_OK(if_op_->input_node(0, &pred_));
+ then_call_builder_.Device(if_op_->requested_device());
+ else_call_builder_.Device(if_op_->requested_device());
}
Status CondBuilder::CreatePivotNodes() {
@@ -117,15 +119,18 @@ Status CondBuilder::CreatePivotNodes() {
NodeBuilder(NewName("switch_pred"), "Switch", graph_->op_registry())
.Input(NodeOut(pred_, 0))
.Input(NodeOut(pred_, 0))
+ .Device(if_op_->requested_device())
.Finalize(graph_, &switch_pred));
control_predecessor_ = switch_pred;
TF_RETURN_IF_ERROR(
NodeBuilder(NewName("pivot_f"), "Identity", graph_->op_registry())
.Input(switch_pred, kElseBranch)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &pivot_f_));
TF_RETURN_IF_ERROR(
NodeBuilder(NewName("pivot_t"), "Identity", graph_->op_registry())
.Input(switch_pred, kThenBranch)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &pivot_t_));
return Status::OK();
}
@@ -140,6 +145,7 @@ Status CondBuilder::AddInput(Node* src, int src_output) {
NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry())
.Input(src, src_output)
.Input(pred_, 0)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &input));
then_call_builder_.Input(input, kThenBranch);
else_call_builder_.Input(input, kElseBranch);
@@ -178,6 +184,7 @@ Status CondBuilder::AddOutputs() {
TF_RETURN_IF_ERROR(
NodeBuilder(graph_->NewName("merge"), "Merge", graph_->op_registry())
.Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)})
+ .Device(if_op_->requested_device())
.Finalize(graph_, &merges[i]));
outputs_[i] = NodeOut(merges[i], 0);
}
@@ -218,7 +225,7 @@ Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib,
Status CondBuilder::BuildLoweredIfOutput() {
// Build the identity node output.
NodeBuilder ib(name_, "IdentityN");
- ib.Input(outputs_);
+ ib.Input(outputs_).Device(if_op_->requested_device());
return ib.Finalize(graph_, &lowered_if_output_);
}
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index e055ef1c1b..4e8639dfc8 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -3255,7 +3255,7 @@ tf_py_test(
tags = ["no_pip"],
)
-tf_py_test(
+cuda_py_test(
name = "cond_v2_test",
size = "medium",
srcs = ["cond_v2_test.py"],
@@ -3272,7 +3272,6 @@ tf_py_test(
"//tensorflow/python:training",
],
grpc_enabled = True,
- tags = ["no_gpu"], # TODO(b/111656070)
)
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 377c041675..ec875aae59 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -172,7 +172,7 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testNestedDefunInCond(self):
- self.skipTest("b/110550782")
+ self.skipTest("b/117284369")
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -198,7 +198,7 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testDoubleNestedDefunInCond(self):
- self.skipTest("b/110550782")
+ self.skipTest("b/117284369")
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -468,7 +468,6 @@ class CondV2Test(test.TestCase):
}), [5., 0.])
def testBuildCondAndGradientInsideDefun(self):
- self.skipTest("b/110550782")
def build_graph():
pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer")
@@ -502,29 +501,29 @@ class CondV2Test(test.TestCase):
return grads, pred_outer, pred_inner
- with ops.Graph().as_default():
+ with ops.Graph().as_default(), self.session(
+ graph=ops.get_default_graph()) as sess:
grads, pred_outer, pred_inner = build_graph()
- with self.session(graph=ops.get_default_graph()) as sess:
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: True,
- pred_inner: True
- }), [0., 0.])
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: True,
- pred_inner: False
- }), [0., 0.])
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: False,
- pred_inner: True
- }), [4., 2.])
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: False,
- pred_inner: False
- }), [5., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: True,
+ pred_inner: True
+ }), [0., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: True,
+ pred_inner: False
+ }), [0., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: False,
+ pred_inner: True
+ }), [4., 2.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: False,
+ pred_inner: False
+ }), [5., 0.])
def testSecondDerivative(self):
with self.cached_session() as sess:
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index c7e89dd5f9..7fae5249aa 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -23,7 +23,6 @@ from __future__ import print_function
import collections
import math
import time
-import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -661,7 +660,6 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
sess.run(r)
- @test_util.disable_control_flow_v2("b/113346829 (gpu failure)")
def testCondGrad_1(self):
graph = ops.Graph()
with graph.as_default():
@@ -3424,9 +3422,6 @@ class EagerTest(test.TestCase):
# TODO(b/117279927): Re-enable once msan failure is fixed.
def DISABLED_testCondInDefun(self):
- if "GPU" in [d.device_type for d in device_lib.list_local_devices()]:
- return unittest.skip("b/113346829 (gpu failure)")
-
with context.eager_mode():
@eager_function.defun