aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc28
1 files changed, 15 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index d33add23d0..83fcc5da6d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -52,9 +53,7 @@ using ::tensorflow::strings::StrCat;
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
HloModule* module, const HloInstructionProto& proto,
const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
- const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
- const std::function<void(std::unique_ptr<HloComputation>)>&
- add_fused_computation) {
+ const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map) {
TF_RET_CHECK(!proto.opcode().empty());
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
TF_RET_CHECK(proto.has_shape());
@@ -76,17 +75,20 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
// HloInstructionProto and do not appear as an HloComputationProto within the
// HloModuleProto.
if (instruction->opcode() == HloOpcode::kFusion) {
- TF_RET_CHECK(proto.has_fused_instructions_computation());
TF_RET_CHECK(!proto.fusion_kind().empty());
TF_ASSIGN_OR_RETURN(instruction->fusion_kind_,
StringToFusionKind(proto.fusion_kind()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> fused_computation,
- HloComputation::CreateFromProto(
- module, proto.fused_instructions_computation(),
- computation_map, add_fused_computation,
- /*fusion_instruction=*/instruction.get()));
- instruction->called_computations_.push_back(fused_computation.get());
- add_fused_computation(std::move(fused_computation));
+
+ // Find the fused computation and set its fusion instruction.
+ TF_RET_CHECK(proto.called_computation_names_size() == 1)
+ << "Expect 1 called computation for fusion instruction, but sees "
+ << proto.called_computation_names_size();
+ const string& fusion_name = proto.called_computation_names(0);
+ auto* fused_computation = FindPtrOrNull(computation_map, fusion_name);
+ TF_RET_CHECK(fused_computation != nullptr)
+ << "No fusion computation named " << fusion_name;
+ fused_computation->SetFusionInstruction(instruction.get());
+ instruction->called_computations_.push_back(fused_computation);
} else {
for (const string& computation_name : proto.called_computation_names()) {
TF_RET_CHECK(ContainsKey(computation_map, computation_name))
@@ -2330,8 +2332,8 @@ HloInstructionProto HloInstruction::ToProto() const {
proto.set_parameter_number(parameter_number_);
if (opcode() == HloOpcode::kFusion) {
proto.set_fusion_kind(xla::ToString(fusion_kind()));
- *proto.mutable_fused_instructions_computation() =
- fused_instructions_computation()->ToProto();
+ *proto.add_called_computation_names() =
+ fused_instructions_computation()->name();
} else {
for (const HloComputation* computation : called_computations_) {
*proto.add_called_computation_names() = computation->name();