aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-16 13:31:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-16 13:34:47 -0800
commit1bde8c61e14b1c42db20a1b07684ea2534cfdf01 (patch)
tree96b49f2ed4f6dc5054c3fa47707dd295ae670dd6
parent6e5ca37827d95b12c4712cf237ec2f8124ed885c (diff)
Automated g4 rollback of changelist 185623948
PiperOrigin-RevId: 186038783
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc3
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h8
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc27
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc16
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc1
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc212
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc63
11 files changed, 270 insertions, 76 deletions
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index cd983bc03e..c812df4235 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -729,7 +729,8 @@ class CopyRemover {
// has a different operand (the operand of the elided copy).
for (const HloUse* copy_use : copy_value_node->uses) {
operand_node->uses.push_back(copy_use);
- if (copy_use->instruction->opcode() == HloOpcode::kCopy) {
+ if (copy_use->instruction->opcode() == HloOpcode::kCopy &&
+ ContainsKey(copy_map_, copy_use->instruction)) {
copy_map_.at(copy_use->instruction).src = operand_node;
}
}
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index cde5877e29..a2d13c013c 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -225,6 +225,7 @@ Status HeapSimulator::RunComputation(
// sub-computations will never be run concurrently.
if (module_sequence_ != nullptr) {
if (instruction->opcode() == HloOpcode::kCall ||
+ instruction->opcode() == HloOpcode::kConditional ||
instruction->opcode() == HloOpcode::kWhile) {
for (const HloComputation* called_computation :
instruction->called_computations()) {
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 5432419e4a..21e6b2ca73 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -509,13 +509,14 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
"Can't deep copy instruction %s: instruction is not in computation %s",
instruction->name().c_str(), name().c_str());
}
-
if (indices_to_copy != nullptr &&
!ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
return FailedPrecondition(
"Can't deep copy instruction %s: given shape tree of indices to copy "
- "has incompatible shape",
- instruction->name().c_str());
+ "has incompatible shapes: %s vs. %s",
+ instruction->name().c_str(),
+ ShapeUtil::HumanString(instruction->shape()).c_str(),
+ ShapeUtil::HumanString(indices_to_copy->shape()).c_str());
}
ShapeIndex index;
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 061c59abe5..39d864efcb 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -77,6 +77,14 @@ class HloComputation {
return last_added_instruction_;
}
+ Status ForEachInstruction(
+ const std::function<Status(const HloInstruction*)>& func) const {
+ for (const auto& instruction : instructions_) {
+ TF_RETURN_IF_ERROR(func(instruction.get()));
+ }
+ return Status::OK();
+ }
+
private:
const string name_;
HloInstruction* last_added_instruction_;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 0d9912d07d..d719ff857d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1371,7 +1371,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
break;
case HloOpcode::kRecv:
CHECK_EQ(new_operands.size(), 0);
- clone = CreateRecv(shape, channel_id());
+ // The shape is a tuple, but CreateRecv() wants the raw data shape.
+ clone =
+ CreateRecv(ShapeUtil::GetTupleElementShape(shape, 0), channel_id());
break;
case HloOpcode::kRecvDone:
CHECK_EQ(new_operands.size(), 1);
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 60270b0595..cb2fe9f874 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -145,6 +145,21 @@ void HloModule::ReplaceComputations(
}
break;
}
+ case HloOpcode::kConditional: {
+ HloComputation* new_true_computation =
+ tensorflow::gtl::FindWithDefault(
+ replacements, instruction->true_computation(), nullptr);
+ if (new_true_computation != nullptr) {
+ instruction->set_true_computation(new_true_computation);
+ }
+ HloComputation* new_false_computation =
+ tensorflow::gtl::FindWithDefault(
+ replacements, instruction->false_computation(), nullptr);
+ if (new_false_computation != nullptr) {
+ instruction->set_false_computation(new_false_computation);
+ }
+ break;
+ }
case HloOpcode::kSelectAndScatter: {
HloComputation* new_select = tensorflow::gtl::FindWithDefault(
replacements, instruction->select(), nullptr);
@@ -563,6 +578,18 @@ std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
return module;
}
+HloComputation* HloModule::DeepCloneComputation(HloComputation* computation) {
+ HloComputation* clone = AddEmbeddedComputation(computation->Clone("", this));
+ TF_CHECK_OK(
+ clone->root_instruction()->Accept([this](HloInstruction* instruction) {
+ instruction->ReplaceCalledComputations([this](HloComputation* callee) {
+ return DeepCloneComputation(callee);
+ });
+ return Status::OK();
+ }));
+ return clone;
+}
+
uint64 HloModule::RandomNew64() const {
tensorflow::mutex_lock l(rng_mutex_);
return rng_();
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 4bfe8d89ce..06d92f94fd 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -85,6 +85,10 @@ class HloModule {
// Returns a deep copy of this module including all computations.
std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
+ // Performs a deep clone of the computation, by recursively cloning all
+ // the called computations as well.
+ HloComputation* DeepCloneComputation(HloComputation* computation);
+
// Return a pointer to the entry computation of the module..
const HloComputation* entry_computation() const {
CHECK_NE(nullptr, entry_computation_);
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 68e3c9618c..1b24d8da9e 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -186,6 +186,22 @@ bool HloOrdering::UseIsBeforeValueDefinition(
}
}
+ if (use.instruction->opcode() == HloOpcode::kConditional) {
+ const HloInstruction* conditional = use.instruction;
+ if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
+ conditional->true_computation())) {
+ VLOG(4) << " use is conditional " << use.instruction->name()
+ << " and def is in TRUE computation";
+ return true;
+ }
+ if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
+ conditional->false_computation())) {
+ VLOG(4) << " use is conditional " << use.instruction->name()
+ << " and def is in FALSE computation";
+ return true;
+ }
+ }
+
VLOG(4) << " use is not before value";
return false;
}
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index c6b4dc0368..98b8d34be1 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -60,6 +60,7 @@ bool IsRematerializable(const HloInstruction* instruction) {
switch (instruction->opcode()) {
case HloOpcode::kCall:
case HloOpcode::kConstant:
+ case HloOpcode::kConditional:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kCustomCall:
case HloOpcode::kParameter:
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index fce135ef61..0668f66051 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -53,6 +53,83 @@ limitations under the License.
namespace xla {
+// For now moving only one API here, but we should have a single top level
+// anonymous namespace, instead of three or four spread all over this file.
+namespace {
+
+// Creates and returns a copy of the given instruction with a different
+// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple
+// instruction producing the copy is returned.
+StatusOr<HloInstruction*> CreateCopyWithNewLayout(
+ const Shape& shape_with_layout, HloInstruction* instruction) {
+ TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
+ DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()))
+ << ShapeUtil::HumanString(shape_with_layout) << " "
+ << ShapeUtil::HumanString(instruction->shape())
+ << " instruction: " << instruction->ToString();
+
+ if (ShapeUtil::IsTuple(instruction->shape())) {
+ // Deep-copy tuples.
+ std::vector<HloInstruction*> element_copies;
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
+ ++i) {
+ HloInstruction* gte = instruction->parent()->AddInstruction(
+ HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction,
+ i));
+
+ // Recurse to copy each elements.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * element_copy,
+ CreateCopyWithNewLayout(
+ ShapeUtil::GetSubshape(shape_with_layout, {i}), gte));
+ element_copies.push_back(element_copy);
+ }
+ // Gather element copies into a tuple with a new Tuple instruction.
+ HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
+ HloInstruction::CreateTuple(element_copies));
+ LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
+ TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
+ shape_with_layout, tuple_copy->mutable_shape()));
+ return tuple_copy;
+ } else if (ShapeUtil::IsArray(instruction->shape())) {
+ HloInstruction* copy =
+ instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
+ instruction->shape(), HloOpcode::kCopy, instruction));
+ LayoutUtil::ClearLayout(copy->mutable_shape());
+ TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
+ shape_with_layout, copy->mutable_shape()));
+
+ return copy;
+ } else {
+ return FailedPrecondition(
+ "Can only copy array and tuple shaped instructions");
+ }
+}
+
+// Creates a copy of the given operand if the operand's layout does not match
+// the given layout. This copy replaces the use in the given instruction. Tuple
+// operands will be deep-copied.
+Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
+ HloInstruction* instruction,
+ int64 operand_no) {
+ HloInstruction* operand = instruction->mutable_operand(operand_no);
+ TF_RET_CHECK(operand_layout.LayoutIsSet());
+ TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
+
+ if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) {
+ // Operand layout already matches our constraint. Nothing to do.
+ return Status::OK();
+ }
+
+ TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
+ CreateCopyWithNewLayout(operand_layout.shape(), operand));
+
+ return instruction->ReplaceOperandWith(operand_no, operand_copy);
+}
+
+} // namespace
+
std::ostream& operator<<(std::ostream& out,
const LayoutConstraint& constraint) {
out << constraint.ToString();
@@ -512,6 +589,36 @@ Status LayoutAssignment::AddMandatoryConstraints(
body_layout.result_shape(), instruction));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
body_layout.result_shape(), instruction, 0));
+ } else if (instruction->opcode() == HloOpcode::kConditional) {
+ // The layout of the true and false computations must match, and must
+ // be the layout of the kConditional instruction.
+ TF_RET_CHECK(instruction->operand_count() == 3);
+
+ HloComputation* true_computation = instruction->true_computation();
+ HloComputation* false_computation = instruction->false_computation();
+ const HloInstruction* true_operand = instruction->operand(1);
+ const HloInstruction* false_operand = instruction->operand(2);
+
+ TF_RET_CHECK(true_computation->num_parameters() == 1);
+ TF_RET_CHECK(false_computation->num_parameters() == 1);
+ ComputationLayout& true_computation_layout =
+ FindOrDie(computation_layouts_, true_computation);
+ ComputationLayout& false_computation_layout =
+ FindOrDie(computation_layouts_, false_computation);
+
+ DCHECK(ShapeUtil::Compatible(true_operand->shape(),
+ true_computation_layout.parameter_shape(0)));
+ DCHECK(ShapeUtil::Compatible(
+ false_operand->shape(), false_computation_layout.parameter_shape(0)));
+
+ TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
+ true_computation_layout.result_shape(), instruction));
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
+ true_computation_layout.parameter_shape(0), instruction, 1,
+ /*mandatory=*/true));
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
+ false_computation_layout.parameter_shape(0), instruction, 2,
+ /*mandatory=*/true));
} else if (instruction->opcode() == HloOpcode::kCustomCall) {
if (!CustomCallRequiresMajorFirstLayout(instruction)) {
continue;
@@ -598,6 +705,33 @@ Status CheckWhileLayout(HloInstruction* while_inst,
return Status::OK();
}
+Status CheckConditionalLayout(
+ HloInstruction* instruction,
+ const ComputationLayout& true_computation_layout,
+ const ComputationLayout& false_computation_layout) {
+ HloComputation* true_computation = instruction->true_computation();
+ HloComputation* false_computation = instruction->false_computation();
+ const HloInstruction* true_operand = instruction->operand(1);
+ const HloInstruction* false_operand = instruction->operand(2);
+
+ TF_RET_CHECK(true_computation_layout.result_layout() ==
+ false_computation_layout.result_layout());
+ TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape(
+ instruction->shape()));
+ TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape(
+ true_computation->root_instruction()->shape()));
+ TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape(
+ instruction->shape()));
+ TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape(
+ false_computation->root_instruction()->shape()));
+ TF_RET_CHECK(true_computation_layout.parameter_layout(0).MatchesLayoutInShape(
+ true_operand->shape()));
+ TF_RET_CHECK(
+ false_computation_layout.parameter_layout(0).MatchesLayoutInShape(
+ false_operand->shape()));
+ return Status::OK();
+}
+
// Fusion parameters must match the layout of the fusion instructions operands,
// and the root of the fusion expression must match the layout of the fusion
// instruction.
@@ -710,6 +844,13 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
FindOrDie(computation_layouts_, instruction->while_condition()),
FindOrDie(computation_layouts_, instruction->while_body())));
break;
+ case HloOpcode::kConditional:
+ TF_RETURN_IF_ERROR(CheckConditionalLayout(
+ instruction,
+ FindOrDie(computation_layouts_, instruction->true_computation()),
+ FindOrDie(computation_layouts_,
+ instruction->false_computation())));
+ break;
default:
break;
}
@@ -1165,77 +1306,6 @@ StatusOr<Layout> InferArrayLayout(
return *first_buffer_layout;
}
-// Creates and returns a copy of the given instruction with a different
-// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple
-// instruction producing the copy is returned.
-StatusOr<HloInstruction*> CreateCopyWithNewLayout(
- const Shape& shape_with_layout, HloInstruction* instruction) {
- TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
- DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()))
- << ShapeUtil::HumanString(shape_with_layout) << " "
- << ShapeUtil::HumanString(instruction->shape())
- << " instruction: " << instruction->ToString();
-
- if (ShapeUtil::IsTuple(instruction->shape())) {
- // Deep-copy tuples.
- std::vector<HloInstruction*> element_copies;
- for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
- ++i) {
- HloInstruction* gte = instruction->parent()->AddInstruction(
- HloInstruction::CreateGetTupleElement(
- ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction,
- i));
-
- // Recurse to copy each elements.
- TF_ASSIGN_OR_RETURN(
- HloInstruction * element_copy,
- CreateCopyWithNewLayout(
- ShapeUtil::GetSubshape(shape_with_layout, {i}), gte));
- element_copies.push_back(element_copy);
- }
- // Gather element copies into a tuple with a new Tuple instruction.
- HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
- HloInstruction::CreateTuple(element_copies));
- LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
- TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
- shape_with_layout, tuple_copy->mutable_shape()));
- return tuple_copy;
- } else if (ShapeUtil::IsArray(instruction->shape())) {
- HloInstruction* copy =
- instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
- instruction->shape(), HloOpcode::kCopy, instruction));
- LayoutUtil::ClearLayout(copy->mutable_shape());
- TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
- shape_with_layout, copy->mutable_shape()));
-
- return copy;
- } else {
- return FailedPrecondition(
- "Can only copy array and tuple shaped instructions");
- }
-}
-
-// Creates a copy of the given operand if the operand's layout does not match
-// the given layout. This copy replaces the use in the given instruction. Tuple
-// operands will be deep-copied.
-Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
- HloInstruction* instruction,
- int64 operand_no) {
- HloInstruction* operand = instruction->mutable_operand(operand_no);
- TF_RET_CHECK(operand_layout.LayoutIsSet());
- TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
-
- if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) {
- // Operand layout already matches our constraint. Nothing to do.
- return Status::OK();
- }
-
- TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
- CreateCopyWithNewLayout(operand_layout.shape(), operand));
-
- return instruction->ReplaceOperandWith(operand_no, operand_copy);
-}
-
// For fusion instructions, set the layout of each fused parameter instruction
// to match the layout of its corresponding fusion instruction operand. Also,
// set the layout of the fused root to match the layout of the fusion
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index e269a13459..dd0fba2758 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -658,5 +658,68 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
ElementsAre(2, 1, 0));
}
+TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
+ auto builder = HloComputation::Builder(TestName());
+ auto module = CreateNewModule();
+ Shape shape = ShapeUtil::MakeShape(F32, {128, 8});
+ Shape tshape = ShapeUtil::MakeTupleShape({shape, shape});
+ Shape result_tshape = ShapeUtil::MakeTupleShape({shape});
+
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param0"));
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, shape, "param1"));
+ auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
+ 2, ShapeUtil::MakeShape(PRED, {}), "param2"));
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
+
+ auto true_builder = HloComputation::Builder(TestName() + "_TrueBranch");
+ {
+ auto param = true_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tshape, "param"));
+ auto gte0 = true_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, param, 0));
+ auto gte1 = true_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, param, 1));
+ auto add = true_builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, gte0, gte1));
+ true_builder.AddInstruction(HloInstruction::CreateTuple({add}));
+ }
+ HloComputation* true_computation =
+ module->AddEmbeddedComputation(true_builder.Build());
+
+ auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch");
+ {
+ Shape xshape = ShapeUtil::MakeShapeWithLayout(F32, {128, 8}, {0, 1});
+ false_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tshape, "param"));
+ // Using infeed as layout assignment does not mess up with it.
+ auto infeed =
+ false_builder.AddInstruction(HloInstruction::CreateInfeed(xshape, ""));
+ false_builder.AddInstruction(HloInstruction::CreateTuple({infeed}));
+ }
+ HloComputation* false_computation =
+ module->AddEmbeddedComputation(false_builder.Build());
+ builder.AddInstruction(HloInstruction::CreateConditional(
+ result_tshape, pred, tuple, true_computation, tuple, false_computation));
+
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
+ ComputationLayout computation_layout(computation->ComputeProgramShape());
+
+ AssignLayouts(module.get(), &computation_layout);
+
+ const HloInstruction* true_root = true_computation->root_instruction();
+ const HloInstruction* false_root = false_computation->root_instruction();
+ EXPECT_THAT(true_root->opcode(), HloOpcode::kTuple);
+ EXPECT_THAT(false_root->opcode(), HloOpcode::kTuple);
+
+ const HloInstruction* true_result = true_root->operand(0);
+ const HloInstruction* false_result = false_root->operand(0);
+ EXPECT_TRUE(LayoutUtil::Equal(true_result->shape().layout(),
+ false_result->shape().layout()));
+ EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy);
+}
+
} // namespace
} // namespace xla