diff options
-rw-r--r-- | tensorflow/core/common_runtime/lower_if_op.cc | 9 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/cond_v2_test.py | 49 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 5 |
4 files changed, 33 insertions, 33 deletions
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc index a02084f223..9306386117 100644 --- a/tensorflow/core/common_runtime/lower_if_op.cc +++ b/tensorflow/core/common_runtime/lower_if_op.cc @@ -107,6 +107,8 @@ CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name, then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()), else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) { TF_CHECK_OK(if_op_->input_node(0, &pred_)); + then_call_builder_.Device(if_op_->requested_device()); + else_call_builder_.Device(if_op_->requested_device()); } Status CondBuilder::CreatePivotNodes() { @@ -117,15 +119,18 @@ Status CondBuilder::CreatePivotNodes() { NodeBuilder(NewName("switch_pred"), "Switch", graph_->op_registry()) .Input(NodeOut(pred_, 0)) .Input(NodeOut(pred_, 0)) + .Device(if_op_->requested_device()) .Finalize(graph_, &switch_pred)); control_predecessor_ = switch_pred; TF_RETURN_IF_ERROR( NodeBuilder(NewName("pivot_f"), "Identity", graph_->op_registry()) .Input(switch_pred, kElseBranch) + .Device(if_op_->requested_device()) .Finalize(graph_, &pivot_f_)); TF_RETURN_IF_ERROR( NodeBuilder(NewName("pivot_t"), "Identity", graph_->op_registry()) .Input(switch_pred, kThenBranch) + .Device(if_op_->requested_device()) .Finalize(graph_, &pivot_t_)); return Status::OK(); } @@ -140,6 +145,7 @@ Status CondBuilder::AddInput(Node* src, int src_output) { NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry()) .Input(src, src_output) .Input(pred_, 0) + .Device(if_op_->requested_device()) .Finalize(graph_, &input)); then_call_builder_.Input(input, kThenBranch); else_call_builder_.Input(input, kElseBranch); @@ -178,6 +184,7 @@ Status CondBuilder::AddOutputs() { TF_RETURN_IF_ERROR( NodeBuilder(graph_->NewName("merge"), "Merge", graph_->op_registry()) .Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)}) + .Device(if_op_->requested_device()) .Finalize(graph_, &merges[i])); outputs_[i] = NodeOut(merges[i], 0); } @@ -218,7 +225,7 @@ Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib, Status CondBuilder::BuildLoweredIfOutput() { // Build the identity node output. NodeBuilder ib(name_, "IdentityN"); - ib.Input(outputs_); + ib.Input(outputs_).Device(if_op_->requested_device()); return ib.Finalize(graph_, &lowered_if_output_); } diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index e055ef1c1b..4e8639dfc8 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -3255,7 +3255,7 @@ tf_py_test( tags = ["no_pip"], ) -tf_py_test( +cuda_py_test( name = "cond_v2_test", size = "medium", srcs = ["cond_v2_test.py"], @@ -3272,7 +3272,6 @@ tf_py_test( "//tensorflow/python:training", ], grpc_enabled = True, - tags = ["no_gpu"], # TODO(b/111656070) ) cuda_py_test( diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index 377c041675..ec875aae59 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -172,7 +172,7 @@ class CondV2Test(test.TestCase): self._testCond(true_fn, false_fn, [y]) def testNestedDefunInCond(self): - self.skipTest("b/110550782") + self.skipTest("b/117284369") x = constant_op.constant(1.0, name="x") y = constant_op.constant(2.0, name="y") @@ -198,7 +198,7 @@ class CondV2Test(test.TestCase): self._testCond(true_fn, false_fn, [y]) def testDoubleNestedDefunInCond(self): - self.skipTest("b/110550782") + self.skipTest("b/117284369") x = constant_op.constant(1.0, name="x") y = constant_op.constant(2.0, name="y") @@ -468,7 +468,6 @@ class CondV2Test(test.TestCase): }), [5., 0.]) def testBuildCondAndGradientInsideDefun(self): - self.skipTest("b/110550782") def build_graph(): pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer") @@ -502,29 +501,29 @@ class CondV2Test(test.TestCase): return grads, pred_outer, pred_inner - with ops.Graph().as_default(): + with ops.Graph().as_default(), self.session( + graph=ops.get_default_graph()) as sess: grads, pred_outer, pred_inner = build_graph() - with self.session(graph=ops.get_default_graph()) as sess: - self.assertSequenceEqual( - sess.run(grads, { - pred_outer: True, - pred_inner: True - }), [0., 0.]) - self.assertSequenceEqual( - sess.run(grads, { - pred_outer: True, - pred_inner: False - }), [0., 0.]) - self.assertSequenceEqual( - sess.run(grads, { - pred_outer: False, - pred_inner: True - }), [4., 2.]) - self.assertSequenceEqual( - sess.run(grads, { - pred_outer: False, - pred_inner: False - }), [5., 0.]) + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: True, + pred_inner: True + }), [0., 0.]) + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: True, + pred_inner: False + }), [0., 0.]) + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: False, + pred_inner: True + }), [4., 2.]) + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: False, + pred_inner: False + }), [5., 0.]) def testSecondDerivative(self): with self.cached_session() as sess: diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index c7e89dd5f9..7fae5249aa 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -23,7 +23,6 @@ from __future__ import print_function import collections import math import time -import unittest import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin @@ -661,7 +660,6 @@ class ControlFlowTest(test.TestCase): r = control_flow_ops.cond(pred, fn1, fn2) sess.run(r) - @test_util.disable_control_flow_v2("b/113346829 (gpu failure)") def testCondGrad_1(self): graph = ops.Graph() with graph.as_default(): @@ -3424,9 +3422,6 @@ class EagerTest(test.TestCase): # TODO(b/117279927): Re-enable once msan failure is fixed. def DISABLED_testCondInDefun(self): - if "GPU" in [d.device_type for d in device_lib.list_local_devices()]: - return unittest.skip("b/113346829 (gpu failure)") - with context.eager_mode(): @eager_function.defun |