aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2017-07-21 13:34:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-21 13:38:31 -0700
commit51cbb58ca5147218b3995dc124bd92927d93e913 (patch)
tree2bfd8551ae066f4acb8a0cc6b6fcd26e19ce2e56 /tensorflow/compiler/xla/service
parent307a2fc2adc2c6b717bb9713f81289f6bbb91b07 (diff)
Some improvements to HLO grapvhiz dumping.
- If all params are filtered out, don't show an empty "parameters" node. - If the node name contains the opcode name, omit the opcode. For example, show "%add.4" instead of "add %add.4". This greatly shrinks the width of our graphs. - (For nodes without a lot of operands), always show either none or all of the operands. - Show nodes with some or all operands elided as "grayed out" to make it clear that these are the "edges" of our neighborhood. - Don't show an out-of-line computation for e.g. "add reduce". Instead, simply show it as "%reduce.42<br>Subcomputation: add". - Split up parameter nodes. Previously all params were fused into one big node, but now each parameter is its own node. This is useful because otherwise graphviz has to route long edges from the top of the graph to nodes that use params at the bottom of the graph. - Inline constants into their users, instead of displaying them as separate nodes. This is particularly helpful when a constant (:cough: zero) is used many times, because otherwise we have to draw many long edges all over the graph. PiperOrigin-RevId: 162778619
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc515
1 files changed, 351 insertions, 164 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 2455925b96..097a762015 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -16,8 +16,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include <unistd.h>
+#include <algorithm>
+#include <atomic>
#include <deque>
+#include <map>
+#include <memory>
#include <string>
+#include <tuple>
+#include <unordered_map>
+#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
@@ -27,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -37,13 +45,15 @@ limitations under the License.
#include "tensorflow/core/platform/regexp.h"
using ::tensorflow::Env;
-using ::tensorflow::WriteStringToFile;
+using ::tensorflow::gtl::nullopt;
+using ::tensorflow::gtl::optional;
using ::tensorflow::io::JoinPath;
using ::tensorflow::strings::Appendf;
using ::tensorflow::strings::Printf;
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
using ::tensorflow::str_util::Join;
+using ::tensorflow::WriteStringToFile;
namespace xla {
namespace hlo_graph_dumper {
@@ -63,11 +73,26 @@ enum ColorScheme {
kRed,
kWhite,
kYellow,
+
+ // Causes the node's border to be a dashed line, and its content to be gray
+ // text on a white background, suggesting that this is an "unimportant" node.
+ kDashedBorder,
};
-// Used to indicate how we should treat a given HLOInstruction in the graph --
-// should we treat it like normal, hide it, or highlight it?
-enum NodeFilterResult { kNormalNode, kHideNode, kHighlightNode };
+// Used to indicate how we should treat a given HLOInstruction in the graph.
+// should we treat it like normal, hide it, and so on?
+enum NodeFilterResult {
+ kNormalNode,
+ kHideNode,
+ // Make the node easy to find in the final graph.
+ kHighlightNode,
+ // "Gray out" the node to indicate that some of its operands have been
+ // omitted.
+ kSomeOperandsOmitted,
+ // Style the node the same as kSomeOperandsOmitted, but also don't connect it
+ // to its operands, even if they're present in the graph.
+ kOmitNodeOperands,
+};
// NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
// It lets callers tell the graph-drawing routines which nodes they want to be
@@ -86,51 +111,59 @@ class NodeFilter {
bool Highlight(const HloInstruction* instr) const {
return filter_(instr) == kHighlightNode;
}
+ bool OmitOperands(const HloInstruction* instr) const {
+ return filter_(instr) == kOmitNodeOperands;
+ }
+ bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const {
+ auto result = filter_(instr);
+ return result == kOmitNodeOperands || result == kSomeOperandsOmitted;
+ }
private:
std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
};
// Given a ColorScheme, returns an attribute string for a node of that color.
-// Sets the node's fill, stroke, and text colors.
+// Sets the node's style and fill/stroke/text colors.
//
// Colors are from https://material.io/color.
string NodeColorAttributes(ColorScheme color) {
using std::make_tuple;
- const char *fill_color, *stroke_color, *font_color;
- std::tie(fill_color, stroke_color, font_color) =
- [color]() -> std::tuple<const char*, const char*, const char*> {
+ const char *style, *fill_color, *stroke_color, *font_color;
+ std::tie(style, fill_color, stroke_color, font_color) = [color] {
switch (color) {
case kBlue:
- return make_tuple("#bbdefb", "#8aacc8", "black");
+ return make_tuple("filled", "#bbdefb", "#8aacc8", "black");
case kBrown:
- return make_tuple("#bcaaa4", "#8c7b75", "black");
+ return make_tuple("filled", "#bcaaa4", "#8c7b75", "black");
case kDarkBlue:
- return make_tuple("#1565c0", "#003c8f", "white");
+ return make_tuple("filled", "#1565c0", "#003c8f", "white");
case kDarkGreen:
- return make_tuple("#2e7d32", "#005005", "white");
+ return make_tuple("filled", "#2e7d32", "#005005", "white");
case kDarkRed:
- return make_tuple("#b71c1c", "#7f0000", "white");
+ return make_tuple("filled", "#b71c1c", "#7f0000", "white");
case kGray:
- return make_tuple("#cfd8dc", "#9ea7aa", "black");
+ return make_tuple("filled", "#cfd8dc", "#9ea7aa", "black");
case kGreen:
- return make_tuple("#c8e6c9", "#97b498", "black");
+ return make_tuple("filled", "#c8e6c9", "#97b498", "black");
case kOrange:
- return make_tuple("#ffe0b2", "#cbae82", "black");
+ return make_tuple("filled", "#ffe0b2", "#cbae82", "black");
case kPurple:
- return make_tuple("#e1bee7", "#af8eb5", "black");
+ return make_tuple("filled", "#e1bee7", "#af8eb5", "black");
case kRed:
- return make_tuple("#ffcdd2", "#cb9ca1", "black");
+ return make_tuple("filled", "#ffcdd2", "#cb9ca1", "black");
case kWhite:
- return make_tuple("white", "black", "black");
+ return make_tuple("filled", "white", "black", "black");
case kYellow:
- return make_tuple("#fff9c4", "#cbc693", "black");
+ return make_tuple("filled", "#fff9c4", "#cbc693", "black");
+ case kDashedBorder:
+ return make_tuple("dashed", "white", "#757575", "#757575");
}
}();
return Printf(
- "style=filled, fontcolor=\"%s\", color=\"%s\", fillcolor=\"%s\"",
+ R"(style=%s, fontcolor="%s", color="%s", fillcolor="%s")", style,
font_color, stroke_color, fill_color);
}
@@ -152,6 +185,70 @@ string ComputationId(const HloComputation* computation) {
return Printf("%lld", reinterpret_cast<uint64>(computation));
}
+// Tries to generates a human-readable one-word description of the given
+// computation.
+//
+// Currently we support:
+//
+// "return param0 + param1;" --> "add"
+// "return param0 * param1;" --> "multiply"
+// "return min(param0, param1);" --> "min"
+// "return max(param0, param1);" --> "max"
+//
+// where param0 and param1 are effective scalars. Since all of the ops above
+// are commutative, we also support them with param0 and param1 swapped.
+//
+// This is useful primarily for reduce and map nodes. These take a
+// subcomputation which is almost always one of the four above, and pattern
+// matching it to a short string lets us tell the user what the subcomputation
+// is without drawing it as a graph.
+optional<string> MatchTrivialComputation(const HloComputation* computation) {
+ if (computation->instruction_count() != 3) {
+ return nullopt;
+ }
+
+ HloInstruction* root = computation->root_instruction();
+ if (root->operand_count() != 2) {
+ return nullopt;
+ }
+
+ // Check that both of the operands to the root are parameters.
+ const HloInstruction* operand0 = root->operand(0);
+ const HloInstruction* operand1 = root->operand(1);
+ if (operand0->opcode() != HloOpcode::kParameter ||
+ operand1->opcode() != HloOpcode::kParameter) {
+ return nullopt;
+ }
+ // Check that the two operands of root are param0 and param1. All of the
+ // opcodes we recognize are commutative, so we're OK with either order.
+ auto n0 = operand0->parameter_number();
+ auto n1 = operand1->parameter_number();
+ if (!(n0 == 0 && n1 == 1) && !(n1 == 0 && n0 == 1)) {
+ return nullopt;
+ }
+
+ // Check that the root and params are all effective scalars.
+ if (!ShapeUtil::IsEffectiveScalar(root->shape()) ||
+ !ShapeUtil::IsEffectiveScalar(operand0->shape()) ||
+ !ShapeUtil::IsEffectiveScalar(operand1->shape())) {
+ return nullopt;
+ }
+
+ // If we recognize the root's opcode, we've successfully pattern-matched!
+ switch (root->opcode()) {
+ case HloOpcode::kAdd:
+ return "add";
+ case HloOpcode::kMultiply:
+ return "multiply";
+ case HloOpcode::kMinimum:
+ return "min";
+ case HloOpcode::kMaximum:
+ return "max";
+ default:
+ return nullopt;
+ }
+}
+
// Returns the dot graph edges and nodes for the given instruction sequence.
// Edges which extend between computations are added to the vector
// intercomputation_edges. This is necessary because graphviz does not render
@@ -165,58 +262,49 @@ string InstructionSequenceGraph(
const NodeFilter& filter) {
string graph_body;
- // Create a single "record" node for the parameters. This node is a
- // partitioned rectangle with one partition per parameter node. The keeps
- // all the parameter instructions together.
- std::vector<HloInstruction*> param_instructions;
for (auto& instruction : instructions) {
- if (instruction->opcode() == HloOpcode::kParameter) {
- size_t param_number = instruction->parameter_number();
-
- if (param_instructions.size() < param_number + 1) {
- param_instructions.resize(param_number + 1, nullptr);
- }
- param_instructions[param_number] = instruction.get();
- }
- }
- string param_node_name;
- if (!param_instructions.empty()) {
- std::vector<string> param_ports;
- param_node_name =
- StrCat("parameters_", InstructionId(param_instructions[0]));
- for (auto& param : param_instructions) {
- if (!filter.Show(param)) {
- continue;
- }
- string label = StrCat(param->parameter_name(), "\\n",
- ShapeUtil::HumanString(param->shape()));
- if (show_addresses) {
- Appendf(&label, "\\n[%p]", param);
- }
- if (show_layouts) {
- StrAppend(&label, "\\nlayout=\\{",
- Join(param->shape().layout().minor_to_major(), ","), "\\}");
- }
- param_ports.push_back(
- Printf("<%s> %s", InstructionId(param).c_str(), label.c_str()));
+ if (!filter.Show(instruction.get())) {
+ continue;
}
- // (If we wanted the word "parameters" to be bold like the other op names,
- // we'd have to make this into an HTML-like table. It is possible but
- // complicated; see http://www.graphviz.org/doc/info/shapes.html#html.)
- StrAppend(&graph_body, param_node_name, " [shape=record ",
- NodeColorAttributes(kOrange), "label=\"{parameters | {",
- Join(param_ports, "|"), "}}\"];\n");
- }
- for (auto& instruction : instructions) {
- if (!filter.Show(instruction.get())) {
+ // We don't display constants as separate nodes; they're merged into their
+ // users.
+ if (instruction->opcode() == HloOpcode::kConstant) {
continue;
}
+
ColorScheme color = kYellow;
string shape = "box";
- string name =
- StrCat("<b>", HtmlLikeStringSanitize(instruction->ExtendedOpcodeStr()),
- "</b> ", HtmlLikeStringSanitize(instruction->name()));
+
+ // Build the first line or two of the node, containing its name and opcode
+ // (if the opcode isn't redundant with the name).
+ string name;
+ if (instruction->opcode() == HloOpcode::kParameter) {
+ // If we have a parameter, put the param number in the name.
+ name = StrCat("<b>Parameter ", instruction->parameter_number(),
+ "</b><br/>", HtmlLikeStringSanitize(instruction->name()));
+ } else if (tensorflow::StringPiece(instruction->name())
+ .starts_with(
+ StrCat("%", instruction->ExtendedOpcodeStr()))) {
+ // The HLO instruction name contains usually the opcode, e.g. "%add.42" is
+ // an add instruction. In this case we render just the name.
+ name = StrCat("<b>", HtmlLikeStringSanitize(instruction->name()), "</b>");
+ } else if (instruction->opcode() == HloOpcode::kFusion &&
+ tensorflow::StringPiece(instruction->name())
+ .starts_with(
+ StrCat("%", HloOpcodeString(instruction->opcode())))) {
+ // Fusion nodes are usually named e.g. "%fusion.5". We render these as
+ // e.g. "%fusion.5<br/>input fusion".
+ name = StrCat("<b>", HtmlLikeStringSanitize(instruction->name()),
+ "</b><br/>",
+ HtmlLikeStringSanitize(instruction->ToCategory()));
+ } else {
+ // If the name does not contain the opcode, render both.
+ name = StrCat("<b>",
+ HtmlLikeStringSanitize(instruction->ExtendedOpcodeStr()),
+ "</b><br/>", HtmlLikeStringSanitize(instruction->name()));
+ }
+
if (HloOpcode::kConvolution == instruction->opcode()) {
StrAppend(
&name, "<br/>",
@@ -305,18 +393,13 @@ string InstructionSequenceGraph(
case HloOpcode::kUpdate:
color = kGreen;
break;
- case HloOpcode::kConstant:
- color = kBlue;
- break;
case HloOpcode::kConvolution:
case HloOpcode::kDot:
color = kDarkBlue;
break;
case HloOpcode::kParameter:
- // A single record node is created for all the parameter nodes with a
- // port for each parameter instruction. No need to emit anything in this
- // case.
- continue;
+ color = kOrange;
+ break;
case HloOpcode::kBatchNormTraining:
StrAppend(&name, " feature_index=", instruction->feature_index());
color = kPurple;
@@ -361,6 +444,8 @@ string InstructionSequenceGraph(
// will be inserted as modifications to an existing graph.
color = kRed;
break;
+ case HloOpcode::kConstant:
+ LOG(FATAL) << "Constants don't get their own nodes in the graph.";
}
// Create instruction node with appropriate label, shape, and color.
@@ -369,14 +454,6 @@ string InstructionSequenceGraph(
string label =
StrCat(name, "<br/>", ShapeUtil::HumanString(instruction->shape()));
- if (instruction->opcode() == HloOpcode::kConstant &&
- ShapeUtil::IsEffectiveScalar(instruction->shape())) {
- auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex(
- instruction->shape(), /*linear_index=*/0);
- StrAppend(&label, " = {", instruction->literal().GetAsString(elem_idx),
- "}");
- }
-
if (show_addresses) {
Appendf(&label, "<br/>[%p]", instruction.get());
}
@@ -405,75 +482,116 @@ string InstructionSequenceGraph(
}
}
+ // If this node's operands are omitted, style it accordingly.
+ if (filter.SomeOrAllOperandsOmitted(instruction.get())) {
+ color = kDashedBorder;
+ }
+
// If this node is highlighted, override its formatting.
if (filter.Highlight(instruction.get())) {
shape = "diamond";
color = kDarkRed;
}
- Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n",
- InstructionId(instruction.get()).c_str(), label.c_str(),
- shape.c_str(), NodeColorAttributes(color).c_str());
-
// Create edges from the instruction's operands to the instruction.
- int64 operand_number = 0;
- for (auto* operand : instruction->operands()) {
- if (!filter.Show(operand)) {
+ if (!filter.OmitOperands(instruction.get())) {
+ int64 operand_number = 0;
+ for (auto* operand : instruction->operands()) {
+ if (!filter.Show(operand) ||
+ operand->opcode() == HloOpcode::kConstant) {
+ ++operand_number;
+ continue;
+ }
+ Appendf(&graph_body, "%s -> %s", InstructionId(operand).c_str(),
+ InstructionId(instruction.get()).c_str());
+ if (instruction->operand_count() > 1) {
+ Appendf(&graph_body, " [headlabel=\"%lld\",labeldistance=2]",
+ operand_number);
+ }
+ StrAppend(&graph_body, ";\n");
++operand_number;
- continue;
}
- string src;
- if (operand->opcode() == HloOpcode::kParameter) {
- // If operand is a parameter, then select the proper partition (port) in
- // the unified parameter node.
- src = param_node_name + ":" + InstructionId(operand);
+
+ // Fusion nodes are handled specially because they contain nested
+ // expressions.
+ if (instruction->opcode() == HloOpcode::kFusion) {
+ string cluster_name =
+ StrCat("cluster_", InstructionId(instruction.get()));
+ StrAppend(&graph_body, "subgraph ", cluster_name, " {\n");
+ StrAppend(&graph_body, "label=<fused expression for <b>",
+ HtmlLikeStringSanitize(instruction->name()),
+ "</b>>;\nstyle=\"rounded,filled\";\n"
+ "color=lightgrey;\n");
+ StrAppend(&graph_body,
+ InstructionSequenceGraph(instruction->fused_instructions(),
+ show_addresses, show_layouts,
+ intercomputation_edges,
+ hlo_execution_profile, NodeFilter()),
+ "}\n");
+ string fusion_edge = StrCat(
+ InstructionId(instruction->fused_expression_root()), " -> ",
+ InstructionId(instruction.get()),
+ " [ style = \"dotted\", arrowsize=0.0, ltail=", cluster_name,
+ " ];\n");
+ intercomputation_edges->push_back(fusion_edge);
} else {
- src = InstructionId(operand);
- }
- Appendf(&graph_body, "%s -> %s", src.c_str(),
- InstructionId(instruction.get()).c_str());
- if (instruction->operand_count() > 1) {
- Appendf(&graph_body, " [headlabel=\"%lld\",labeldistance=2]",
- operand_number);
+ // If instruction has just one computation and it's trivial (e.g.
+ // "return param0 + param1"), put the trivial computation type (e.g.
+ // "add") into instruction's label. Otherwise, add a dotted edge
+ // between the instruction and its subcomputations.
+ const auto& subcomputations = instruction->called_computations();
+
+ bool trivial_subcomputation = false;
+ if (subcomputations.size() == 1) {
+ optional<string> computation_type =
+ MatchTrivialComputation(subcomputations.front());
+ if (computation_type) {
+ trivial_subcomputation = true;
+ StrAppend(&label, "<br/>Subcomputation: <b>", *computation_type,
+ "</b>");
+ }
+ }
+
+ if (!trivial_subcomputation) {
+ for (const HloComputation* computation :
+ instruction->called_computations()) {
+ string cluster_name =
+ StrCat("cluster_", ComputationId(computation));
+ string call_edge = Printf(
+ "%s -> %s [ style=dashed; ltail=%s ];\n",
+ InstructionId(computation->root_instruction()).c_str(),
+ InstructionId(instruction.get()).c_str(), cluster_name.c_str());
+ intercomputation_edges->push_back(call_edge);
+ }
+ }
}
- StrAppend(&graph_body, ";\n");
- ++operand_number;
}
- // Fusion nodes are handled specially because they contain nested
- // expressions.
- if (instruction->opcode() == HloOpcode::kFusion) {
- string cluster_name =
- StrCat("cluster_", InstructionId(instruction.get()));
- StrAppend(&graph_body, "subgraph ", cluster_name, " {\n");
- StrAppend(&graph_body,
- "label=<<b>fused expression</b>>;\nstyle=\"rounded,filled\";\n"
- "color=lightgrey;\n");
- StrAppend(&graph_body,
- InstructionSequenceGraph(instruction->fused_instructions(),
- show_addresses, show_layouts,
- intercomputation_edges,
- hlo_execution_profile, NodeFilter()),
- "}\n");
- string fusion_edge =
- StrCat(InstructionId(instruction->fused_expression_root()), " -> ",
- InstructionId(instruction.get()),
- " [ style = \"dotted\", arrowsize=0.0, ltail=", cluster_name,
- " ];\n");
- intercomputation_edges->push_back(fusion_edge);
- } else {
- // Add a dotted edge between the instruction and any computations that the
- // instruction calls.
- for (const HloComputation* computation :
- instruction->called_computations()) {
- string cluster_name = StrCat("cluster_", ComputationId(computation));
- string call_edge = Printf(
- "%s -> %s [ style=dashed; ltail=%s ];\n",
- InstructionId(computation->root_instruction()).c_str(),
- InstructionId(instruction.get()).c_str(), cluster_name.c_str());
- intercomputation_edges->push_back(call_edge);
+ // Inline constant operands into the node.
+ for (int64 i = 0; i < instruction->operand_count(); ++i) {
+ const HloInstruction* operand = instruction->operand(i);
+ if (operand->opcode() != HloOpcode::kConstant) {
+ continue;
+ }
+
+ StrAppend(&label, "<br/><b>operand ", i, "</b> = ");
+ if (ShapeUtil::IsEffectiveScalar(operand->shape())) {
+ auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex(
+ operand->shape(), /*linear_index=*/0);
+ StrAppend(&label, ShapeUtil::HumanString(operand->shape()), "{",
+ operand->literal().GetAsString(elem_idx), "}");
+ } else {
+ if (tensorflow::StringPiece(operand->name()).starts_with("%constant")) {
+ StrAppend(&label, operand->name());
+ } else {
+ StrAppend(&label, "constant ", operand->name());
+ }
}
}
+
+ Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n",
+ InstructionId(instruction.get()).c_str(), label.c_str(),
+ shape.c_str(), NodeColorAttributes(color).c_str());
}
return graph_body;
}
@@ -513,16 +631,25 @@ stylesheet="%s"
)",
graph_label.c_str(), dot_stylesheet);
+ // Dump the subcomputations of each instruction that's shown and doesn't have
+ // its operands omitted. If an instruction has just one subcomputation and
+ // it's trivial, omit it: We'll display that subcomputation inlined into the
+ // instruction's node when we draw it.
std::unordered_set<const HloComputation*> computations_to_dump;
for (const auto& instr : computation.instructions()) {
- if (!filter.Show(instr.get())) {
+ if (!filter.Show(instr.get()) || filter.OmitOperands(instr.get())) {
continue;
}
if (instr->opcode() == HloOpcode::kFusion) {
computations_to_dump.insert(instr->fused_instructions_computation());
}
- for (const HloComputation* computation : instr->called_computations()) {
- computations_to_dump.insert(computation);
+
+ const auto& subcomputations = instr->called_computations();
+ if (subcomputations.size() != 1 ||
+ !MatchTrivialComputation(subcomputations.front())) {
+ for (const HloComputation* computation : instr->called_computations()) {
+ computations_to_dump.insert(computation);
+ }
}
}
@@ -620,31 +747,41 @@ class FileGraphRenderer : public GraphRendererInterface {
}
};
-// Gets roughly all instructions whose distance from root is <= radius.
-std::unordered_set<const HloInstruction*> GetInstructionsInNeighborhood(
- const HloInstruction& root, int64 radius) {
- std::unordered_set<const HloInstruction*> ret;
-
+// Gets a NodeFilter that includes roughly all instructions whose distance from
+// root is <= radius.
+//
+// It's confusing to draw a node and include only some of its operands. So if
+// some but not all of a node's operands are <= radius units away from the root,
+// we include the other operands (unless there are a lot of them, as often in a
+// tuple node). These additional operands may have as inputs other nodes
+// already present in the graph, but we don't draw those edges unless *all* of
+// the inputs are present. (Otherwise we'd have the same problem we were trying
+// to solve in the first place!)
+NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) {
+ // First, find the neighborhood of nodes with distance from root <= radius.
+ // These nodes are our initial set of "normal" nodes.
+ std::unordered_map<const HloInstruction*, NodeFilterResult> nodes;
std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
- worklist.push_back({&root, 0});
-
+ worklist.push_back({root, 0});
while (!worklist.empty()) {
const HloInstruction* instr;
int64 depth;
std::tie(instr, depth) = worklist.front();
worklist.pop_front();
- ret.insert(instr);
+ nodes[instr] = kNormalNode;
if (depth == radius) {
continue;
}
+ // Traverse into instr's operands.
+ //
// Don't traverse into tuples' operands unless the tuple is the root.
// Usually a tuple is the bottommost node in the graph, and so its operands
// are not interesting to the graph at hand.
- if (instr == &root || instr->opcode() != HloOpcode::kTuple) {
+ if (instr == root || instr->opcode() != HloOpcode::kTuple) {
for (const HloInstruction* operand : instr->operands()) {
- if (ret.find(operand) == ret.end()) {
+ if (!nodes.count(operand)) {
worklist.push_back({operand, depth + 1});
}
}
@@ -655,14 +792,77 @@ std::unordered_set<const HloInstruction*> GetInstructionsInNeighborhood(
// constants' users.
if (instr->opcode() != HloOpcode::kConstant) {
for (const HloInstruction* user : instr->users()) {
- if (ret.find(user) == ret.end()) {
+ if (!nodes.count(user)) {
worklist.push_back({user, depth + 1});
}
}
}
}
- return ret;
+ auto is_displayed = [&](const HloInstruction* instr) {
+ return nodes.count(instr) > 0;
+ };
+
+ // If a node has some but not all of its operands omitted, add the operands to
+ // the map with type kOmitNodeOperands. Unless the node has a lot of
+ // operands, in which case just mark the node as "some operands omitted".
+ std::vector<const HloInstruction*> extra_operands;
+ for (auto& kv : nodes) {
+ const HloInstruction* instr = kv.first;
+ NodeFilterResult& filter_result = kv.second;
+ const auto& operands = instr->operands();
+
+ // Mark nodes with many operands and some omitted as "some operands omitted"
+ // and carry on -- don't add their omitted operands to extra_operands.
+ if (operands.size() > 4) {
+ if (std::any_of(operands.begin(), operands.end(), is_displayed) &&
+ !std::all_of(operands.begin(), operands.end(), is_displayed)) {
+ filter_result = kSomeOperandsOmitted;
+ }
+ continue;
+ }
+
+ if (std::any_of(operands.begin(), operands.end(), is_displayed)) {
+ for (const HloInstruction* operand : operands) {
+ if (!is_displayed(operand)) {
+ extra_operands.push_back(operand);
+ }
+ }
+ }
+ }
+ for (const HloInstruction* instr : extra_operands) {
+ nodes[instr] = kOmitNodeOperands;
+ }
+
+ // Some of the nodes in extra_operands may now have all of their inputs
+ // present in nodes. We can promote these to normal nodes.
+ for (const HloInstruction* instr : extra_operands) {
+ const auto& operands = instr->operands();
+ if (std::all_of(operands.begin(), operands.end(), is_displayed)) {
+ nodes[instr] = kNormalNode;
+ }
+ }
+
+ // If none of a node's operands appear in nodes, mark it as type
+ // kOmitNodeOperands so it gets styled appropriately.
+ for (auto& kv : nodes) {
+ const auto& operands = kv.first->operands();
+ if (!operands.empty() &&
+ std::none_of(operands.begin(), operands.end(), is_displayed)) {
+ kv.second = kOmitNodeOperands;
+ }
+ }
+
+ // Highlight the root node.
+ nodes[root] = kHighlightNode;
+
+ return NodeFilter([=](const HloInstruction* instr) {
+ auto it = nodes.find(instr);
+ if (it != nodes.end()) {
+ return it->second;
+ }
+ return kHideNode;
+ });
}
XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0);
@@ -699,22 +899,9 @@ string DumpGraph(const HloComputation& computation, const string& label,
string DumpNeighborhoodAround(const HloInstruction& node, int radius) {
auto debug_options = node.GetModule()->config().debug_options();
-
- std::unordered_set<const HloInstruction*> neighborhood =
- GetInstructionsInNeighborhood(node, radius);
-
- NodeFilter filter([&](const HloInstruction* instr) {
- if (instr == &node) {
- return kHighlightNode;
- }
- if (neighborhood.find(instr) != neighborhood.end()) {
- return kNormalNode;
- }
- return kHideNode;
- });
-
string label =
StrCat("Neighborhood of ", radius, " nodes around ", node.name());
+ NodeFilter filter = MakeNodeFilter(&node, radius);
string graph = ComputationToDotGraph(
*node.parent(), label,
/*show_addresses=*/debug_options.xla_hlo_graph_addresses(),