aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-13 11:37:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-13 11:45:43 -0800
commit3db96abfc5432c190d3afa62ebfad3c1d82cd818 (patch)
treea5926f485e408ad968a77f1e540dbcdeff2f23bb /tensorflow/compiler
parent58f7858601b72aa3c5854571f2152b91d1795e29 (diff)
Allow assigning colors based on HLO sharding information, when generating Graphviz HLO graphs via a new --xla_hlo_graph_sharding_color option.
When generating TF graphs, a new --xla_hlo_tfgraph_device_scopes option allows to prefix the instructions names with a device scope. This help the TF graph viewer to better isolate the parts of the graph which are targeted to different devices, and allow rendering of graphs which would not be able to due to size. Changed TF/XLA broadcast lowering to propagate the request metadata into the HLO broadcast instructions. PiperOrigin-RevId: 175563052
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc16
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc201
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc29
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder.h7
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc3
-rw-r--r--tensorflow/compiler/xla/xla.proto8
9 files changed, 169 insertions, 103 deletions
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index f2cdd9669c..bfafef0a40 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -31,7 +31,6 @@ std::vector<tensorflow::Flag>* flag_objects;
std::once_flag flags_init;
void SetDebugOptionsDefaults(DebugOptions* flags) {
- flags->set_xla_hlo_graph_path("/tmp/");
flags->set_xla_enable_fast_math(true);
flags->set_xla_llvm_enable_alias_scope_metadata(true);
flags->set_xla_llvm_enable_noalias_metadata(true);
@@ -118,8 +117,21 @@ void AllocateFlags() {
flag_values->xla_hlo_dump_as_graphdef(),
"Dump HLO graphs as TensorFlow GraphDefs."),
tensorflow::Flag(
+ "xla_hlo_graph_sharding_color",
+ bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color),
+ flag_values->xla_hlo_graph_sharding_color(),
+ "Assign colors based on sharding assignments when generating the "
+ "HLO graphs."),
+ tensorflow::Flag(
+ "xla_hlo_tfgraph_device_scopes",
+ bool_setter_for(&DebugOptions::set_xla_hlo_tfgraph_device_scopes),
+ flag_values->xla_hlo_tfgraph_device_scopes(),
+ "When generating TensorFlow HLO graphs, if the HLO instructions "
+ "are assigned to a specific device, prefix the name scope with "
+ "\"devX\" with X being the device ordinal."),
+ tensorflow::Flag(
"xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(),
- "HLO modules matching this regex will be dumped to LOG(INFO). "),
+ "HLO modules matching this regex will be dumped to LOG(INFO)."),
tensorflow::Flag(
"xla_generate_hlo_text_to",
flag_values->mutable_xla_generate_hlo_text_to(),
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 7cf24641b5..c163a5f837 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1985,6 +1985,7 @@ cc_library(
":hlo",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_proto",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 04b3059fb1..e4c89cd8c1 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -312,11 +312,11 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
class HloDotDumper {
public:
HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label,
- bool show_addresses, bool show_metadata,
+ const DebugOptions& debug_options, bool show_metadata,
const HloExecutionProfile* profile, NodeFilter filter)
: computation_(computation),
label_(label.ToString()),
- show_addresses_(show_addresses),
+ debug_options_(debug_options),
show_metadata_(show_metadata),
profile_(profile),
filter_(std::move(filter)) {}
@@ -382,7 +382,7 @@ class HloDotDumper {
const HloComputation* computation_; // never null
const string label_; // overall name for the graph
- const bool show_addresses_;
+ const DebugOptions& debug_options_;
const bool show_metadata_;
const HloExecutionProfile* profile_; // may be null
const NodeFilter filter_;
@@ -414,6 +414,11 @@ class HloDotDumper {
// appears before both the inner computation and the destination node are
// defined.
std::vector<string> edges_;
+
+ // When coloring by sharding information, we track the sharding string
+ // representation to color association, by round-robin the color schemes.
+ std::unordered_map<string, ColorScheme> sharding_colors_;
+ int64 next_shard_color_ = 0;
};
string HloDotDumper::Dump() {
@@ -734,15 +739,16 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
AddInstructionIncomingEdges(instr);
- // Override the node's styling if it should be (de-)emphasized.
- if (filter_.Deemphasized(instr)) {
- color = kDashedBorder;
- }
- if (filter_.Highlight(instr)) {
- node_shape = "diamond";
- color = kDarkRed;
+ if (!debug_options_.xla_hlo_graph_sharding_color()) {
+ // Override the node's styling if it should be (de-)emphasized.
+ if (filter_.Deemphasized(instr)) {
+ color = kDashedBorder;
+ }
+ if (filter_.Highlight(instr)) {
+ node_shape = "diamond";
+ color = kDarkRed;
+ }
}
-
// Build the text that will be displayed inside the node.
string node_body = node_label;
for (const string& s :
@@ -827,6 +833,20 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
}
ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
+ if (debug_options_.xla_hlo_graph_sharding_color()) {
+ if (!instr->has_sharding()) {
+ return kDashedBorder;
+ }
+ string shard_str = instr->sharding().ToString();
+ auto it = sharding_colors_.find(shard_str);
+ if (it != sharding_colors_.end()) {
+ return it->second;
+ }
+ ColorScheme color = static_cast<ColorScheme>(
+ kBlue + (next_shard_color_++ % (kDashedBorder - kBlue)));
+ sharding_colors_.emplace(shard_str, color);
+ return color;
+ }
const auto kParameterColor = kOrange;
// Special case: If this instruction has a parameter merged into it, paint it
@@ -1079,8 +1099,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
}
lines.push_back(instr_shape);
}
-
- if (show_addresses_) {
+ if (debug_options_.xla_hlo_graph_addresses()) {
lines.push_back(Printf("[%p]", instr));
}
if (profile_ != nullptr) {
@@ -1177,70 +1196,36 @@ const HloInstruction* HloDotDumper::GetNodeForEdge(
return instr;
}
-tensorflow::mutex& RendererMutex() {
- static tensorflow::mutex* mu = new tensorflow::mutex;
- return *mu;
-}
+class GraphRendererRegistry {
+ public:
+ void AddRenderer(GraphRendererInterface* graph_renderer) {
+ tensorflow::mutex_lock lock(mu_);
+ graph_renderer_ = graph_renderer;
+ }
-std::map<int, GraphRendererInterface*>* GraphRenderers() {
- static auto* graph_renderers = new std::map<int, GraphRendererInterface*>();
- return graph_renderers;
-}
+ GraphRendererInterface* GetDefaultRenderer() {
+ tensorflow::mutex_lock lock(mu_);
+ return graph_renderer_;
+ }
-GraphRendererInterface* GetGraphRenderer() {
- tensorflow::mutex_lock lock(RendererMutex());
- auto* graph_renderers = GraphRenderers();
- auto it = graph_renderers->rbegin();
- CHECK(it != graph_renderers->rend()) << "No registered graph dumpers";
- return it->second;
-}
+ static GraphRendererRegistry* Default() {
+ static GraphRendererRegistry* registry = new GraphRendererRegistry();
+ return registry;
+ }
+
+ private:
+ tensorflow::mutex mu_;
+ GraphRendererInterface* graph_renderer_ = nullptr;
+};
} // namespace
-Registrar::Registrar(GraphRendererInterface* dumper, int priority) {
- tensorflow::mutex_lock lock(RendererMutex());
- auto* graph_renderers = GraphRenderers();
- graph_renderers->emplace(priority, dumper);
+Registrar::Registrar(GraphRendererInterface* dumper) {
+ GraphRendererRegistry::Default()->AddRenderer(dumper);
}
namespace {
-class FileGraphRenderer : public GraphRendererInterface {
- public:
- string RenderGraph(const string& graph, GraphKind graph_kind,
- const DebugOptions& debug_options) override {
- static std::atomic<int> output_num(0);
- string file_extension;
- switch (graph_kind) {
- case DOT_GRAPH:
- file_extension = ".dot";
- break;
- case TF_GRAPHDEF:
- file_extension = ".pbtxt";
- break;
- }
- string path =
- JoinPath(debug_options.xla_hlo_graph_path(),
- StrCat("hlo_graph_", output_num++, ".XXXXXX", file_extension));
- auto status = Status::OK();
- int fd = mkstemps(&path[0], file_extension.length());
- if (fd < 0) {
- status =
- Status(tensorflow::error::Code::UNKNOWN,
- StrCat("Failed to create temporary file to dump HLO graph: ",
- strerror(errno)));
- } else {
- status = tensorflow::WriteStringToFile(tensorflow::Env::Default(), path,
- graph);
- close(fd);
- }
- if (!status.ok()) {
- LOG(WARNING) << "Saving HLO graph failed: " << status;
- }
- return path;
- }
-};
-
// Gets a NodeFilter that includes roughly all instructions whose distance from
// root is <= radius.
NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) {
@@ -1350,7 +1335,54 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) {
});
}
-XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0);
+string SaveGraph(const string& graph,
+ GraphRendererInterface::GraphKind graph_kind,
+ const string& dest_path) {
+ static std::atomic<int> output_num(0);
+ string file_extension;
+ switch (graph_kind) {
+ case GraphRendererInterface::DOT_GRAPH:
+ file_extension = ".dot";
+ break;
+ case GraphRendererInterface::TF_GRAPHDEF:
+ file_extension = ".pbtxt";
+ break;
+ }
+ string path = JoinPath(
+ dest_path, StrCat("hlo_graph_", output_num++, ".XXXXXX", file_extension));
+ auto status = Status::OK();
+ int fd = mkstemps(&path[0], file_extension.length());
+ if (fd < 0) {
+ status =
+ Status(tensorflow::error::Code::UNKNOWN,
+ StrCat("Failed to create temporary file to dump HLO graph: ",
+ strerror(errno)));
+ } else {
+ status =
+ tensorflow::WriteStringToFile(tensorflow::Env::Default(), path, graph);
+ close(fd);
+ }
+ if (!status.ok()) {
+ LOG(WARNING) << "Saving HLO graph failed: " << status;
+ }
+ return path;
+}
+
+string ExportGraph(const string& graph,
+ GraphRendererInterface::GraphKind graph_kind,
+ const DebugOptions& debug_options) {
+ string path = debug_options.xla_hlo_graph_path();
+ if (!path.empty()) {
+ return SaveGraph(graph, graph_kind, path);
+ } else {
+ auto graph_renderer =
+ GraphRendererRegistry::Default()->GetDefaultRenderer();
+ CHECK(graph_renderer != nullptr)
+ << "No registered renderer for the HLO graph. "
+ "Use --xla_hlo_graph_path=PATH to export to local file system";
+ return graph_renderer->RenderGraph(graph, graph_kind, debug_options);
+ }
+}
} // namespace
@@ -1358,27 +1390,22 @@ string DumpGraph(const HloComputation& computation, const string& label,
const DebugOptions& debug_options,
const HloExecutionProfile* hlo_execution_profile,
bool show_metadata) {
+ GraphRendererInterface::GraphKind graph_kind;
string graph;
- string graph_url;
if (debug_options.xla_hlo_dump_as_graphdef()) {
- HloTfGraphBuilder builder;
+ HloTfGraphBuilder builder(debug_options);
TF_CHECK_OK(builder.AddComputation(computation));
CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(),
&graph));
- // TODO(b/37198616): Use the default registered renderers when all
- // renderers support rendering GraphDefs. Always dump GraphDefs to files
- // for now.
- graph_url = FileGraphRenderer().RenderGraph(
- graph, GraphRendererInterface::TF_GRAPHDEF, debug_options);
+ graph_kind = GraphRendererInterface::TF_GRAPHDEF;
} else {
- graph =
- HloDotDumper(&computation, label,
- /*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
- show_metadata, hlo_execution_profile, NodeFilter())
- .Dump();
- graph_url = GetGraphRenderer()->RenderGraph(
- graph, GraphRendererInterface::DOT_GRAPH, debug_options);
+ graph = HloDotDumper(&computation, label, debug_options, show_metadata,
+ hlo_execution_profile, NodeFilter())
+ .Dump();
+ graph_kind = GraphRendererInterface::DOT_GRAPH;
}
+
+ string graph_url = ExportGraph(graph, graph_kind, debug_options);
LOG(INFO) << "computation " << computation.name() << " [" << label
<< "]: " << graph_url;
return graph_url;
@@ -1391,12 +1418,10 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius,
StrCat("Neighborhood of ", radius, " nodes around ", node.name());
NodeFilter filter = MakeNodeFilter(&node, radius);
string graph =
- HloDotDumper(node.parent(), label,
- /*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
- show_metadata, /*profile=*/nullptr, filter)
+ HloDotDumper(node.parent(), label, debug_options, show_metadata,
+ /*profile=*/nullptr, filter)
.Dump();
- return GetGraphRenderer()->RenderGraph(
- graph, GraphRendererInterface::DOT_GRAPH, debug_options);
+ return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options);
}
void DumpText(const HloModule& module, const string& label,
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h
index dd304ec76c..2704aae1e3 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h
@@ -84,11 +84,10 @@ void DumpText(const HloModule& module, const string& label,
// Internal implementation details below this point.
-// Class that registers a graph renderer. Higher-priority renders are chosen
-// first.
+// Class that registers a graph renderer.
class Registrar {
public:
- Registrar(GraphRendererInterface* dumper, int priority);
+ Registrar(GraphRendererInterface* dumper);
};
#define XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, ctr, ...) \
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
index 7b0f937f38..8e1531c87f 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
@@ -45,7 +45,7 @@ class DotRenderer : public hlo_graph_dumper::GraphRendererInterface {
string last_graph_;
};
-XLA_REGISTER_GRAPH_RENDERER(DotRenderer, std::numeric_limits<int>::max());
+XLA_REGISTER_GRAPH_RENDERER(DotRenderer);
TEST(HloGraphDumperTest, NestedFusion) {
HloComputation::Builder b("b");
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
index 06abe00747..101a710d1c 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
@@ -58,8 +58,6 @@ TensorShapeProto GetTensorShape(const HloInstruction* instruction) {
string GetDeviceName(int device) { return StrCat("/device/XLA:", device); }
-} // namespace
-
void CleanNodeName(string* name) {
name->erase(std::remove(name->begin(), name->end(), '%'), name->end());
const string chars_to_replace = "<>[]";
@@ -70,6 +68,11 @@ void CleanNodeName(string* name) {
std::replace_if(name->begin(), name->end(), pred, '_');
}
+} // namespace
+
+HloTfGraphBuilder::HloTfGraphBuilder(const DebugOptions& debug_options)
+ : debug_options_(debug_options) {}
+
Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) {
VLOG(2) << "Adding computation " << computation.name();
for (auto embedded : computation.MakeEmbeddedComputationsList()) {
@@ -90,24 +93,38 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction(
if (ContainsKey(instruction_to_node_name_, instruction)) {
return instruction_to_node_name_[instruction];
}
+ auto append = [](string* str, const string& other) {
+ if (str->empty()) {
+ *str = other;
+ } else if (!other.empty()) {
+ StrAppend(str, "/", other);
+ }
+ };
string node_name;
+ if (debug_options_.xla_hlo_tfgraph_device_scopes() &&
+ instruction->has_sharding() &&
+ instruction->sharding().HasUniqueDevice()) {
+ node_name = StrCat(
+ "dev", instruction->sharding().UniqueDevice().ConsumeValueOrDie());
+ }
// If an instruction is fused, put it in the subgraph of the fusion;
// otherwise, put it in the computation subgraph.
const HloComputation* computation = instruction->parent();
if (computation->IsFusionComputation()) {
- node_name = GetNodeNameForInstruction(computation->FusionInstruction());
+ append(&node_name,
+ GetNodeNameForInstruction(computation->FusionInstruction()));
} else {
- node_name = computation->name();
+ append(&node_name, computation->name());
if (!instruction->metadata().op_name().empty()) {
// Always make computations contain TF ops but not the other way around.
- StrAppend(&node_name, "/", instruction->metadata().op_name());
+ append(&node_name, instruction->metadata().op_name());
}
}
string instruction_name = instruction->name();
if (instruction->opcode() == HloOpcode::kParameter) {
StrAppend(&instruction_name, ".", instruction->parameter_number());
}
- StrAppend(&node_name, "/", instruction_name);
+ append(&node_name, instruction_name);
CleanNodeName(&node_name);
auto ret =
instruction_to_node_name_.insert(std::make_pair(instruction, node_name));
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h
index b2c578af91..9aa3e501d5 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h
@@ -17,6 +17,7 @@ limitations under the License.
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@@ -26,6 +27,8 @@ namespace hlo_graph_dumper {
// This constructs a tensorflow graph for HLO computations.
class HloTfGraphBuilder {
public:
+ HloTfGraphBuilder(const DebugOptions& debug_options = DebugOptions());
+
// Adds a computation to the graph.
Status AddComputation(const HloComputation& computation);
@@ -42,6 +45,7 @@ class HloTfGraphBuilder {
Status AddInstruction(const HloInstruction* instruction);
+ DebugOptions debug_options_;
tensorflow::GraphDef graph_def_;
// This records instructions that have been visited.
std::unordered_set<const HloInstruction*> visited_instructions_;
@@ -49,9 +53,6 @@ class HloTfGraphBuilder {
std::unordered_map<const HloInstruction*, string> instruction_to_node_name_;
};
-// Cleans the node name to make it a valid name in a tensorflow graph.
-void CleanNodeName(string* name);
-
} // namespace hlo_graph_dumper
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 8d5bb08e51..8f63c92e5b 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -2538,6 +2538,7 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast(
if (ShapeUtil::IsScalar(operand->shape())) {
HloInstruction* broadcast = hlo_builder_.AddInstruction(
HloInstruction::CreateBroadcast(broadcast_shape, operand, {}));
+ broadcast->set_metadata(operand->metadata());
if (operand->has_sharding()) {
broadcast->set_sharding(operand->sharding());
}
@@ -2558,6 +2559,7 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast(
ShapeUtil::MakeShape(operand->shape().element_type(),
reshaped_dimensions),
operand));
+ reshaped_operand->set_metadata(operand->metadata());
if (operand->has_sharding()) {
reshaped_operand->set_sharding(operand->sharding());
}
@@ -2565,6 +2567,7 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast(
HloInstruction* broadcast =
hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast(
broadcast_shape, reshaped_operand, broadcast_dimensions));
+ broadcast->set_metadata(operand->metadata());
if (operand->has_sharding()) {
broadcast->set_sharding(operand->sharding());
}
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index 710bb6ff25..127e5e81ac 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -167,6 +167,14 @@ message DebugOptions {
// computation will run 2! * 4! times.
bool xla_test_all_input_layouts = 91;
+ // Assign colors based on sharding information when generating the Graphviz
+ // HLO graph.
+ bool xla_hlo_graph_sharding_color = 92;
+
+ // Prefix the name scopes of the TF graph exports with "devX" device
+ // assignments, if available.
+ bool xla_hlo_tfgraph_device_scopes = 93;
+
// Extra options to pass to the compilation backend; specific interpretation
// of these values is left to the backend.
map<string, string> xla_backend_extra_options = 500;