aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-10-05 13:24:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 13:28:41 -0700
commitf14287eabf69c57a2d2e044c311f2db1413cb6a5 (patch)
tree124c7ebf03dd0057ded4a54700c25a21d069d1dd /tensorflow/core
parentec451f5ab43467d7cb4ae7736f2de16331441e0b (diff)
Copy device from If op to the lowered ops.
Enable GPU tests for cond_v2. PiperOrigin-RevId: 215956220
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.cc9
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
index a02084f223..9306386117 100644
--- a/tensorflow/core/common_runtime/lower_if_op.cc
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -107,6 +107,8 @@ CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name,
then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()),
else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) {
TF_CHECK_OK(if_op_->input_node(0, &pred_));
+ then_call_builder_.Device(if_op_->requested_device());
+ else_call_builder_.Device(if_op_->requested_device());
}
Status CondBuilder::CreatePivotNodes() {
@@ -117,15 +119,18 @@ Status CondBuilder::CreatePivotNodes() {
NodeBuilder(NewName("switch_pred"), "Switch", graph_->op_registry())
.Input(NodeOut(pred_, 0))
.Input(NodeOut(pred_, 0))
+ .Device(if_op_->requested_device())
.Finalize(graph_, &switch_pred));
control_predecessor_ = switch_pred;
TF_RETURN_IF_ERROR(
NodeBuilder(NewName("pivot_f"), "Identity", graph_->op_registry())
.Input(switch_pred, kElseBranch)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &pivot_f_));
TF_RETURN_IF_ERROR(
NodeBuilder(NewName("pivot_t"), "Identity", graph_->op_registry())
.Input(switch_pred, kThenBranch)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &pivot_t_));
return Status::OK();
}
@@ -140,6 +145,7 @@ Status CondBuilder::AddInput(Node* src, int src_output) {
NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry())
.Input(src, src_output)
.Input(pred_, 0)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &input));
then_call_builder_.Input(input, kThenBranch);
else_call_builder_.Input(input, kElseBranch);
@@ -178,6 +184,7 @@ Status CondBuilder::AddOutputs() {
TF_RETURN_IF_ERROR(
NodeBuilder(graph_->NewName("merge"), "Merge", graph_->op_registry())
.Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)})
+ .Device(if_op_->requested_device())
.Finalize(graph_, &merges[i]));
outputs_[i] = NodeOut(merges[i], 0);
}
@@ -218,7 +225,7 @@ Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib,
Status CondBuilder::BuildLoweredIfOutput() {
// Build the identity node output.
NodeBuilder ib(name_, "IdentityN");
- ib.Input(outputs_);
+ ib.Input(outputs_).Device(if_op_->requested_device());
return ib.Finalize(graph_, &lowered_if_output_);
}