aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-20 14:33:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-20 14:36:15 -0700
commit546d1d467372a176f337f2614165c6d754a386da (patch)
tree3cd9382a91d16523b9249c9fbefb2ac8980f8eaf /tensorflow/compiler
parenta0e07f998b388f0ecc7b7cf2256522f28482b285 (diff)
[XLA] Simplify the HLO proto: don't nest the fusion computation in an fusion HloInstructionProto.
PiperOrigin-RevId: 189811729
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto3
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h10
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc28
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc22
6 files changed, 30 insertions, 58 deletions
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index bf903d6a39..b86fbd821b 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -38,6 +38,8 @@ option cc_enable_arenas = true;
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
+ reserved 12;
+ reserved "fused_instructions_computation";
string name = 1;
string opcode = 2;
@@ -58,7 +60,6 @@ message HloInstructionProto {
// Fusion state, only present for kFusion.
string fusion_kind = 11;
- HloComputationProto fused_instructions_computation = 12;
// Index for kGetTupleElement.
int64 tuple_index = 13;
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index f99c7cf5e4..4e852190a8 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -406,18 +406,15 @@ HloComputationProto HloComputation::ToProto() const {
/* static */ StatusOr<std::unique_ptr<HloComputation>>
HloComputation::CreateFromProto(
HloModule* module, const HloComputationProto& proto,
- const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
- const std::function<void(std::unique_ptr<HloComputation>)>&
- add_fused_computation,
- HloInstruction* fusion_instruction) {
+ const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map) {
std::vector<std::unique_ptr<HloInstruction>> instructions;
tensorflow::gtl::FlatMap<string, HloInstruction*> instruction_map;
int64 parameter_count = 0;
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
- HloInstruction::CreateFromProto(
- module, instruction_proto, instruction_map,
- computation_map, add_fused_computation));
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloInstruction> instruction,
+ HloInstruction::CreateFromProto(module, instruction_proto,
+ instruction_map, computation_map));
if (instruction->opcode() == HloOpcode::kParameter) {
parameter_count++;
}
@@ -429,8 +426,9 @@ HloComputation::CreateFromProto(
TF_RET_CHECK(!proto.root_name().empty());
TF_RET_CHECK(ContainsKey(instruction_map, proto.root_name()));
HloInstruction* root = instruction_map.at(proto.root_name());
- return WrapUnique(new HloComputation(
- proto.name(), parameter_count, &instructions, root, fusion_instruction));
+ return WrapUnique(new HloComputation(proto.name(), parameter_count,
+ &instructions, root,
+ /*fusion_instruction=*/nullptr));
}
void HloComputation::FuseInstructionsInto(
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index dd9d346999..630d3675de 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -163,17 +163,9 @@ class HloComputation {
// computation_map: a map from computation name to HloComputation*. This map
// must contain all computations which the newly constructed computation
// calls.
- // add_fused_computation: A function to call to add a fused
- // computation. Used only when the instruction is a fusion instruction.
- // fusion_instruction: if non-null then the newly created computation will
- // be constructed as a fused computation with this instruction as its
- // fusion parent.
static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
HloModule* module, const HloComputationProto& proto,
- const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
- const std::function<void(std::unique_ptr<HloComputation>)>&
- add_fused_computation,
- HloInstruction* fusion_instruction = nullptr);
+ const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map);
// Gets the instructions in this computation.
//
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index d33add23d0..83fcc5da6d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -52,9 +53,7 @@ using ::tensorflow::strings::StrCat;
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
HloModule* module, const HloInstructionProto& proto,
const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
- const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
- const std::function<void(std::unique_ptr<HloComputation>)>&
- add_fused_computation) {
+ const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map) {
TF_RET_CHECK(!proto.opcode().empty());
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
TF_RET_CHECK(proto.has_shape());
@@ -76,17 +75,20 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
// HloInstructionProto and do not appear as an HloComputationProto within the
// HloModuleProto.
if (instruction->opcode() == HloOpcode::kFusion) {
- TF_RET_CHECK(proto.has_fused_instructions_computation());
TF_RET_CHECK(!proto.fusion_kind().empty());
TF_ASSIGN_OR_RETURN(instruction->fusion_kind_,
StringToFusionKind(proto.fusion_kind()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> fused_computation,
- HloComputation::CreateFromProto(
- module, proto.fused_instructions_computation(),
- computation_map, add_fused_computation,
- /*fusion_instruction=*/instruction.get()));
- instruction->called_computations_.push_back(fused_computation.get());
- add_fused_computation(std::move(fused_computation));
+
+ // Find the fused computation and set its fusion instruction.
+ TF_RET_CHECK(proto.called_computation_names_size() == 1)
+ << "Expect 1 called computation for fusion instruction, but sees "
+ << proto.called_computation_names_size();
+ const string& fusion_name = proto.called_computation_names(0);
+ auto* fused_computation = FindPtrOrNull(computation_map, fusion_name);
+ TF_RET_CHECK(fused_computation != nullptr)
+ << "No fusion computation named " << fusion_name;
+ fused_computation->SetFusionInstruction(instruction.get());
+ instruction->called_computations_.push_back(fused_computation);
} else {
for (const string& computation_name : proto.called_computation_names()) {
TF_RET_CHECK(ContainsKey(computation_map, computation_name))
@@ -2330,8 +2332,8 @@ HloInstructionProto HloInstruction::ToProto() const {
proto.set_parameter_number(parameter_number_);
if (opcode() == HloOpcode::kFusion) {
proto.set_fusion_kind(xla::ToString(fusion_kind()));
- *proto.mutable_fused_instructions_computation() =
- fused_instructions_computation()->ToProto();
+ *proto.add_called_computation_names() =
+ fused_instructions_computation()->name();
} else {
for (const HloComputation* computation : called_computations_) {
*proto.add_called_computation_names() = computation->name();
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index e4c86214c2..a111e1e4a6 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -184,15 +184,10 @@ class HloInstruction {
// computation_map: a map from computation name to HloComputation*. This map
// must contain all computations which the newly constructed instruction
// calls.
- // add_fused_computation: A function to call to add a fused
- // computation. Used (clearly) when the instruction is a fusion
- // instruction.
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
HloModule* module, const HloInstructionProto& proto,
const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
- const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
- const std::function<void(std::unique_ptr<HloComputation>)>&
- add_fused_computation);
+ const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map);
// Creates a parameter-retrieving instruction.
static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index cdea3d5978..4091ebbfd3 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -207,11 +207,6 @@ HloModuleProto HloModule::ToProto() const {
proto.set_name(name_);
proto.set_entry_computation_name(entry_computation_->name());
for (const HloComputation* computation : MakeComputationPostOrder()) {
- // Fusion computations are added when the fusion instructions are created by
- // HloInstruction::CreateFromProto.
- if (computation->IsFusionComputation()) {
- continue;
- }
HloComputationProto computation_proto = computation->ToProto();
if (computation->name() == entry_computation_->name()) {
*proto.mutable_program_shape() = computation_proto.program_shape();
@@ -256,16 +251,9 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
tensorflow::gtl::FlatMap<string, HloComputation*> computation_map;
for (const HloComputationProto& computation_proto : proto.computations()) {
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<HloComputation> computation,
- HloComputation::CreateFromProto(
- module.get(), computation_proto, computation_map,
- /*add_fused_computation=*/
- [&module](std::unique_ptr<HloComputation> fused_computation) {
- module->AddComputationInternal(std::move(fused_computation),
- /*is_entry=*/false,
- /*uniquify_names=*/false);
- }));
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> computation,
+ HloComputation::CreateFromProto(
+ module.get(), computation_proto, computation_map));
CHECK_NE(computation.get(), nullptr);
TF_RET_CHECK(!ContainsKey(computation_map, computation->name()));
string computation_name = computation->name();
@@ -283,10 +271,6 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
tensorflow::gtl::FlatSet<string> computation_names;
tensorflow::gtl::FlatSet<string> instruction_names;
for (HloComputation* computation : module->computations()) {
- if (computation->IsFusionComputation()) {
- continue;
- }
-
TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
<< "Computation name is not unique: " << computation->name();
computation_names.insert(computation->name());