aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto3
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc52
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h20
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc43
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc100
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h59
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h19
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc71
-rw-r--r--tensorflow/compiler/xla/statusor.h11
-rw-r--r--tensorflow/compiler/xla/statusor_test.cc8
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc10
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc22
14 files changed, 259 insertions, 176 deletions
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index aa6860880b..1f7c1cffd3 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -147,6 +147,9 @@ message HloInstructionProto {
repeated int64 called_computation_ids = 38;
xla.OpSharding sharding = 40;
+
+ // Backend configuration for the instruction. Has backend-specific meaning.
+ string backend_config = 43;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 594413e88f..17e43c3cb8 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -347,6 +347,11 @@ std::list<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
// To avoid special handling of this computation, cast away const of
// 'this'. 'this' is immediately removed from the post order after
// construction.
+ //
+ // TODO(b/78350259): This violates const-correctness, since while the original
+ // computation is not returned, we still retrieve non-const computations from
+ // a const one. Consider also avoiding const for HloComputation, or review XLA
+ // for const-correctness of non-HloInstruction* types like this.
ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited,
&post_order);
@@ -723,18 +728,25 @@ Status HloComputation::Accept(
return this->Accept(&visitor);
}
-std::unique_ptr<HloComputation> HloComputation::Clone(const string& suffix,
- HloModule* module) {
+std::unique_ptr<HloComputation> HloComputation::Clone(
+ const string& suffix, HloModule* module,
+ HloInstruction::CloneMap* clone_map) {
return CloneWithReplacements(
/*replacements=*/std::unordered_map<const HloInstruction*,
std::unique_ptr<HloInstruction>>(),
- module, suffix);
+ module, clone_map, suffix);
}
std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements,
- HloModule* module, const string& suffix) {
+ HloModule* module, HloInstruction::CloneMap* clone_map,
+ const string& suffix) {
+ HloInstruction::CloneMap local_clone_map;
+ if (clone_map == nullptr) {
+ clone_map = &local_clone_map;
+ }
+
// Look up instr in the replacements map, and return either the replacement,
// or instr, if the replacement isn't present.
//
@@ -756,24 +768,19 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
}
}
- std::unordered_map<HloInstruction*, HloInstruction*> clone_map;
std::vector<std::unique_ptr<HloInstruction>> instructions;
std::unique_ptr<HloInstruction> new_instr = nullptr;
for (auto instr : postorder) {
std::vector<HloInstruction*> new_operands;
for (auto operand : instr->operands()) {
auto replaced_operand = replace(operand);
- // If replaced_operand is null, that means 'replacements' asked us not to
- // include operand in the new computation. But we can't do that, because
- // operand is used by instr.
CHECK_NE(replaced_operand, nullptr)
- << "replacements map tried to eliminate a used instruction "
- << operand->ToString() << ", used by " << instr->ToString();
- new_operands.push_back(FindOrDie(clone_map, replaced_operand));
+ << "Replacements map specifies to leave out " << operand->ToString()
+ << ", but it is used by " << instr->ToString() << ".";
+ new_operands.push_back(FindOrDie(*clone_map, replaced_operand));
}
- new_instr =
- instr->CloneWithNewOperands(instr->shape(), new_operands, module);
- InsertOrDie(&clone_map, instr, new_instr.get());
+ new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands,
+ module, clone_map);
instructions.push_back(std::move(new_instr));
}
Builder builder(name() + "." + suffix);
@@ -781,27 +788,24 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
builder.AddInstruction(std::move(instr));
}
auto result = builder.Build(
- /*root_instruction=*/FindOrDie(clone_map, replace(root_instruction())));
+ /*root_instruction=*/FindOrDie(*clone_map, replace(root_instruction())));
// Clone control dependencies.
for (auto instr : postorder) {
- HloInstruction* new_instr = FindOrDie(clone_map, instr);
+ HloInstruction* new_instr = FindOrDie(*clone_map, instr);
for (auto successor : instr->control_successors()) {
auto replaced_successor = replace(successor);
-
- // successor may not be in clone_map, because it might have been
- // removed by the replacements map.
- if (replaced_successor == nullptr) {
- continue;
- }
+ CHECK_NE(replaced_successor, nullptr)
+ << "Replacements map specifies to leave out " << successor->ToString()
+ << ", but it is control-depended-on by " << instr->ToString() << ".";
TF_CHECK_OK(new_instr->AddControlDependencyTo(
- FindOrDie(clone_map, replaced_successor)));
+ FindOrDie(*clone_map, replaced_successor)));
}
}
// We cloned the elements of 'replacements', so they're all going to be
- // destroyed. HloInstructions need to be detached from their operands before
+ // destroyed. HloInstructions need to be detached from their operands before
// they're destroyed, otherwise they stick around in the operands' users lists
// and cause use-after-frees.
for (auto& kv : replacements) {
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 9d3f6e9a2c..9898355625 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -291,11 +291,17 @@ class HloComputation {
const std::function<Status(const HloInstruction*)>& visitor_func) const;
// Returns a deep copy of this computation including all instructions.
- // If the module pointer is not nullptr, it will be the module where
- // the cloned computations will be added to (in order to support deep
- // cloning).
- std::unique_ptr<HloComputation> Clone(const string& suffix = "clone",
- HloModule* module = nullptr);
+ //
+ // If the module pointer is not nullptr, then the cloned computations will be
+ // added to this module in order to support deep cloning. Otherwise the module
+ // of the computation is used.
+ //
+ // If clone_map is not nullptr, then each original instruction that is cloned
+ // will be inserted and map to its clone. clone_map should not already contain
+ // any of the instructions to clone.
+ std::unique_ptr<HloComputation> Clone(
+ const string& suffix = "clone", HloModule* module = nullptr,
+ HloInstruction::CloneMap* clone_map = nullptr);
// Like Clone(), but if an instruction is present in replacement_map, we use
// the map's value to replace that instruction in the cloned computation.
@@ -305,7 +311,9 @@ class HloComputation {
std::unique_ptr<HloComputation> CloneWithReplacements(
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements,
- HloModule* module = nullptr, const string& suffix = "clone");
+ HloModule* module = nullptr,
+ HloInstruction::CloneMap* clone_map = nullptr,
+ const string& suffix = "clone");
// Returns true if the given instruction can be removed from the computation.
// Parameter instructions cannot be removed without violating invariants of
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index bb4db89f0a..794f1b4682 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -322,11 +322,13 @@ class HloDotDumper {
public:
HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label,
const DebugOptions& debug_options, bool show_metadata,
- const HloExecutionProfile* profile, NodeFilter filter)
+ bool show_backend_config, const HloExecutionProfile* profile,
+ NodeFilter filter)
: computation_(computation),
label_(label.ToString()),
debug_options_(debug_options),
show_metadata_(show_metadata),
+ show_backend_config_(show_backend_config),
profile_(profile),
filter_(std::move(filter)) {}
@@ -365,6 +367,7 @@ class HloDotDumper {
string GetInstructionNodeShape(const HloInstruction* instr);
string GetInstructionNodeLabel(const HloInstruction* instr);
string GetInstructionNodeMetadata(const HloInstruction* instr);
+ string GetInstructionNodeBackendConfig(const HloInstruction* instr);
string GetInstructionNodeExtraInfo(const HloInstruction* instr);
string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
void AddInstructionIncomingEdges(const HloInstruction* instr);
@@ -393,6 +396,7 @@ class HloDotDumper {
const string label_; // overall name for the graph
const DebugOptions& debug_options_;
const bool show_metadata_;
+ const bool show_backend_config_;
const HloExecutionProfile* profile_; // may be null
const NodeFilter filter_;
@@ -611,6 +615,10 @@ tooltip = " ";
if (!extra_info.empty()) {
StrAppend(&subcomp_label, "<br/>", extra_info);
}
+ string node_backend_config = GetInstructionNodeBackendConfig(parent_instr);
+ if (!node_backend_config.empty()) {
+ StrAppend(&subcomp_label, "<br/>", node_backend_config);
+ }
bool highlight = filter_.Highlight(parent_instr);
const char* fillcolor;
@@ -765,6 +773,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
string node_shape = GetInstructionNodeShape(instr);
string node_label = GetInstructionNodeLabel(instr);
string node_metadata = GetInstructionNodeMetadata(instr);
+ string node_backend_config = GetInstructionNodeBackendConfig(instr);
string extra_info = GetInstructionNodeExtraInfo(instr);
string inlined_constants = GetInstructionNodeInlinedOperands(instr);
string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
@@ -782,8 +791,8 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
}
// Build the text that will be displayed inside the node.
string node_body = node_label;
- for (const string& s :
- {trivial_subcomputation, node_metadata, extra_info, inlined_constants}) {
+ for (const string& s : {trivial_subcomputation, node_metadata,
+ node_backend_config, extra_info, inlined_constants}) {
if (!s.empty()) {
StrAppend(&node_body, "<br/>", s);
}
@@ -1078,6 +1087,15 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
return Join(lines, "<br/>");
}
+string HloDotDumper::GetInstructionNodeBackendConfig(
+ const HloInstruction* instr) {
+ if (!show_backend_config_ || instr->backend_config().empty()) {
+ return "";
+ }
+
+ return StrCat("backend_config=\"", instr->backend_config(), "\"");
+}
+
string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
std::vector<string> lines;
@@ -1404,7 +1422,7 @@ string ExportGraph(const string& graph,
string DumpGraph(const HloComputation& computation, const string& label,
const DebugOptions& debug_options,
const HloExecutionProfile* hlo_execution_profile,
- bool show_metadata) {
+ bool show_metadata, bool show_backend_config) {
GraphRendererInterface::GraphKind graph_kind;
string graph;
if (debug_options.xla_hlo_dump_as_graphdef()) {
@@ -1414,9 +1432,10 @@ string DumpGraph(const HloComputation& computation, const string& label,
&graph));
graph_kind = GraphRendererInterface::TF_GRAPHDEF;
} else {
- graph = HloDotDumper(&computation, label, debug_options, show_metadata,
- hlo_execution_profile, NodeFilter())
- .Dump();
+ graph =
+ HloDotDumper(&computation, label, debug_options, show_metadata,
+ show_backend_config, hlo_execution_profile, NodeFilter())
+ .Dump();
graph_kind = GraphRendererInterface::DOT_GRAPH;
}
@@ -1427,15 +1446,15 @@ string DumpGraph(const HloComputation& computation, const string& label,
}
string DumpNeighborhoodAround(const HloInstruction& node, int radius,
- bool show_metadata) {
+ bool show_metadata, bool show_backend_config) {
auto debug_options = node.GetModule()->config().debug_options();
string label =
StrCat("Neighborhood of ", radius, " nodes around ", node.name());
NodeFilter filter = MakeNodeFilter(&node, radius);
- string graph =
- HloDotDumper(node.parent(), label, debug_options, show_metadata,
- /*profile=*/nullptr, filter)
- .Dump();
+ string graph = HloDotDumper(node.parent(), label, debug_options,
+ show_metadata, show_backend_config,
+ /*profile=*/nullptr, filter)
+ .Dump();
return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options);
}
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h
index 2704aae1e3..fc8e1468ac 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h
@@ -56,7 +56,7 @@ string MaybeDumpHloModule(const HloModule& module, const string& label,
string DumpGraph(const HloComputation& computation, const string& label,
const DebugOptions& debug_options,
const HloExecutionProfile* hlo_execution_profile = nullptr,
- bool show_metadata = false);
+ bool show_metadata = false, bool show_backend_config = false);
// Like DumpGraph, but renders only nodes "near" the given node in the graph.
//
@@ -64,7 +64,8 @@ string DumpGraph(const HloComputation& computation, const string& label,
// (roughly) corresponds to the max distance a node may be from the primary node
// before it's omitted from the graph.
string DumpNeighborhoodAround(const HloInstruction& node, int radius,
- bool show_metadata = false);
+ bool show_metadata = false,
+ bool show_backend_config = false);
// Dumps the HloModule::ToString() as a file into the provided directory path
// suffixed with the provided label.
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index a714d0e114..2c733726a6 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -109,6 +109,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->name_ = proto.name();
instruction->metadata_ = proto.metadata();
+ instruction->set_backend_config(proto.backend_config());
if (proto.has_literal()) {
TF_ASSIGN_OR_RETURN(instruction->literal_,
Literal::CreateFromProto(proto.literal()));
@@ -1231,12 +1232,15 @@ bool HloInstruction::HasSideEffect() const {
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
- HloModule* module) const {
+ HloModule* module, CloneMap* clone_map) const {
VLOG(3) << "CloneWithNewOperands:\n " << ToString();
VLOG(3) << " new operands:";
for (const HloInstruction* new_operand : new_operands) {
VLOG(3) << " %" << new_operand->name();
}
+ if (module == nullptr) {
+ module = GetModule();
+ }
std::unique_ptr<HloInstruction> clone;
@@ -1342,7 +1346,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
break;
case HloOpcode::kFft:
CHECK_EQ(new_operands.size(), 1);
- return CreateFft(shape, new_operands[0], fft_type_, fft_length_);
+ clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_);
+ break;
case HloOpcode::kCrossReplicaSum:
clone = CreateCrossReplicaSum(shape, new_operands);
break;
@@ -1415,9 +1420,15 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kConstant:
clone = CreateConstant(literal_->CloneToUnique());
break;
- case HloOpcode::kFusion:
- clone = CloneFusionWithNewOperands(shape, new_operands, module);
+ case HloOpcode::kFusion: {
+ CHECK_NE(module, nullptr);
+ auto new_fused_computation = module->AddEmbeddedComputation(
+ fused_instructions_computation()->Clone("clone", module, clone_map));
+ clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(),
+ /*operands=*/new_operands,
+ /*fusion_computation=*/new_fused_computation);
break;
+ }
case HloOpcode::kParameter:
clone = CreateParameter(parameter_number_, shape, name_);
break;
@@ -1481,15 +1492,19 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
}
SetupDerivedInstruction(clone.get());
clone->set_parent(parent_);
+ clone->set_backend_config(backend_config());
+ if (clone_map != nullptr) {
+ InsertOrDie(clone_map, this, clone.get());
+ }
return clone;
}
HloInstruction::~HloInstruction() {}
-std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix,
- HloModule* module) const {
+std::unique_ptr<HloInstruction> HloInstruction::Clone(
+ const string& suffix, HloModule* module, CloneMap* clone_map) const {
std::unique_ptr<HloInstruction> clone =
- CloneWithNewOperands(shape_, operands_, module);
+ CloneWithNewOperands(shape_, operands_, module, clone_map);
if (suffix.empty()) {
clone->name_ = name();
} else {
@@ -1526,71 +1541,6 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix,
return clone;
}
-std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloModule* module) const {
- CHECK_EQ(opcode_, HloOpcode::kFusion);
- CHECK(parent() != nullptr);
-
- auto new_instruction =
- WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
- // Add the operands to our new fusion instruction.
- for (HloInstruction* new_operand : operands) {
- new_instruction->AppendOperand(new_operand);
- }
- // Clone all the fused instructions for the new fusion instruction.
- HloInstructionMap<HloInstruction*> old_to_new;
- std::list<std::unique_ptr<HloInstruction>> new_fused_instructions;
- // Create the list of fused parameters by mapping through the cloned,
- // fused instructions.
- for (HloInstruction* old_fused_parameter :
- fused_instructions_computation()->parameter_instructions()) {
- new_fused_instructions.push_back(
- old_fused_parameter->Clone("clone", module));
- HloInstruction* new_fusion_parameter = new_fused_instructions.back().get();
- InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter);
- }
- for (auto old_fused_instruction :
- fused_instructions_computation()->MakeInstructionPostOrder()) {
- if (old_fused_instruction->opcode() == HloOpcode::kParameter) {
- FindOrDie(old_to_new, old_fused_instruction);
- continue;
- }
- std::vector<HloInstruction*> new_operands;
- for (int64 operand_idx = 0;
- operand_idx < old_fused_instruction->operand_count(); ++operand_idx) {
- HloInstruction* old_operand =
- old_fused_instruction->mutable_operand(operand_idx);
- new_operands.push_back(FindOrDie(old_to_new, old_operand));
- }
- new_fused_instructions.push_back(
- old_fused_instruction->CloneWithNewOperands(
- old_fused_instruction->shape(), new_operands, module));
- HloInstruction* new_fused_instruction = new_fused_instructions.back().get();
- new_fused_instruction->set_parent(parent_);
- InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction);
- }
- new_instruction->fusion_kind_ = fusion_kind_;
- auto computation_builder = HloComputation::Builder(
- fused_instructions_computation()->name() + ".clone",
- new_instruction.get());
- // We iterated the fusion instructions in reverse post order which means
- // that we must reverse our new list of fusion instructions.
- for (auto new_fused_instruction_iter = new_fused_instructions.rbegin();
- new_fused_instruction_iter != new_fused_instructions.rend();
- ++new_fused_instruction_iter) {
- computation_builder.AddInstruction(std::move(*new_fused_instruction_iter));
- }
- if (module == nullptr) {
- module = GetModule();
- }
- auto fused_root_ = fused_expression_root();
- new_instruction->called_computations_.push_back(
- CHECK_NOTNULL(module)->AddEmbeddedComputation(
- computation_builder.Build(FindOrDie(old_to_new, fused_root_))));
- return new_instruction;
-}
-
std::pair<const HloInstruction*, ShapeIndex>
HloInstruction::LatestNonGteAncestorAndIndex() const {
const HloInstruction* hlo = this;
@@ -2172,6 +2122,9 @@ string HloInstruction::ToString(const HloPrintOptions& options) const {
!metadata_.source_file().empty())) {
StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
}
+ if (options.print_backend_config() && !backend_config().empty()) {
+ StrAppend(&result, ", backend_config=\"", CEscape(backend_config()), "\"");
+ }
return result;
}
@@ -2357,6 +2310,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
extra.push_back(
StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
}
+
return extra;
}
@@ -2386,6 +2340,7 @@ HloInstructionProto HloInstruction::ToProto() const {
}
*proto.mutable_metadata() = metadata_;
+ proto.set_backend_config(backend_config());
if (literal_ != nullptr) {
*proto.mutable_literal() = literal_->ToProto();
}
@@ -2971,6 +2926,7 @@ Status HloInstruction::AcceptOrdered(
continue;
}
+ // TODO(b/78350259): Eliminate const laundering.
HloInstruction* instruction =
const_cast<HloInstruction*>(const_instruction);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index a5e9aecb9e..19c8c11453 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -66,6 +66,7 @@ class HloPrintOptions {
: print_large_constants_(false),
print_subcomputation_references_(true),
print_metadata_(true),
+ print_backend_config_(true),
compact_operands_(false),
print_operand_shape_(true),
print_program_shape_(true),
@@ -77,6 +78,7 @@ class HloPrintOptions {
.set_print_large_constants(true)
.set_print_subcomputation_references(true)
.set_print_metadata(false)
+ .set_print_backend_config(false)
.set_print_operand_shape(false)
.set_print_program_shape(false)
.set_print_percent(false);
@@ -99,12 +101,18 @@ class HloPrintOptions {
return *this;
}
- // If true, metatdata will be printed.
+ // If true, metadata will be printed.
HloPrintOptions& set_print_metadata(bool value) {
print_metadata_ = value;
return *this;
}
+ // If true, backend_config will be printed.
+ HloPrintOptions& set_print_backend_config(bool value) {
+ print_backend_config_ = value;
+ return *this;
+ }
+
// If true, operands' shapes will be printed.
HloPrintOptions& set_print_operand_shape(bool value) {
print_operand_shape_ = value;
@@ -141,6 +149,7 @@ class HloPrintOptions {
return print_subcomputation_references_;
}
bool print_metadata() const { return print_metadata_; }
+ bool print_backend_config() const { return print_metadata_; }
bool compact_operands() const { return compact_operands_; }
bool print_operand_shape() const { return print_operand_shape_; }
bool print_program_shape() const { return print_program_shape_; }
@@ -151,6 +160,7 @@ class HloPrintOptions {
bool print_large_constants_;
bool print_subcomputation_references_;
bool print_metadata_;
+ bool print_backend_config_;
bool compact_operands_;
bool print_operand_shape_;
bool print_program_shape_;
@@ -643,6 +653,8 @@ class HloInstruction {
// Detaches an instruction from its operands. That is, remove the instruction
// from each operand's user set. This should only be called prior to
// deallocating the instruction.
+ //
+ // TODO(b/78305363): Make this automatic when deleting an instruction.
void DetachFromOperands();
// Performs a postorder DFS visit using this node as the root. If
@@ -1157,23 +1169,30 @@ class HloInstruction {
// Precondition: opcode() == HloOpcode::kRng
RandomDistribution random_distribution() const;
+ // See documentation for Clone().
+ using CloneMap = std::unordered_map<const HloInstruction*, HloInstruction*>;
+
// Clones the HLO instruction. The clone will have the same opcode, shape, and
// operands. After creation the clone has no uses. "this" (the instruction
// cloned from) is not changed. Suffix is the string to append to the name of
- // the instruction to form the name of the cloned instruction. If the module
- // pointer is not nullptr, it will be the module where the cloned computations
- // will be added to (in order to support deep cloning). Ignores the control
- // predecessors and successors of this HLO instruction.
+ // the instruction to form the name of the cloned instruction. Ignores the
+ // control predecessors and successors of this HLO instruction.
+ //
+ // If the module pointer is not nullptr, then any cloned computations will be
+ // added to this module in order to support deep cloning. Otherwise the module
+ // of the instruction is used.
+ //
+ // If clone_map is not nullptr, then each original instruction that is cloned
+ // will be inserted and map to its clone. clone_map should not already contain
+ // any of the instructions to clone.
std::unique_ptr<HloInstruction> Clone(const string& suffix = "clone",
- HloModule* module = nullptr) const;
+ HloModule* module = nullptr,
+ CloneMap* clone_map = nullptr) const;
- // Clones the HLO instruction as above but with new shape and operands. If
- // the module pointer is not nullptr, it will be the module where the cloned
- // computations will be added to (in order to support deep cloning). Ignores
- // the control predecessors and successors of this HLO instruction.
+ // Clones the HLO instruction as above but with new shape and operands.
std::unique_ptr<HloInstruction> CloneWithNewOperands(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloModule* module = nullptr) const;
+ HloModule* module = nullptr, CloneMap* clone_map = nullptr) const;
// Returns the computations this instruction directly calls (if any).
const std::vector<HloComputation*>& called_computations() const {
@@ -1262,6 +1281,19 @@ class HloInstruction {
// if no id has been assigned yet).
int unique_id() const { return unique_id_; }
+ // Returns the backend-specific configuration for how a backend should compile
+ // this HLO. The meaning of the field is backend specific. Not for use before
+ // or during general HLO optimization, since HLO optimizations do not preserve
+ // this field and they cannot interpret it due to its meaning being backend
+ // specific.
+ //
+ // TODO(b/78194644): Introduce structured configuration format as per
+ // go/xla-heuristics.
+ const string& backend_config() const { return backend_config_; }
+ void set_backend_config(string backend_config) {
+ backend_config_ = std::move(backend_config);
+ }
+
// Sets the debug metadata for this instruction.
void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
const OpMetadata& metadata() const { return metadata_; }
@@ -1283,6 +1315,7 @@ class HloInstruction {
// Get/Set the number of partitions per outer dimension (in order, starting
// with outer-most dimension first). Currently used by the parallel cpu
// backend to partition HLOs into parallel tasks.
+ //
// TODO(b/62783254) Replace these methods with a more general way to
// annotate HLOs with backend-specific information.
const std::vector<int64>& outer_dimension_partitions() const {
@@ -1510,6 +1543,10 @@ class HloInstruction {
// The string representation of the infeed configuration.
string infeed_config_;
+ // The backend-specific configuration for how a backend should compile this
+ // HLO. See the documentation on backend_config().
+ string backend_config_;
+
// String identifier for instruction.
string name_;
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index c7a7192867..5308fb5848 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -46,6 +46,18 @@ HloModule::HloModule(const string& name, const HloModuleConfig& config)
config_(config),
unique_id_(next_unique_module_id_++) {}
+StatusOr<HloInstruction*> HloModule::LaunderConstInstructionFromModule(
+ const HloInstruction* hlo) {
+ if (hlo == nullptr) {
+ return nullptr;
+ }
+
+ TF_RET_CHECK(hlo->GetModule() == this);
+
+ // TODO(b/78350259): Eliminate const laundering.
+ return const_cast<HloInstruction*>(hlo);
+}
+
HloComputation* HloModule::AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
bool uniquify_names) {
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index f9674df812..1604a72612 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -217,6 +217,25 @@ class HloModule {
// the lifetime of this process.
int unique_id() const { return unique_id_; }
+ // Returns a non-const version of the passed-in const HloInstruction*. This is
+ // safe on the argument that if you have a non-const module, then you can
+ // access all instructions in the module as non-const.
+ //
+ // Returns an error if the passed-in instruction is not from this module,
+ // except that it is allowed to pass in a null pointer.
+ //
+ // TODO(b/78350259): Eliminate const laundering. The argument above is not
+ // reliable since at any time someone could add or discover a way for a
+ // non-const module to transitively contain a const HloInstruction. The
+ // reliable way to do this would be to create a const laundering map from a
+ // module, mapping each encountered HloInstruction to its non-const version
+ // and then look up each instruction in need of laundering in that map, but
+ // this is much more expensive and complicated. This returns a Status instead
+ // of doing a CHECK-failure in part to make it strongly apparent that this is
+ // something that can fail.
+ StatusOr<HloInstruction*> LaunderConstInstructionFromModule(
+ const HloInstruction* hlo);
+
private:
HloComputation* AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 8a30cbf9cd..096ebb7946 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -116,7 +116,7 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) {
// produces no HLO value in the graph.
if (!ShapeUtil::Compatible(outfeed->outfeed_shape(),
outfeed->operand(0)->shape())) {
- return InvalidArgument(
+ return InternalError(
"Expected outfeed to have shape compatible with operand's shape %s, "
"actual shape is %s:\n%s",
ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(),
@@ -200,7 +200,7 @@ Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) {
transpose->operand(0)->shape(), transpose->dimensions()));
}
-Status ShapeVerifier::HandleParameter(HloInstruction*) {
+Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
return tensorflow::Status::OK();
}
@@ -410,7 +410,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
if (fp_type == PRIMITIVE_TYPE_INVALID) {
fp_type = subshape.element_type();
} else if (fp_type != subshape.element_type()) {
- return FailedPrecondition(
+ return InternalError(
"Seen floating point types of different precisions in "
"%s, but mixed precision is disallowed.",
instruction->ToString().c_str());
@@ -490,7 +490,7 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
}
}
if (!compatible) {
- return InvalidArgument(
+ return InternalError(
"Expected instruction to have shape compatible with %s, actual "
"shape is %s:\n%s",
ShapeUtil::HumanString(inferred_shape).c_str(),
@@ -541,7 +541,7 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1,
const HloInstruction* instr2) {
if (instr1->channel_id() != instr2->channel_id()) {
- return FailedPrecondition(
+ return InternalError(
"Expected to have the same channel id, actual channel ids are: %s "
"(%lld), %s (%lld)",
instr1->ToString().c_str(), instr1->channel_id(),
@@ -571,22 +571,22 @@ string ComputationsToString(
Status VerifyHloStructure(HloModule* module) {
for (const HloComputation* computation : module->computations()) {
if (computation->parent() == nullptr) {
- return FailedPrecondition("Computation %s has a null parent pointer",
- computation->name().c_str());
+ return InternalError("Computation %s has a null parent pointer",
+ computation->name().c_str());
}
if (computation->parent() != module) {
- return FailedPrecondition(
+ return InternalError(
"Computation %s parent() does not point to parent module",
computation->name().c_str());
}
for (const HloInstruction* instruction : computation->instructions()) {
if (instruction->parent() == nullptr) {
- return FailedPrecondition("Instruction %s has a null parent pointer",
- instruction->name().c_str());
+ return InternalError("Instruction %s has a null parent pointer",
+ instruction->name().c_str());
}
if (instruction->parent() != computation) {
- return FailedPrecondition(
+ return InternalError(
"Instruction %s parent() does not point to parent computation",
instruction->name().c_str());
}
@@ -602,7 +602,7 @@ Status VerifyHloStructure(HloModule* module) {
for (int i = 0; i < instruction->operand_count(); ++i) {
const HloInstruction* operand = instruction->operand(i);
if (operand->parent() != instruction->parent()) {
- return FailedPrecondition(
+ return InternalError(
"Operand %d (%s) of instruction %s is in a different "
"computation: %s vs %s",
i, operand->name().c_str(), instruction->name().c_str(),
@@ -619,7 +619,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
// The parent fusion instruction of the fusion computation must be 'fusion'.
HloComputation* fused_computation = fusion->fused_instructions_computation();
if (fusion != fused_computation->FusionInstruction()) {
- return FailedPrecondition(
+ return InternalError(
"Instruction of fused computation does not match expected instruction "
"%s.",
fusion->ToString().c_str());
@@ -635,37 +635,37 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
for (auto* instruction : fused_computation->instructions()) {
if (fused_root == instruction) {
if (root_owned) {
- return FailedPrecondition("Root appears more than once in %s.",
- fusion->ToString().c_str());
+ return InternalError("Root appears more than once in %s.",
+ fusion->ToString().c_str());
}
root_owned = true;
}
for (int i = 0; i < fused_parameters.size(); ++i) {
if (fused_parameters[i] == instruction) {
if (parameter_owned[i]) {
- return FailedPrecondition("Parameter appears more than once in %s.",
- fusion->ToString().c_str());
+ return InternalError("Parameter appears more than once in %s.",
+ fusion->ToString().c_str());
}
parameter_owned[i] = true;
}
}
}
if (!root_owned) {
- return FailedPrecondition("Root not found in computation of %s.",
- fusion->ToString().c_str());
+ return InternalError("Root not found in computation of %s.",
+ fusion->ToString().c_str());
}
// Make sure all the parameter_owned entries are set
for (int i = 0; i < parameter_owned.size(); i++) {
if (!parameter_owned[i]) {
- return FailedPrecondition("Parameter %d not found in computation of %s.",
- i, fusion->ToString().c_str());
+ return InternalError("Parameter %d not found in computation of %s.", i,
+ fusion->ToString().c_str());
}
}
// Fused root must have no users.
if (fused_root->user_count() != 0) {
- return FailedPrecondition("Root of %s may not have users.",
- fusion->ToString().c_str());
+ return InternalError("Root of %s may not have users.",
+ fusion->ToString().c_str());
}
// All uses of fused instructions must be in the fusion computation, and every
@@ -674,13 +674,13 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
fusion->fused_instructions_computation()->instructions()) {
if (instruction != fused_root) {
if (instruction->user_count() == 0) {
- return FailedPrecondition(
- "Non-root instruction %s in %s must have users.",
- instruction->ToString().c_str(), fusion->ToString().c_str());
+ return InternalError("Non-root instruction %s in %s must have users.",
+ instruction->ToString().c_str(),
+ fusion->ToString().c_str());
}
for (auto& user : instruction->users()) {
if (fused_computation != user->parent()) {
- return FailedPrecondition(
+ return InternalError(
"Non-root instruction %s in %s may not have external users.",
instruction->ToString().c_str(), fusion->ToString().c_str());
}
@@ -695,34 +695,33 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
for (auto fused_param : fused_parameters) {
int64 param_no = fused_param->parameter_number();
if (param_no < 0) {
- return FailedPrecondition(
- "Unexpected negative parameter number %lld in %s.", param_no,
- fusion->ToString().c_str());
+ return InternalError("Unexpected negative parameter number %lld in %s.",
+ param_no, fusion->ToString().c_str());
}
if (param_no >= fused_parameters.size()) {
- return FailedPrecondition(
+ return InternalError(
"Unexpected parameter number %lld in %s: higher then number of "
"parameters %lu.",
param_no, fusion->ToString().c_str(), fused_parameters.size());
}
if (parameter_numbers[param_no]) {
- return FailedPrecondition(
+ return InternalError(
"Did not expect parameter number %lld more than once in %s.",
param_no, fusion->ToString().c_str());
}
parameter_numbers[param_no] = true;
if (!ShapeUtil::Compatible(fused_param->shape(),
fusion->operand(param_no)->shape())) {
- return FailedPrecondition(
+ return InternalError(
"Shape mismatch between parameter number %lld and its operand in %s.",
param_no, fusion->ToString().c_str());
}
}
- // Make sure all the parameter_numbers entries were seen
+ // Make sure all the parameter_numbers entries were seen.
for (int i = 0; i < parameter_numbers.size(); i++) {
if (!parameter_numbers[i]) {
- return FailedPrecondition("Did not see parameter number %d in %s.", i,
- fusion->ToString().c_str());
+ return InternalError("Did not see parameter number %d in %s.", i,
+ fusion->ToString().c_str());
}
}
diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h
index cccbce5fc8..0e1387c939 100644
--- a/tensorflow/compiler/xla/statusor.h
+++ b/tensorflow/compiler/xla/statusor.h
@@ -13,13 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// StatusOr<T> is the union of a Status object and a T
-// object. StatusOr models the concept of an object that is either a
-// usable value, or an error Status explaining why such a value is
-// not present. To this end, StatusOr<T> does not allow its Status
-// value to be Status::OK. Furthermore, the value of a StatusOr<T*>
-// must not be null. This is enforced by a debug check in most cases,
-// but even when it is not, clients must not set the value to null.
+// StatusOr<T> is the union of a Status object and a T object. StatusOr models
+// the concept of an object that is either a value, or an error Status
+// explaining why such a value is not present. To this end, StatusOr<T> does not
+// allow its Status value to be Status::OK.
//
// The primary use-case for StatusOr<T> is as the return value of a
// function which may fail.
diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc
index f9d25945bc..7d76370e85 100644
--- a/tensorflow/compiler/xla/statusor_test.cc
+++ b/tensorflow/compiler/xla/statusor_test.cc
@@ -75,6 +75,14 @@ TEST(StatusOr, ElementType) {
static_assert(std::is_same<StatusOr<char>::element_type, char>(), "");
}
+TEST(StatusOr, NullPointerStatusOr) {
+ // As a very special case, null-plain-pointer StatusOr used to be an
+ // error. Test that it no longer is.
+ StatusOr<int*> null_status(nullptr);
+ EXPECT_TRUE(null_status.ok());
+ EXPECT_EQ(null_status.ValueOrDie(), nullptr);
+}
+
TEST(StatusOr, TestNoDefaultConstructorInitialization) {
// Explicitly initialize it with an error code.
StatusOr<NoDefaultConstructor> statusor(tensorflow::errors::Cancelled(""));
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index 40dc0730ce..156a06c596 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -440,6 +440,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<OpMetadata> metadata;
attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};
+ optional<string> backend_config;
+ attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
+ &backend_config};
+
HloInstruction* instruction;
switch (opcode) {
case HloOpcode::kParameter: {
@@ -1094,8 +1098,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
instruction->set_name(name);
- // Add common attrs (sharding, control predecessors) to the instruction, if
- // they were seen.
+ // Add shared attributes like metadata to the instruction, if they were seen.
if (sharding) {
instruction->set_sharding(
HloSharding::FromProto(sharding.value()).ValueOrDie());
@@ -1112,6 +1115,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (metadata) {
instruction->set_metadata(*metadata);
}
+ if (backend_config) {
+ instruction->set_backend_config(std::move(*backend_config));
+ }
return AddInstruction(name, instruction, name_loc);
} // NOLINT(readability/fn_size)
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index d38d8907a6..e100d8cda1 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -65,7 +65,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
R"(HloModule constant_pred_module
ENTRY %constant_pred () -> pred[] {
- ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}
+ ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar"
}
)"
@@ -81,13 +81,14 @@ ENTRY %constant_s32 () -> s32[] {
)"
},
-// f32 constant, but the value is not a decimal
+// f32 constant, but the value is not a decimal and there is a backend
+// configuration
{
"ConstantF32",
R"(HloModule ConstantF32_module
ENTRY %ConstantF32.v4 () -> f32[] {
- ROOT %constant = f32[] constant(42)
+ ROOT %constant = f32[] constant(42), backend_config="this is a configuration"
}
)"
@@ -1013,6 +1014,19 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] {
// but the constant names will not be exactly the same.
}
+TEST_F(HloParserTest, ConfigurationField) {
+ const string original = R"(HloModule AModule
+ENTRY %configuration_test() -> s32[] {
+ %constant = s32[] constant(42), backend_config="foo bar"
+})";
+ auto result = Parse(original);
+ TF_ASSERT_OK(result.status());
+ EXPECT_EQ("foo bar", result.ValueOrDie()
+ ->entry_computation()
+ ->root_instruction()
+ ->backend_config());
+}
+
TEST_F(HloParserTest, LiteralDimensionsMismatch_1) {
const string original = R"(HloModule some_2_module
@@ -1092,7 +1106,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
- ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
+ ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
}
)";