diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-02-13 18:57:53 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-13 19:01:45 -0800 |
commit | 7325919e07e2ae45b3b5436db1dc9f26a51af6c6 (patch) | |
tree | bfa9931089ce6339846c44c8d6a3b44a625c6152 /tensorflow/compiler/xla | |
parent | e5840b71a2199ec4b1f04281a7c45cbb4157c510 (diff) |
Automated g4 rollback of changelist 185598764
PiperOrigin-RevId: 185623948
Diffstat (limited to 'tensorflow/compiler/xla')
9 files changed, 75 insertions, 251 deletions
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index c812df4235..cd983bc03e 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -729,8 +729,7 @@ class CopyRemover { // has a different operand (the operand of the elided copy). for (const HloUse* copy_use : copy_value_node->uses) { operand_node->uses.push_back(copy_use); - if (copy_use->instruction->opcode() == HloOpcode::kCopy && - ContainsKey(copy_map_, copy_use->instruction)) { + if (copy_use->instruction->opcode() == HloOpcode::kCopy) { copy_map_.at(copy_use->instruction).src = operand_node; } } diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index a2d13c013c..cde5877e29 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -225,7 +225,6 @@ Status HeapSimulator::RunComputation( // sub-computations will never be run concurrently. if (module_sequence_ != nullptr) { if (instruction->opcode() == HloOpcode::kCall || - instruction->opcode() == HloOpcode::kConditional || instruction->opcode() == HloOpcode::kWhile) { for (const HloComputation* called_computation : instruction->called_computations()) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 21e6b2ca73..5432419e4a 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -509,14 +509,13 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction( "Can't deep copy instruction %s: instruction is not in computation %s", instruction->name().c_str(), name().c_str()); } + if (indices_to_copy != nullptr && !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) { return FailedPrecondition( "Can't deep copy instruction %s: given shape tree of indices to copy " - "has incompatible shapes: %s vs. %s", - instruction->name().c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - ShapeUtil::HumanString(indices_to_copy->shape()).c_str()); + "has incompatible shape", + instruction->name().c_str()); } ShapeIndex index; diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 39d864efcb..061c59abe5 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -77,14 +77,6 @@ class HloComputation { return last_added_instruction_; } - Status ForEachInstruction( - const std::function<Status(const HloInstruction*)>& func) const { - for (const auto& instruction : instructions_) { - TF_RETURN_IF_ERROR(func(instruction.get())); - } - return Status::OK(); - } - private: const string name_; HloInstruction* last_added_instruction_; diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 1e0e17f22f..60270b0595 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -145,21 +145,6 @@ void HloModule::ReplaceComputations( } break; } - case HloOpcode::kConditional: { - HloComputation* new_true_computation = - tensorflow::gtl::FindWithDefault( - replacements, instruction->true_computation(), nullptr); - if (new_true_computation != nullptr) { - instruction->set_true_computation(new_true_computation); - } - HloComputation* new_false_computation = - tensorflow::gtl::FindWithDefault( - replacements, instruction->false_computation(), nullptr); - if (new_false_computation != nullptr) { - instruction->set_false_computation(new_false_computation); - } - break; - } case HloOpcode::kSelectAndScatter: { HloComputation* new_select = tensorflow::gtl::FindWithDefault( replacements, instruction->select(), nullptr); diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 1b24d8da9e..68e3c9618c 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -186,22 +186,6 @@ bool HloOrdering::UseIsBeforeValueDefinition( } } - if (use.instruction->opcode() == HloOpcode::kConditional) { - const HloInstruction* conditional = use.instruction; - if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), - conditional->true_computation())) { - VLOG(4) << " use is conditional " << use.instruction->name() - << " and def is in TRUE computation"; - return true; - } - if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), - conditional->false_computation())) { - VLOG(4) << " use is conditional " << use.instruction->name() - << " and def is in FALSE computation"; - return true; - } - } - VLOG(4) << " use is not before value"; return false; } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 98b8d34be1..c6b4dc0368 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -60,7 +60,6 @@ bool IsRematerializable(const HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kCall: case HloOpcode::kConstant: - case HloOpcode::kConditional: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: case HloOpcode::kParameter: diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 0668f66051..fce135ef61 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -53,83 +53,6 @@ limitations under the License. namespace xla { -// For now moving only one API here, but we should have a single top level -// anonymous namespace, instead of three or four spread all over this file. -namespace { - -// Creates and returns a copy of the given instruction with a different -// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple -// instruction producing the copy is returned. -StatusOr<HloInstruction*> CreateCopyWithNewLayout( - const Shape& shape_with_layout, HloInstruction* instruction) { - TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout)); - DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape())) - << ShapeUtil::HumanString(shape_with_layout) << " " - << ShapeUtil::HumanString(instruction->shape()) - << " instruction: " << instruction->ToString(); - - if (ShapeUtil::IsTuple(instruction->shape())) { - // Deep-copy tuples. - std::vector<HloInstruction*> element_copies; - for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); - ++i) { - HloInstruction* gte = instruction->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, - i)); - - // Recurse to copy each elements. - TF_ASSIGN_OR_RETURN( - HloInstruction * element_copy, - CreateCopyWithNewLayout( - ShapeUtil::GetSubshape(shape_with_layout, {i}), gte)); - element_copies.push_back(element_copy); - } - // Gather element copies into a tuple with a new Tuple instruction. - HloInstruction* tuple_copy = instruction->parent()->AddInstruction( - HloInstruction::CreateTuple(element_copies)); - LayoutUtil::ClearLayout(tuple_copy->mutable_shape()); - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( - shape_with_layout, tuple_copy->mutable_shape())); - return tuple_copy; - } else if (ShapeUtil::IsArray(instruction->shape())) { - HloInstruction* copy = - instruction->parent()->AddInstruction(HloInstruction::CreateUnary( - instruction->shape(), HloOpcode::kCopy, instruction)); - LayoutUtil::ClearLayout(copy->mutable_shape()); - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( - shape_with_layout, copy->mutable_shape())); - - return copy; - } else { - return FailedPrecondition( - "Can only copy array and tuple shaped instructions"); - } -} - -// Creates a copy of the given operand if the operand's layout does not match -// the given layout. This copy replaces the use in the given instruction. Tuple -// operands will be deep-copied. -Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, - HloInstruction* instruction, - int64 operand_no) { - HloInstruction* operand = instruction->mutable_operand(operand_no); - TF_RET_CHECK(operand_layout.LayoutIsSet()); - TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); - - if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { - // Operand layout already matches our constraint. Nothing to do. - return Status::OK(); - } - - TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, - CreateCopyWithNewLayout(operand_layout.shape(), operand)); - - return instruction->ReplaceOperandWith(operand_no, operand_copy); -} - -} // namespace - std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint) { out << constraint.ToString(); @@ -589,36 +512,6 @@ Status LayoutAssignment::AddMandatoryConstraints( body_layout.result_shape(), instruction)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( body_layout.result_shape(), instruction, 0)); - } else if (instruction->opcode() == HloOpcode::kConditional) { - // The layout of the true and false computations must match, and must - // be the layout of the kConditional instruction. - TF_RET_CHECK(instruction->operand_count() == 3); - - HloComputation* true_computation = instruction->true_computation(); - HloComputation* false_computation = instruction->false_computation(); - const HloInstruction* true_operand = instruction->operand(1); - const HloInstruction* false_operand = instruction->operand(2); - - TF_RET_CHECK(true_computation->num_parameters() == 1); - TF_RET_CHECK(false_computation->num_parameters() == 1); - ComputationLayout& true_computation_layout = - FindOrDie(computation_layouts_, true_computation); - ComputationLayout& false_computation_layout = - FindOrDie(computation_layouts_, false_computation); - - DCHECK(ShapeUtil::Compatible(true_operand->shape(), - true_computation_layout.parameter_shape(0))); - DCHECK(ShapeUtil::Compatible( - false_operand->shape(), false_computation_layout.parameter_shape(0))); - - TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( - true_computation_layout.result_shape(), instruction)); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - true_computation_layout.parameter_shape(0), instruction, 1, - /*mandatory=*/true)); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - false_computation_layout.parameter_shape(0), instruction, 2, - /*mandatory=*/true)); } else if (instruction->opcode() == HloOpcode::kCustomCall) { if (!CustomCallRequiresMajorFirstLayout(instruction)) { continue; @@ -705,33 +598,6 @@ Status CheckWhileLayout(HloInstruction* while_inst, return Status::OK(); } -Status CheckConditionalLayout( - HloInstruction* instruction, - const ComputationLayout& true_computation_layout, - const ComputationLayout& false_computation_layout) { - HloComputation* true_computation = instruction->true_computation(); - HloComputation* false_computation = instruction->false_computation(); - const HloInstruction* true_operand = instruction->operand(1); - const HloInstruction* false_operand = instruction->operand(2); - - TF_RET_CHECK(true_computation_layout.result_layout() == - false_computation_layout.result_layout()); - TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape( - instruction->shape())); - TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape( - true_computation->root_instruction()->shape())); - TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape( - instruction->shape())); - TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape( - false_computation->root_instruction()->shape())); - TF_RET_CHECK(true_computation_layout.parameter_layout(0).MatchesLayoutInShape( - true_operand->shape())); - TF_RET_CHECK( - false_computation_layout.parameter_layout(0).MatchesLayoutInShape( - false_operand->shape())); - return Status::OK(); -} - // Fusion parameters must match the layout of the fusion instructions operands, // and the root of the fusion expression must match the layout of the fusion // instruction. @@ -844,13 +710,6 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { FindOrDie(computation_layouts_, instruction->while_condition()), FindOrDie(computation_layouts_, instruction->while_body()))); break; - case HloOpcode::kConditional: - TF_RETURN_IF_ERROR(CheckConditionalLayout( - instruction, - FindOrDie(computation_layouts_, instruction->true_computation()), - FindOrDie(computation_layouts_, - instruction->false_computation()))); - break; default: break; } @@ -1306,6 +1165,77 @@ StatusOr<Layout> InferArrayLayout( return *first_buffer_layout; } +// Creates and returns a copy of the given instruction with a different +// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple +// instruction producing the copy is returned. +StatusOr<HloInstruction*> CreateCopyWithNewLayout( + const Shape& shape_with_layout, HloInstruction* instruction) { + TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout)); + DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape())) + << ShapeUtil::HumanString(shape_with_layout) << " " + << ShapeUtil::HumanString(instruction->shape()) + << " instruction: " << instruction->ToString(); + + if (ShapeUtil::IsTuple(instruction->shape())) { + // Deep-copy tuples. + std::vector<HloInstruction*> element_copies; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); + ++i) { + HloInstruction* gte = instruction->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, + i)); + + // Recurse to copy each elements. + TF_ASSIGN_OR_RETURN( + HloInstruction * element_copy, + CreateCopyWithNewLayout( + ShapeUtil::GetSubshape(shape_with_layout, {i}), gte)); + element_copies.push_back(element_copy); + } + // Gather element copies into a tuple with a new Tuple instruction. + HloInstruction* tuple_copy = instruction->parent()->AddInstruction( + HloInstruction::CreateTuple(element_copies)); + LayoutUtil::ClearLayout(tuple_copy->mutable_shape()); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + shape_with_layout, tuple_copy->mutable_shape())); + return tuple_copy; + } else if (ShapeUtil::IsArray(instruction->shape())) { + HloInstruction* copy = + instruction->parent()->AddInstruction(HloInstruction::CreateUnary( + instruction->shape(), HloOpcode::kCopy, instruction)); + LayoutUtil::ClearLayout(copy->mutable_shape()); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + shape_with_layout, copy->mutable_shape())); + + return copy; + } else { + return FailedPrecondition( + "Can only copy array and tuple shaped instructions"); + } +} + +// Creates a copy of the given operand if the operand's layout does not match +// the given layout. This copy replaces the use in the given instruction. Tuple +// operands will be deep-copied. +Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, + HloInstruction* instruction, + int64 operand_no) { + HloInstruction* operand = instruction->mutable_operand(operand_no); + TF_RET_CHECK(operand_layout.LayoutIsSet()); + TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); + + if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { + // Operand layout already matches our constraint. Nothing to do. + return Status::OK(); + } + + TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, + CreateCopyWithNewLayout(operand_layout.shape(), operand)); + + return instruction->ReplaceOperandWith(operand_no, operand_copy); +} + // For fusion instructions, set the layout of each fused parameter instruction // to match the layout of its corresponding fusion instruction operand. Also, // set the layout of the fused root to match the layout of the fusion diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index dd0fba2758..e269a13459 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -658,68 +658,5 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { ElementsAre(2, 1, 0)); } -TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { - auto builder = HloComputation::Builder(TestName()); - auto module = CreateNewModule(); - Shape shape = ShapeUtil::MakeShape(F32, {128, 8}); - Shape tshape = ShapeUtil::MakeTupleShape({shape, shape}); - Shape result_tshape = ShapeUtil::MakeTupleShape({shape}); - - auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param0")); - auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, shape, "param1")); - auto pred = builder.AddInstruction(HloInstruction::CreateParameter( - 2, ShapeUtil::MakeShape(PRED, {}), "param2")); - auto tuple = - builder.AddInstruction(HloInstruction::CreateTuple({param0, param1})); - - auto true_builder = HloComputation::Builder(TestName() + "_TrueBranch"); - { - auto param = true_builder.AddInstruction( - HloInstruction::CreateParameter(0, tshape, "param")); - auto gte0 = true_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, param, 0)); - auto gte1 = true_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, param, 1)); - auto add = true_builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, gte0, gte1)); - true_builder.AddInstruction(HloInstruction::CreateTuple({add})); - } - HloComputation* true_computation = - module->AddEmbeddedComputation(true_builder.Build()); - - auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch"); - { - Shape xshape = ShapeUtil::MakeShapeWithLayout(F32, {128, 8}, {0, 1}); - false_builder.AddInstruction( - HloInstruction::CreateParameter(0, tshape, "param")); - // Using infeed as layout assignment does not mess up with it. - auto infeed = - false_builder.AddInstruction(HloInstruction::CreateInfeed(xshape, "")); - false_builder.AddInstruction(HloInstruction::CreateTuple({infeed})); - } - HloComputation* false_computation = - module->AddEmbeddedComputation(false_builder.Build()); - builder.AddInstruction(HloInstruction::CreateConditional( - result_tshape, pred, tuple, true_computation, tuple, false_computation)); - - HloComputation* computation = module->AddEntryComputation(builder.Build()); - ComputationLayout computation_layout(computation->ComputeProgramShape()); - - AssignLayouts(module.get(), &computation_layout); - - const HloInstruction* true_root = true_computation->root_instruction(); - const HloInstruction* false_root = false_computation->root_instruction(); - EXPECT_THAT(true_root->opcode(), HloOpcode::kTuple); - EXPECT_THAT(false_root->opcode(), HloOpcode::kTuple); - - const HloInstruction* true_result = true_root->operand(0); - const HloInstruction* false_result = false_root->operand(0); - EXPECT_TRUE(LayoutUtil::Equal(true_result->shape().layout(), - false_result->shape().layout())); - EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy); -} - } // namespace } // namespace xla |