aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_graph_dumper.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc515
1 files changed, 351 insertions, 164 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 2455925b96..097a762015 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -16,8 +16,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include <unistd.h>
+#include <algorithm>
+#include <atomic>
#include <deque>
+#include <map>
+#include <memory>
#include <string>
+#include <tuple>
+#include <unordered_map>
+#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
@@ -27,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -37,13 +45,15 @@ limitations under the License.
#include "tensorflow/core/platform/regexp.h"
using ::tensorflow::Env;
-using ::tensorflow::WriteStringToFile;
+using ::tensorflow::gtl::nullopt;
+using ::tensorflow::gtl::optional;
using ::tensorflow::io::JoinPath;
using ::tensorflow::strings::Appendf;
using ::tensorflow::strings::Printf;
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
using ::tensorflow::str_util::Join;
+using ::tensorflow::WriteStringToFile;
namespace xla {
namespace hlo_graph_dumper {
@@ -63,11 +73,26 @@ enum ColorScheme {
kRed,
kWhite,
kYellow,
+
+ // Causes the node's border to be a dashed line, and its content to be gray
+ // text on a white background, suggesting that this is an "unimportant" node.
+ kDashedBorder,
};
-// Used to indicate how we should treat a given HLOInstruction in the graph --
-// should we treat it like normal, hide it, or highlight it?
-enum NodeFilterResult { kNormalNode, kHideNode, kHighlightNode };
+// Used to indicate how we should treat a given HLOInstruction in the graph.
+// should we treat it like normal, hide it, and so on?
+enum NodeFilterResult {
+ kNormalNode,
+ kHideNode,
+ // Make the node easy to find in the final graph.
+ kHighlightNode,
+ // "Gray out" the node to indicate that some of its operands have been
+ // omitted.
+ kSomeOperandsOmitted,
+ // Style the node the same as kSomeOperandsOmitted, but also don't connect it
+ // to its operands, even if they're present in the graph.
+ kOmitNodeOperands,
+};
// NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
// It lets callers tell the graph-drawing routines which nodes they want to be
@@ -86,51 +111,59 @@ class NodeFilter {
bool Highlight(const HloInstruction* instr) const {
return filter_(instr) == kHighlightNode;
}
+ bool OmitOperands(const HloInstruction* instr) const {
+ return filter_(instr) == kOmitNodeOperands;
+ }
+ bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const {
+ auto result = filter_(instr);
+ return result == kOmitNodeOperands || result == kSomeOperandsOmitted;
+ }
private:
std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
};
// Given a ColorScheme, returns an attribute string for a node of that color.
-// Sets the node's fill, stroke, and text colors.
+// Sets the node's style and fill/stroke/text colors.
//
// Colors are from https://material.io/color.
string NodeColorAttributes(ColorScheme color) {
using std::make_tuple;
- const char *fill_color, *stroke_color, *font_color;
- std::tie(fill_color, stroke_color, font_color) =
- [color]() -> std::tuple<const char*, const char*, const char*> {
+ const char *style, *fill_color, *stroke_color, *font_color;
+ std::tie(style, fill_color, stroke_color, font_color) = [color] {
switch (color) {
case kBlue:
- return make_tuple("#bbdefb", "#8aacc8", "black");
+ return make_tuple("filled", "#bbdefb", "#8aacc8", "black");
case kBrown:
- return make_tuple("#bcaaa4", "#8c7b75", "black");
+ return make_tuple("filled", "#bcaaa4", "#8c7b75", "black");
case kDarkBlue:
- return make_tuple("#1565c0", "#003c8f", "white");
+ return make_tuple("filled", "#1565c0", "#003c8f", "white");
case kDarkGreen:
- return make_tuple("#2e7d32", "#005005", "white");
+ return make_tuple("filled", "#2e7d32", "#005005", "white");
case kDarkRed:
- return make_tuple("#b71c1c", "#7f0000", "white");
+ return make_tuple("filled", "#b71c1c", "#7f0000", "white");
case kGray:
- return make_tuple("#cfd8dc", "#9ea7aa", "black");
+ return make_tuple("filled", "#cfd8dc", "#9ea7aa", "black");
case kGreen:
- return make_tuple("#c8e6c9", "#97b498", "black");
+ return make_tuple("filled", "#c8e6c9", "#97b498", "black");
case kOrange:
- return make_tuple("#ffe0b2", "#cbae82", "black");
+ return make_tuple("filled", "#ffe0b2", "#cbae82", "black");
case kPurple:
- return make_tuple("#e1bee7", "#af8eb5", "black");
+ return make_tuple("filled", "#e1bee7", "#af8eb5", "black");
case kRed:
- return make_tuple("#ffcdd2", "#cb9ca1", "black");
+ return make_tuple("filled", "#ffcdd2", "#cb9ca1", "black");
case kWhite:
- return make_tuple("white", "black", "black");
+ return make_tuple("filled", "white", "black", "black");
case kYellow:
- return make_tuple("#fff9c4", "#cbc693", "black");
+ return make_tuple("filled", "#fff9c4", "#cbc693", "black");
+ case kDashedBorder:
+ return make_tuple("dashed", "white", "#757575", "#757575");
}
}();
return Printf(
- "style=filled, fontcolor=\"%s\", color=\"%s\", fillcolor=\"%s\"",
+ R"(style=%s, fontcolor="%s", color="%s", fillcolor="%s")", style,
font_color, stroke_color, fill_color);
}
@@ -152,6 +185,70 @@ string ComputationId(const HloComputation* computation) {
return Printf("%lld", reinterpret_cast<uint64>(computation));
}
+// Tries to generates a human-readable one-word description of the given
+// computation.
+//
+// Currently we support:
+//
+// "return param0 + param1;" --> "add"
+// "return param0 * param1;" --> "multiply"
+// "return min(param0, param1);" --> "min"
+// "return max(param0, param1);" --> "max"
+//
+// where param0 and param1 are effective scalars. Since all of the ops above
+// are commutative, we also support them with param0 and param1 swapped.
+//
+// This is useful primarily for reduce and map nodes. These take a
+// subcomputation which is almost always one of the four above, and pattern
+// matching it to a short string lets us tell the user what the subcomputation
+// is without drawing it as a graph.
+optional<string> MatchTrivialComputation(const HloComputation* computation) {
+ if (computation->instruction_count() != 3) {
+ return nullopt;
+ }
+
+ HloInstruction* root = computation->root_instruction();
+ if (root->operand_count() != 2) {
+ return nullopt;
+ }
+
+ // Check that both of the operands to the root are parameters.
+ const HloInstruction* operand0 = root->operand(0);
+ const HloInstruction* operand1 = root->operand(1);
+ if (operand0->opcode() != HloOpcode::kParameter ||
+ operand1->opcode() != HloOpcode::kParameter) {
+ return nullopt;
+ }
+ // Check that the two operands of root are param0 and param1. All of the
+ // opcodes we recognize are commutative, so we're OK with either order.
+ auto n0 = operand0->parameter_number();
+ auto n1 = operand1->parameter_number();
+ if (!(n0 == 0 && n1 == 1) && !(n1 == 0 && n0 == 1)) {
+ return nullopt;
+ }
+
+ // Check that the root and params are all effective scalars.
+ if (!ShapeUtil::IsEffectiveScalar(root->shape()) ||
+ !ShapeUtil::IsEffectiveScalar(operand0->shape()) ||
+ !ShapeUtil::IsEffectiveScalar(operand1->shape())) {
+ return nullopt;
+ }
+
+ // If we recognize the root's opcode, we've successfully pattern-matched!
+ switch (root->opcode()) {
+ case HloOpcode::kAdd:
+ return "add";
+ case HloOpcode::kMultiply:
+ return "multiply";
+ case HloOpcode::kMinimum:
+ return "min";
+ case HloOpcode::kMaximum:
+ return "max";
+ default:
+ return nullopt;
+ }
+}
+
// Returns the dot graph edges and nodes for the given instruction sequence.
// Edges which extend between computations are added to the vector
// intercomputation_edges. This is necessary because graphviz does not render
@@ -165,58 +262,49 @@ string InstructionSequenceGraph(
const NodeFilter& filter) {
string graph_body;
- // Create a single "record" node for the parameters. This node is a
- // partitioned rectangle with one partition per parameter node. The keeps
- // all the parameter instructions together.
- std::vector<HloInstruction*> param_instructions;
for (auto& instruction : instructions) {
- if (instruction->opcode() == HloOpcode::kParameter) {
- size_t param_number = instruction->parameter_number();
-
- if (param_instructions.size() < param_number + 1) {
- param_instructions.resize(param_number + 1, nullptr);
- }
- param_instructions[param_number] = instruction.get();
- }
- }
- string param_node_name;
- if (!param_instructions.empty()) {
- std::vector<string> param_ports;
- param_node_name =
- StrCat("parameters_", InstructionId(param_instructions[0]));
- for (auto& param : param_instructions) {
- if (!filter.Show(param)) {
- continue;
- }
- string label = StrCat(param->parameter_name(), "\\n",
- ShapeUtil::HumanString(param->shape()));
- if (show_addresses) {
- Appendf(&label, "\\n[%p]", param);
- }
- if (show_layouts) {
- StrAppend(&label, "\\nlayout=\\{",
- Join(param->shape().layout().minor_to_major(), ","), "\\}");
- }
- param_ports.push_back(
- Printf("<%s> %s", InstructionId(param).c_str(), label.c_str()));
+ if (!filter.Show(instruction.get())) {
+ continue;
}
- // (If we wanted the word "parameters" to be bold like the other op names,
- // we'd have to make this into an HTML-like table. It is possible but
- // complicated; see http://www.graphviz.org/doc/info/shapes.html#html.)
- StrAppend(&graph_body, param_node_name, " [shape=record ",
- NodeColorAttributes(kOrange), "label=\"{parameters | {",
- Join(param_ports, "|"), "}}\"];\n");
- }
- for (auto& instruction : instructions) {
- if (!filter.Show(instruction.get())) {
+ // We don't display constants as separate nodes; they're merged into their
+ // users.
+ if (instruction->opcode() == HloOpcode::kConstant) {
continue;
}
+
ColorScheme color = kYellow;
string shape = "box";
- string name =
- StrCat("<b>", HtmlLikeStringSanitize(instruction->ExtendedOpcodeStr()),
- "</b> ", HtmlLikeStringSanitize(instruction->name()));
+
+ // Build the first line or two of the node, containing its name and opcode
+ // (if the opcode isn't redundant with the name).
+ string name;
+ if (instruction->opcode() == HloOpcode::kParameter) {
+ // If we have a parameter, put the param number in the name.
+ name = StrCat("<b>Parameter ", instruction->parameter_number(),
+ "</b><br/>", HtmlLikeStringSanitize(instruction->name()));
+ } else if (tensorflow::StringPiece(instruction->name())
+ .starts_with(
+ StrCat("%", instruction->ExtendedOpcodeStr()))) {
+ // The HLO instruction name contains usually the opcode, e.g. "%add.42" is
+ // an add instruction. In this case we render just the name.
+ name = StrCat("<b>", HtmlLikeStringSanitize(instruction->name()), "</b>");
+ } else if (instruction->opcode() == HloOpcode::kFusion &&
+ tensorflow::StringPiece(instruction->name())
+ .starts_with(
+ StrCat("%", HloOpcodeString(instruction->opcode())))) {
+ // Fusion nodes are usually named e.g. "%fusion.5". We render these as
+ // e.g. "%fusion.5<br/>input fusion".
+ name = StrCat("<b>", HtmlLikeStringSanitize(instruction->name()),
+ "</b><br/>",
+ HtmlLikeStringSanitize(instruction->ToCategory()));
+ } else {
+ // If the name does not contain the opcode, render both.
+ name = StrCat("<b>",
+ HtmlLikeStringSanitize(instruction->ExtendedOpcodeStr()),
+ "</b><br/>", HtmlLikeStringSanitize(instruction->name()));
+ }
+
if (HloOpcode::kConvolution == instruction->opcode()) {
StrAppend(
&name, "<br/>",
@@ -305,18 +393,13 @@ string InstructionSequenceGraph(
case HloOpcode::kUpdate:
color = kGreen;
break;
- case HloOpcode::kConstant:
- color = kBlue;
- break;
case HloOpcode::kConvolution:
case HloOpcode::kDot:
color = kDarkBlue;
break;
case HloOpcode::kParameter:
- // A single record node is created for all the parameter nodes with a
- // port for each parameter instruction. No need to emit anything in this
- // case.
- continue;
+ color = kOrange;
+ break;
case HloOpcode::kBatchNormTraining:
StrAppend(&name, " feature_index=", instruction->feature_index());
color = kPurple;
@@ -361,6 +444,8 @@ string InstructionSequenceGraph(
// will be inserted as modifications to an existing graph.
color = kRed;
break;
+ case HloOpcode::kConstant:
+ LOG(FATAL) << "Constants don't get their own nodes in the graph.";
}
// Create instruction node with appropriate label, shape, and color.
@@ -369,14 +454,6 @@ string InstructionSequenceGraph(
string label =
StrCat(name, "<br/>", ShapeUtil::HumanString(instruction->shape()));
- if (instruction->opcode() == HloOpcode::kConstant &&
- ShapeUtil::IsEffectiveScalar(instruction->shape())) {
- auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex(
- instruction->shape(), /*linear_index=*/0);
- StrAppend(&label, " = {", instruction->literal().GetAsString(elem_idx),
- "}");
- }
-
if (show_addresses) {
Appendf(&label, "<br/>[%p]", instruction.get());
}
@@ -405,75 +482,116 @@ string InstructionSequenceGraph(
}
}
+ // If this node's operands are omitted, style it accordingly.
+ if (filter.SomeOrAllOperandsOmitted(instruction.get())) {
+ color = kDashedBorder;
+ }
+
// If this node is highlighted, override its formatting.
if (filter.Highlight(instruction.get())) {
shape = "diamond";
color = kDarkRed;
}
- Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n",
- InstructionId(instruction.get()).c_str(), label.c_str(),
- shape.c_str(), NodeColorAttributes(color).c_str());
-
// Create edges from the instruction's operands to the instruction.
- int64 operand_number = 0;
- for (auto* operand : instruction->operands()) {
- if (!filter.Show(operand)) {
+ if (!filter.OmitOperands(instruction.get())) {
+ int64 operand_number = 0;
+ for (auto* operand : instruction->operands()) {
+ if (!filter.Show(operand) ||
+ operand->opcode() == HloOpcode::kConstant) {
+ ++operand_number;
+ continue;
+ }
+ Appendf(&graph_body, "%s -> %s", InstructionId(operand).c_str(),
+ InstructionId(instruction.get()).c_str());
+ if (instruction->operand_count() > 1) {
+ Appendf(&graph_body, " [headlabel=\"%lld\",labeldistance=2]",
+ operand_number);
+ }
+ StrAppend(&graph_body, ";\n");
++operand_number;
- continue;
}
- string src;
- if (operand->opcode() == HloOpcode::kParameter) {
- // If operand is a parameter, then select the proper partition (port) in
- // the unified parameter node.
- src = param_node_name + ":" + InstructionId(operand);
+
+ // Fusion nodes are handled specially because they contain nested
+ // expressions.
+ if (instruction->opcode() == HloOpcode::kFusion) {
+ string cluster_name =
+ StrCat("cluster_", InstructionId(instruction.get()));
+ StrAppend(&graph_body, "subgraph ", cluster_name, " {\n");
+ StrAppend(&graph_body, "label=<fused expression for <b>",
+ HtmlLikeStringSanitize(instruction->name()),
+ "</b>>;\nstyle=\"rounded,filled\";\n"
+ "color=lightgrey;\n");
+ StrAppend(&graph_body,
+ InstructionSequenceGraph(instruction->fused_instructions(),
+ show_addresses, show_layouts,
+ intercomputation_edges,
+ hlo_execution_profile, NodeFilter()),
+ "}\n");
+ string fusion_edge = StrCat(
+ InstructionId(instruction->fused_expression_root()), " -> ",
+ InstructionId(instruction.get()),
+ " [ style = \"dotted\", arrowsize=0.0, ltail=", cluster_name,
+ " ];\n");
+ intercomputation_edges->push_back(fusion_edge);
} else {
- src = InstructionId(operand);
- }
- Appendf(&graph_body, "%s -> %s", src.c_str(),
- InstructionId(instruction.get()).c_str());
- if (instruction->operand_count() > 1) {
- Appendf(&graph_body, " [headlabel=\"%lld\",labeldistance=2]",
- operand_number);
+ // If instruction has just one computation and it's trivial (e.g.
+ // "return param0 + param1"), put the trivial computation type (e.g.
+ // "add") into instruction's label. Otherwise, add a dotted edge
+ // between the instruction and its subcomputations.
+ const auto& subcomputations = instruction->called_computations();
+
+ bool trivial_subcomputation = false;
+ if (subcomputations.size() == 1) {
+ optional<string> computation_type =
+ MatchTrivialComputation(subcomputations.front());
+ if (computation_type) {
+ trivial_subcomputation = true;
+ StrAppend(&label, "<br/>Subcomputation: <b>", *computation_type,
+ "</b>");
+ }
+ }
+
+ if (!trivial_subcomputation) {
+ for (const HloComputation* computation :
+ instruction->called_computations()) {
+ string cluster_name =
+ StrCat("cluster_", ComputationId(computation));
+ string call_edge = Printf(
+ "%s -> %s [ style=dashed; ltail=%s ];\n",
+ InstructionId(computation->root_instruction()).c_str(),
+ InstructionId(instruction.get()).c_str(), cluster_name.c_str());
+ intercomputation_edges->push_back(call_edge);
+ }
+ }
}
- StrAppend(&graph_body, ";\n");
- ++operand_number;
}
- // Fusion nodes are handled specially because they contain nested
- // expressions.
- if (instruction->opcode() == HloOpcode::kFusion) {
- string cluster_name =
- StrCat("cluster_", InstructionId(instruction.get()));
- StrAppend(&graph_body, "subgraph ", cluster_name, " {\n");
- StrAppend(&graph_body,
- "label=<<b>fused expression</b>>;\nstyle=\"rounded,filled\";\n"
- "color=lightgrey;\n");
- StrAppend(&graph_body,
- InstructionSequenceGraph(instruction->fused_instructions(),
- show_addresses, show_layouts,
- intercomputation_edges,
- hlo_execution_profile, NodeFilter()),
- "}\n");
- string fusion_edge =
- StrCat(InstructionId(instruction->fused_expression_root()), " -> ",
- InstructionId(instruction.get()),
- " [ style = \"dotted\", arrowsize=0.0, ltail=", cluster_name,
- " ];\n");
- intercomputation_edges->push_back(fusion_edge);
- } else {
- // Add a dotted edge between the instruction and any computations that the
- // instruction calls.
- for (const HloComputation* computation :
- instruction->called_computations()) {
- string cluster_name = StrCat("cluster_", ComputationId(computation));
- string call_edge = Printf(
- "%s -> %s [ style=dashed; ltail=%s ];\n",
- InstructionId(computation->root_instruction()).c_str(),
- InstructionId(instruction.get()).c_str(), cluster_name.c_str());
- intercomputation_edges->push_back(call_edge);
+ // Inline constant operands into the node.
+ for (int64 i = 0; i < instruction->operand_count(); ++i) {
+ const HloInstruction* operand = instruction->operand(i);
+ if (operand->opcode() != HloOpcode::kConstant) {
+ continue;
+ }
+
+ StrAppend(&label, "<br/><b>operand ", i, "</b> = ");
+ if (ShapeUtil::IsEffectiveScalar(operand->shape())) {
+ auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex(
+ operand->shape(), /*linear_index=*/0);
+ StrAppend(&label, ShapeUtil::HumanString(operand->shape()), "{",
+ operand->literal().GetAsString(elem_idx), "}");
+ } else {
+ if (tensorflow::StringPiece(operand->name()).starts_with("%constant")) {
+ StrAppend(&label, operand->name());
+ } else {
+ StrAppend(&label, "constant ", operand->name());
+ }
}
}
+
+ Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n",
+ InstructionId(instruction.get()).c_str(), label.c_str(),
+ shape.c_str(), NodeColorAttributes(color).c_str());
}
return graph_body;
}
@@ -513,16 +631,25 @@ stylesheet="%s"
)",
graph_label.c_str(), dot_stylesheet);
+ // Dump the subcomputations of each instruction that's shown and doesn't have
+ // its operands omitted. If an instruction has just one subcomputation and
+ // it's trivial, omit it: We'll display that subcomputation inlined into the
+ // instruction's node when we draw it.
std::unordered_set<const HloComputation*> computations_to_dump;
for (const auto& instr : computation.instructions()) {
- if (!filter.Show(instr.get())) {
+ if (!filter.Show(instr.get()) || filter.OmitOperands(instr.get())) {
continue;
}
if (instr->opcode() == HloOpcode::kFusion) {
computations_to_dump.insert(instr->fused_instructions_computation());
}
- for (const HloComputation* computation : instr->called_computations()) {
- computations_to_dump.insert(computation);
+
+ const auto& subcomputations = instr->called_computations();
+ if (subcomputations.size() != 1 ||
+ !MatchTrivialComputation(subcomputations.front())) {
+ for (const HloComputation* computation : instr->called_computations()) {
+ computations_to_dump.insert(computation);
+ }
}
}
@@ -620,31 +747,41 @@ class FileGraphRenderer : public GraphRendererInterface {
}
};
-// Gets roughly all instructions whose distance from root is <= radius.
-std::unordered_set<const HloInstruction*> GetInstructionsInNeighborhood(
- const HloInstruction& root, int64 radius) {
- std::unordered_set<const HloInstruction*> ret;
-
+// Gets a NodeFilter that includes roughly all instructions whose distance from
+// root is <= radius.
+//
+// It's confusing to draw a node and include only some of its operands. So if
+// some but not all of a node's operands are <= radius units away from the root,
+// we include the other operands (unless there are a lot of them, as often in a
+// tuple node). These additional operands may have as inputs other nodes
+// already present in the graph, but we don't draw those edges unless *all* of
+// the inputs are present. (Otherwise we'd have the same problem we were trying
+// to solve in the first place!)
+NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) {
+ // First, find the neighborhood of nodes with distance from root <= radius.
+ // These nodes are our initial set of "normal" nodes.
+ std::unordered_map<const HloInstruction*, NodeFilterResult> nodes;
std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
- worklist.push_back({&root, 0});
-
+ worklist.push_back({root, 0});
while (!worklist.empty()) {
const HloInstruction* instr;
int64 depth;
std::tie(instr, depth) = worklist.front();
worklist.pop_front();
- ret.insert(instr);
+ nodes[instr] = kNormalNode;
if (depth == radius) {
continue;
}
+ // Traverse into instr's operands.
+ //
// Don't traverse into tuples' operands unless the tuple is the root.
// Usually a tuple is the bottommost node in the graph, and so its operands
// are not interesting to the graph at hand.
- if (instr == &root || instr->opcode() != HloOpcode::kTuple) {
+ if (instr == root || instr->opcode() != HloOpcode::kTuple) {
for (const HloInstruction* operand : instr->operands()) {
- if (ret.find(operand) == ret.end()) {
+ if (!nodes.count(operand)) {
worklist.push_back({operand, depth + 1});
}
}
@@ -655,14 +792,77 @@ std::unordered_set<const HloInstruction*> GetInstructionsInNeighborhood(
// constants' users.
if (instr->opcode() != HloOpcode::kConstant) {
for (const HloInstruction* user : instr->users()) {
- if (ret.find(user) == ret.end()) {
+ if (!nodes.count(user)) {
worklist.push_back({user, depth + 1});
}
}
}
}
- return ret;
+ auto is_displayed = [&](const HloInstruction* instr) {
+ return nodes.count(instr) > 0;
+ };
+
+ // If a node has some but not all of its operands omitted, add the operands to
+ // the map with type kOmitNodeOperands. Unless the node has a lot of
+ // operands, in which case just mark the node as "some operands omitted".
+ std::vector<const HloInstruction*> extra_operands;
+ for (auto& kv : nodes) {
+ const HloInstruction* instr = kv.first;
+ NodeFilterResult& filter_result = kv.second;
+ const auto& operands = instr->operands();
+
+ // Mark nodes with many operands and some omitted as "some operands omitted"
+ // and carry on -- don't add their omitted operands to extra_operands.
+ if (operands.size() > 4) {
+ if (std::any_of(operands.begin(), operands.end(), is_displayed) &&
+ !std::all_of(operands.begin(), operands.end(), is_displayed)) {
+ filter_result = kSomeOperandsOmitted;
+ }
+ continue;
+ }
+
+ if (std::any_of(operands.begin(), operands.end(), is_displayed)) {
+ for (const HloInstruction* operand : operands) {
+ if (!is_displayed(operand)) {
+ extra_operands.push_back(operand);
+ }
+ }
+ }
+ }
+ for (const HloInstruction* instr : extra_operands) {
+ nodes[instr] = kOmitNodeOperands;
+ }
+
+ // Some of the nodes in extra_operands may now have all of their inputs
+ // present in nodes. We can promote these to normal nodes.
+ for (const HloInstruction* instr : extra_operands) {
+ const auto& operands = instr->operands();
+ if (std::all_of(operands.begin(), operands.end(), is_displayed)) {
+ nodes[instr] = kNormalNode;
+ }
+ }
+
+ // If none of a node's operands appear in nodes, mark it as type
+ // kOmitNodeOperands so it gets styled appropriately.
+ for (auto& kv : nodes) {
+ const auto& operands = kv.first->operands();
+ if (!operands.empty() &&
+ std::none_of(operands.begin(), operands.end(), is_displayed)) {
+ kv.second = kOmitNodeOperands;
+ }
+ }
+
+ // Highlight the root node.
+ nodes[root] = kHighlightNode;
+
+ return NodeFilter([=](const HloInstruction* instr) {
+ auto it = nodes.find(instr);
+ if (it != nodes.end()) {
+ return it->second;
+ }
+ return kHideNode;
+ });
}
XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0);
@@ -699,22 +899,9 @@ string DumpGraph(const HloComputation& computation, const string& label,
string DumpNeighborhoodAround(const HloInstruction& node, int radius) {
auto debug_options = node.GetModule()->config().debug_options();
-
- std::unordered_set<const HloInstruction*> neighborhood =
- GetInstructionsInNeighborhood(node, radius);
-
- NodeFilter filter([&](const HloInstruction* instr) {
- if (instr == &node) {
- return kHighlightNode;
- }
- if (neighborhood.find(instr) != neighborhood.end()) {
- return kNormalNode;
- }
- return kHideNode;
- });
-
string label =
StrCat("Neighborhood of ", radius, " nodes around ", node.name());
+ NodeFilter filter = MakeNodeFilter(&node, radius);
string graph = ComputationToDotGraph(
*node.parent(), label,
/*show_addresses=*/debug_options.xla_hlo_graph_addresses(),