diff options
author | Saurabh Saxena <srbs@google.com> | 2018-10-05 13:24:34 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-05 13:28:41 -0700 |
commit | f14287eabf69c57a2d2e044c311f2db1413cb6a5 (patch) | |
tree | 124c7ebf03dd0057ded4a54700c25a21d069d1dd /tensorflow/core | |
parent | ec451f5ab43467d7cb4ae7736f2de16331441e0b (diff) |
Copy device from If op to the lowered ops.
Enable GPU tests for cond_v2.
PiperOrigin-RevId: 215956220
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/common_runtime/lower_if_op.cc | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc index a02084f223..9306386117 100644 --- a/tensorflow/core/common_runtime/lower_if_op.cc +++ b/tensorflow/core/common_runtime/lower_if_op.cc @@ -107,6 +107,8 @@ CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name, then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()), else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) { TF_CHECK_OK(if_op_->input_node(0, &pred_)); + then_call_builder_.Device(if_op_->requested_device()); + else_call_builder_.Device(if_op_->requested_device()); } Status CondBuilder::CreatePivotNodes() { @@ -117,15 +119,18 @@ Status CondBuilder::CreatePivotNodes() { NodeBuilder(NewName("switch_pred"), "Switch", graph_->op_registry()) .Input(NodeOut(pred_, 0)) .Input(NodeOut(pred_, 0)) + .Device(if_op_->requested_device()) .Finalize(graph_, &switch_pred)); control_predecessor_ = switch_pred; TF_RETURN_IF_ERROR( NodeBuilder(NewName("pivot_f"), "Identity", graph_->op_registry()) .Input(switch_pred, kElseBranch) + .Device(if_op_->requested_device()) .Finalize(graph_, &pivot_f_)); TF_RETURN_IF_ERROR( NodeBuilder(NewName("pivot_t"), "Identity", graph_->op_registry()) .Input(switch_pred, kThenBranch) + .Device(if_op_->requested_device()) .Finalize(graph_, &pivot_t_)); return Status::OK(); } @@ -140,6 +145,7 @@ Status CondBuilder::AddInput(Node* src, int src_output) { NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry()) .Input(src, src_output) .Input(pred_, 0) + .Device(if_op_->requested_device()) .Finalize(graph_, &input)); then_call_builder_.Input(input, kThenBranch); else_call_builder_.Input(input, kElseBranch); @@ -178,6 +184,7 @@ Status CondBuilder::AddOutputs() { TF_RETURN_IF_ERROR( NodeBuilder(graph_->NewName("merge"), "Merge", graph_->op_registry()) .Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)}) + .Device(if_op_->requested_device()) .Finalize(graph_, &merges[i])); outputs_[i] = NodeOut(merges[i], 0); } @@ -218,7 +225,7 @@ Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib, Status CondBuilder::BuildLoweredIfOutput() { // Build the identity node output. NodeBuilder ib(name_, "IdentityN"); - ib.Input(outputs_); + ib.Input(outputs_).Device(if_op_->requested_device()); return ib.Finalize(graph_, &lowered_if_output_); } |