From f14287eabf69c57a2d2e044c311f2db1413cb6a5 Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Fri, 5 Oct 2018 13:24:34 -0700 Subject: Copy device from If op to the lowered ops. Enable GPU tests for cond_v2. PiperOrigin-RevId: 215956220 --- tensorflow/core/common_runtime/lower_if_op.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'tensorflow/core') diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc index a02084f223..9306386117 100644 --- a/tensorflow/core/common_runtime/lower_if_op.cc +++ b/tensorflow/core/common_runtime/lower_if_op.cc @@ -107,6 +107,8 @@ CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name, then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()), else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) { TF_CHECK_OK(if_op_->input_node(0, &pred_)); + then_call_builder_.Device(if_op_->requested_device()); + else_call_builder_.Device(if_op_->requested_device()); } Status CondBuilder::CreatePivotNodes() { @@ -117,15 +119,18 @@ Status CondBuilder::CreatePivotNodes() { NodeBuilder(NewName("switch_pred"), "Switch", graph_->op_registry()) .Input(NodeOut(pred_, 0)) .Input(NodeOut(pred_, 0)) + .Device(if_op_->requested_device()) .Finalize(graph_, &switch_pred)); control_predecessor_ = switch_pred; TF_RETURN_IF_ERROR( NodeBuilder(NewName("pivot_f"), "Identity", graph_->op_registry()) .Input(switch_pred, kElseBranch) + .Device(if_op_->requested_device()) .Finalize(graph_, &pivot_f_)); TF_RETURN_IF_ERROR( NodeBuilder(NewName("pivot_t"), "Identity", graph_->op_registry()) .Input(switch_pred, kThenBranch) + .Device(if_op_->requested_device()) .Finalize(graph_, &pivot_t_)); return Status::OK(); } @@ -140,6 +145,7 @@ Status CondBuilder::AddInput(Node* src, int src_output) { NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry()) .Input(src, src_output) .Input(pred_, 0) + .Device(if_op_->requested_device()) .Finalize(graph_, &input)); then_call_builder_.Input(input, kThenBranch); else_call_builder_.Input(input, kElseBranch); @@ -178,6 +184,7 @@ Status CondBuilder::AddOutputs() { TF_RETURN_IF_ERROR( NodeBuilder(graph_->NewName("merge"), "Merge", graph_->op_registry()) .Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)}) + .Device(if_op_->requested_device()) .Finalize(graph_, &merges[i])); outputs_[i] = NodeOut(merges[i], 0); } @@ -218,7 +225,7 @@ Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib, Status CondBuilder::BuildLoweredIfOutput() { // Build the identity node output. NodeBuilder ib(name_, "IdentityN"); - ib.Input(outputs_); + ib.Input(outputs_).Device(if_op_->requested_device()); return ib.Finalize(graph_, &lowered_if_output_); } -- cgit v1.2.3