diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 28 |
1 files changed, 15 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index d33add23d0..83fcc5da6d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -52,9 +53,7 @@ using ::tensorflow::strings::StrCat; StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( HloModule* module, const HloInstructionProto& proto, const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map, - const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map, - const std::function<void(std::unique_ptr<HloComputation>)>& - add_fused_computation) { + const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); @@ -76,17 +75,20 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( // HloInstructionProto and do not appear as an HloComputationProto within the // HloModuleProto. if (instruction->opcode() == HloOpcode::kFusion) { - TF_RET_CHECK(proto.has_fused_instructions_computation()); TF_RET_CHECK(!proto.fusion_kind().empty()); TF_ASSIGN_OR_RETURN(instruction->fusion_kind_, StringToFusionKind(proto.fusion_kind())); - TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> fused_computation, - HloComputation::CreateFromProto( - module, proto.fused_instructions_computation(), - computation_map, add_fused_computation, - /*fusion_instruction=*/instruction.get())); - instruction->called_computations_.push_back(fused_computation.get()); - add_fused_computation(std::move(fused_computation)); + + // Find the fused computation and set its fusion instruction. + TF_RET_CHECK(proto.called_computation_names_size() == 1) + << "Expect 1 called computation for fusion instruction, but sees " + << proto.called_computation_names_size(); + const string& fusion_name = proto.called_computation_names(0); + auto* fused_computation = FindPtrOrNull(computation_map, fusion_name); + TF_RET_CHECK(fused_computation != nullptr) + << "No fusion computation named " << fusion_name; + fused_computation->SetFusionInstruction(instruction.get()); + instruction->called_computations_.push_back(fused_computation); } else { for (const string& computation_name : proto.called_computation_names()) { TF_RET_CHECK(ContainsKey(computation_map, computation_name)) @@ -2330,8 +2332,8 @@ HloInstructionProto HloInstruction::ToProto() const { proto.set_parameter_number(parameter_number_); if (opcode() == HloOpcode::kFusion) { proto.set_fusion_kind(xla::ToString(fusion_kind())); - *proto.mutable_fused_instructions_computation() = - fused_instructions_computation()->ToProto(); + *proto.add_called_computation_names() = + fused_instructions_computation()->name(); } else { for (const HloComputation* computation : called_computations_) { *proto.add_called_computation_names() = computation->name(); |