diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_graph_dumper.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_graph_dumper.cc | 515 |
1 files changed, 351 insertions, 164 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 2455925b96..097a762015 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -16,8 +16,15 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include <unistd.h> +#include <algorithm> +#include <atomic> #include <deque> +#include <map> +#include <memory> #include <string> +#include <tuple> +#include <unordered_map> +#include <vector> #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -27,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -37,13 +45,15 @@ limitations under the License. #include "tensorflow/core/platform/regexp.h" using ::tensorflow::Env; -using ::tensorflow::WriteStringToFile; +using ::tensorflow::gtl::nullopt; +using ::tensorflow::gtl::optional; using ::tensorflow::io::JoinPath; using ::tensorflow::strings::Appendf; using ::tensorflow::strings::Printf; using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; using ::tensorflow::str_util::Join; +using ::tensorflow::WriteStringToFile; namespace xla { namespace hlo_graph_dumper { @@ -63,11 +73,26 @@ enum ColorScheme { kRed, kWhite, kYellow, + + // Causes the node's border to be a dashed line, and its content to be gray + // text on a white background, suggesting that this is an "unimportant" node. + kDashedBorder, }; -// Used to indicate how we should treat a given HLOInstruction in the graph -- -// should we treat it like normal, hide it, or highlight it? -enum NodeFilterResult { kNormalNode, kHideNode, kHighlightNode }; +// Used to indicate how we should treat a given HLOInstruction in the graph. +// should we treat it like normal, hide it, and so on? +enum NodeFilterResult { + kNormalNode, + kHideNode, + // Make the node easy to find in the final graph. + kHighlightNode, + // "Gray out" the node to indicate that some of its operands have been + // omitted. + kSomeOperandsOmitted, + // Style the node the same as kSomeOperandsOmitted, but also don't connect it + // to its operands, even if they're present in the graph. + kOmitNodeOperands, +}; // NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult. // It lets callers tell the graph-drawing routines which nodes they want to be @@ -86,51 +111,59 @@ class NodeFilter { bool Highlight(const HloInstruction* instr) const { return filter_(instr) == kHighlightNode; } + bool OmitOperands(const HloInstruction* instr) const { + return filter_(instr) == kOmitNodeOperands; + } + bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const { + auto result = filter_(instr); + return result == kOmitNodeOperands || result == kSomeOperandsOmitted; + } private: std::function<NodeFilterResult(const HloInstruction* instr)> filter_; }; // Given a ColorScheme, returns an attribute string for a node of that color. -// Sets the node's fill, stroke, and text colors. +// Sets the node's style and fill/stroke/text colors. // // Colors are from https://material.io/color. string NodeColorAttributes(ColorScheme color) { using std::make_tuple; - const char *fill_color, *stroke_color, *font_color; - std::tie(fill_color, stroke_color, font_color) = - [color]() -> std::tuple<const char*, const char*, const char*> { + const char *style, *fill_color, *stroke_color, *font_color; + std::tie(style, fill_color, stroke_color, font_color) = [color] { switch (color) { case kBlue: - return make_tuple("#bbdefb", "#8aacc8", "black"); + return make_tuple("filled", "#bbdefb", "#8aacc8", "black"); case kBrown: - return make_tuple("#bcaaa4", "#8c7b75", "black"); + return make_tuple("filled", "#bcaaa4", "#8c7b75", "black"); case kDarkBlue: - return make_tuple("#1565c0", "#003c8f", "white"); + return make_tuple("filled", "#1565c0", "#003c8f", "white"); case kDarkGreen: - return make_tuple("#2e7d32", "#005005", "white"); + return make_tuple("filled", "#2e7d32", "#005005", "white"); case kDarkRed: - return make_tuple("#b71c1c", "#7f0000", "white"); + return make_tuple("filled", "#b71c1c", "#7f0000", "white"); case kGray: - return make_tuple("#cfd8dc", "#9ea7aa", "black"); + return make_tuple("filled", "#cfd8dc", "#9ea7aa", "black"); case kGreen: - return make_tuple("#c8e6c9", "#97b498", "black"); + return make_tuple("filled", "#c8e6c9", "#97b498", "black"); case kOrange: - return make_tuple("#ffe0b2", "#cbae82", "black"); + return make_tuple("filled", "#ffe0b2", "#cbae82", "black"); case kPurple: - return make_tuple("#e1bee7", "#af8eb5", "black"); + return make_tuple("filled", "#e1bee7", "#af8eb5", "black"); case kRed: - return make_tuple("#ffcdd2", "#cb9ca1", "black"); + return make_tuple("filled", "#ffcdd2", "#cb9ca1", "black"); case kWhite: - return make_tuple("white", "black", "black"); + return make_tuple("filled", "white", "black", "black"); case kYellow: - return make_tuple("#fff9c4", "#cbc693", "black"); + return make_tuple("filled", "#fff9c4", "#cbc693", "black"); + case kDashedBorder: + return make_tuple("dashed", "white", "#757575", "#757575"); } }(); return Printf( - "style=filled, fontcolor=\"%s\", color=\"%s\", fillcolor=\"%s\"", + R"(style=%s, fontcolor="%s", color="%s", fillcolor="%s")", style, font_color, stroke_color, fill_color); } @@ -152,6 +185,70 @@ string ComputationId(const HloComputation* computation) { return Printf("%lld", reinterpret_cast<uint64>(computation)); } +// Tries to generates a human-readable one-word description of the given +// computation. +// +// Currently we support: +// +// "return param0 + param1;" --> "add" +// "return param0 * param1;" --> "multiply" +// "return min(param0, param1);" --> "min" +// "return max(param0, param1);" --> "max" +// +// where param0 and param1 are effective scalars. Since all of the ops above +// are commutative, we also support them with param0 and param1 swapped. +// +// This is useful primarily for reduce and map nodes. These take a +// subcomputation which is almost always one of the four above, and pattern +// matching it to a short string lets us tell the user what the subcomputation +// is without drawing it as a graph. +optional<string> MatchTrivialComputation(const HloComputation* computation) { + if (computation->instruction_count() != 3) { + return nullopt; + } + + HloInstruction* root = computation->root_instruction(); + if (root->operand_count() != 2) { + return nullopt; + } + + // Check that both of the operands to the root are parameters. + const HloInstruction* operand0 = root->operand(0); + const HloInstruction* operand1 = root->operand(1); + if (operand0->opcode() != HloOpcode::kParameter || + operand1->opcode() != HloOpcode::kParameter) { + return nullopt; + } + // Check that the two operands of root are param0 and param1. All of the + // opcodes we recognize are commutative, so we're OK with either order. + auto n0 = operand0->parameter_number(); + auto n1 = operand1->parameter_number(); + if (!(n0 == 0 && n1 == 1) && !(n1 == 0 && n0 == 1)) { + return nullopt; + } + + // Check that the root and params are all effective scalars. + if (!ShapeUtil::IsEffectiveScalar(root->shape()) || + !ShapeUtil::IsEffectiveScalar(operand0->shape()) || + !ShapeUtil::IsEffectiveScalar(operand1->shape())) { + return nullopt; + } + + // If we recognize the root's opcode, we've successfully pattern-matched! + switch (root->opcode()) { + case HloOpcode::kAdd: + return "add"; + case HloOpcode::kMultiply: + return "multiply"; + case HloOpcode::kMinimum: + return "min"; + case HloOpcode::kMaximum: + return "max"; + default: + return nullopt; + } +} + // Returns the dot graph edges and nodes for the given instruction sequence. // Edges which extend between computations are added to the vector // intercomputation_edges. This is necessary because graphviz does not render @@ -165,58 +262,49 @@ string InstructionSequenceGraph( const NodeFilter& filter) { string graph_body; - // Create a single "record" node for the parameters. This node is a - // partitioned rectangle with one partition per parameter node. The keeps - // all the parameter instructions together. - std::vector<HloInstruction*> param_instructions; for (auto& instruction : instructions) { - if (instruction->opcode() == HloOpcode::kParameter) { - size_t param_number = instruction->parameter_number(); - - if (param_instructions.size() < param_number + 1) { - param_instructions.resize(param_number + 1, nullptr); - } - param_instructions[param_number] = instruction.get(); - } - } - string param_node_name; - if (!param_instructions.empty()) { - std::vector<string> param_ports; - param_node_name = - StrCat("parameters_", InstructionId(param_instructions[0])); - for (auto& param : param_instructions) { - if (!filter.Show(param)) { - continue; - } - string label = StrCat(param->parameter_name(), "\\n", - ShapeUtil::HumanString(param->shape())); - if (show_addresses) { - Appendf(&label, "\\n[%p]", param); - } - if (show_layouts) { - StrAppend(&label, "\\nlayout=\\{", - Join(param->shape().layout().minor_to_major(), ","), "\\}"); - } - param_ports.push_back( - Printf("<%s> %s", InstructionId(param).c_str(), label.c_str())); + if (!filter.Show(instruction.get())) { + continue; } - // (If we wanted the word "parameters" to be bold like the other op names, - // we'd have to make this into an HTML-like table. It is possible but - // complicated; see http://www.graphviz.org/doc/info/shapes.html#html.) - StrAppend(&graph_body, param_node_name, " [shape=record ", - NodeColorAttributes(kOrange), "label=\"{parameters | {", - Join(param_ports, "|"), "}}\"];\n"); - } - for (auto& instruction : instructions) { - if (!filter.Show(instruction.get())) { + // We don't display constants as separate nodes; they're merged into their + // users. + if (instruction->opcode() == HloOpcode::kConstant) { continue; } + ColorScheme color = kYellow; string shape = "box"; - string name = - StrCat("<b>", HtmlLikeStringSanitize(instruction->ExtendedOpcodeStr()), - "</b> ", HtmlLikeStringSanitize(instruction->name())); + + // Build the first line or two of the node, containing its name and opcode + // (if the opcode isn't redundant with the name). + string name; + if (instruction->opcode() == HloOpcode::kParameter) { + // If we have a parameter, put the param number in the name. + name = StrCat("<b>Parameter ", instruction->parameter_number(), + "</b><br/>", HtmlLikeStringSanitize(instruction->name())); + } else if (tensorflow::StringPiece(instruction->name()) + .starts_with( + StrCat("%", instruction->ExtendedOpcodeStr()))) { + // The HLO instruction name contains usually the opcode, e.g. "%add.42" is + // an add instruction. In this case we render just the name. + name = StrCat("<b>", HtmlLikeStringSanitize(instruction->name()), "</b>"); + } else if (instruction->opcode() == HloOpcode::kFusion && + tensorflow::StringPiece(instruction->name()) + .starts_with( + StrCat("%", HloOpcodeString(instruction->opcode())))) { + // Fusion nodes are usually named e.g. "%fusion.5". We render these as + // e.g. "%fusion.5<br/>input fusion". + name = StrCat("<b>", HtmlLikeStringSanitize(instruction->name()), + "</b><br/>", + HtmlLikeStringSanitize(instruction->ToCategory())); + } else { + // If the name does not contain the opcode, render both. + name = StrCat("<b>", + HtmlLikeStringSanitize(instruction->ExtendedOpcodeStr()), + "</b><br/>", HtmlLikeStringSanitize(instruction->name())); + } + if (HloOpcode::kConvolution == instruction->opcode()) { StrAppend( &name, "<br/>", @@ -305,18 +393,13 @@ string InstructionSequenceGraph( case HloOpcode::kUpdate: color = kGreen; break; - case HloOpcode::kConstant: - color = kBlue; - break; case HloOpcode::kConvolution: case HloOpcode::kDot: color = kDarkBlue; break; case HloOpcode::kParameter: - // A single record node is created for all the parameter nodes with a - // port for each parameter instruction. No need to emit anything in this - // case. - continue; + color = kOrange; + break; case HloOpcode::kBatchNormTraining: StrAppend(&name, " feature_index=", instruction->feature_index()); color = kPurple; @@ -361,6 +444,8 @@ string InstructionSequenceGraph( // will be inserted as modifications to an existing graph. color = kRed; break; + case HloOpcode::kConstant: + LOG(FATAL) << "Constants don't get their own nodes in the graph."; } // Create instruction node with appropriate label, shape, and color. @@ -369,14 +454,6 @@ string InstructionSequenceGraph( string label = StrCat(name, "<br/>", ShapeUtil::HumanString(instruction->shape())); - if (instruction->opcode() == HloOpcode::kConstant && - ShapeUtil::IsEffectiveScalar(instruction->shape())) { - auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex( - instruction->shape(), /*linear_index=*/0); - StrAppend(&label, " = {", instruction->literal().GetAsString(elem_idx), - "}"); - } - if (show_addresses) { Appendf(&label, "<br/>[%p]", instruction.get()); } @@ -405,75 +482,116 @@ string InstructionSequenceGraph( } } + // If this node's operands are omitted, style it accordingly. + if (filter.SomeOrAllOperandsOmitted(instruction.get())) { + color = kDashedBorder; + } + // If this node is highlighted, override its formatting. if (filter.Highlight(instruction.get())) { shape = "diamond"; color = kDarkRed; } - Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n", - InstructionId(instruction.get()).c_str(), label.c_str(), - shape.c_str(), NodeColorAttributes(color).c_str()); - // Create edges from the instruction's operands to the instruction. - int64 operand_number = 0; - for (auto* operand : instruction->operands()) { - if (!filter.Show(operand)) { + if (!filter.OmitOperands(instruction.get())) { + int64 operand_number = 0; + for (auto* operand : instruction->operands()) { + if (!filter.Show(operand) || + operand->opcode() == HloOpcode::kConstant) { + ++operand_number; + continue; + } + Appendf(&graph_body, "%s -> %s", InstructionId(operand).c_str(), + InstructionId(instruction.get()).c_str()); + if (instruction->operand_count() > 1) { + Appendf(&graph_body, " [headlabel=\"%lld\",labeldistance=2]", + operand_number); + } + StrAppend(&graph_body, ";\n"); ++operand_number; - continue; } - string src; - if (operand->opcode() == HloOpcode::kParameter) { - // If operand is a parameter, then select the proper partition (port) in - // the unified parameter node. - src = param_node_name + ":" + InstructionId(operand); + + // Fusion nodes are handled specially because they contain nested + // expressions. + if (instruction->opcode() == HloOpcode::kFusion) { + string cluster_name = + StrCat("cluster_", InstructionId(instruction.get())); + StrAppend(&graph_body, "subgraph ", cluster_name, " {\n"); + StrAppend(&graph_body, "label=<fused expression for <b>", + HtmlLikeStringSanitize(instruction->name()), + "</b>>;\nstyle=\"rounded,filled\";\n" + "color=lightgrey;\n"); + StrAppend(&graph_body, + InstructionSequenceGraph(instruction->fused_instructions(), + show_addresses, show_layouts, + intercomputation_edges, + hlo_execution_profile, NodeFilter()), + "}\n"); + string fusion_edge = StrCat( + InstructionId(instruction->fused_expression_root()), " -> ", + InstructionId(instruction.get()), + " [ style = \"dotted\", arrowsize=0.0, ltail=", cluster_name, + " ];\n"); + intercomputation_edges->push_back(fusion_edge); } else { - src = InstructionId(operand); - } - Appendf(&graph_body, "%s -> %s", src.c_str(), - InstructionId(instruction.get()).c_str()); - if (instruction->operand_count() > 1) { - Appendf(&graph_body, " [headlabel=\"%lld\",labeldistance=2]", - operand_number); + // If instruction has just one computation and it's trivial (e.g. + // "return param0 + param1"), put the trivial computation type (e.g. + // "add") into instruction's label. Otherwise, add a dotted edge + // between the instruction and its subcomputations. + const auto& subcomputations = instruction->called_computations(); + + bool trivial_subcomputation = false; + if (subcomputations.size() == 1) { + optional<string> computation_type = + MatchTrivialComputation(subcomputations.front()); + if (computation_type) { + trivial_subcomputation = true; + StrAppend(&label, "<br/>Subcomputation: <b>", *computation_type, + "</b>"); + } + } + + if (!trivial_subcomputation) { + for (const HloComputation* computation : + instruction->called_computations()) { + string cluster_name = + StrCat("cluster_", ComputationId(computation)); + string call_edge = Printf( + "%s -> %s [ style=dashed; ltail=%s ];\n", + InstructionId(computation->root_instruction()).c_str(), + InstructionId(instruction.get()).c_str(), cluster_name.c_str()); + intercomputation_edges->push_back(call_edge); + } + } } - StrAppend(&graph_body, ";\n"); - ++operand_number; } - // Fusion nodes are handled specially because they contain nested - // expressions. - if (instruction->opcode() == HloOpcode::kFusion) { - string cluster_name = - StrCat("cluster_", InstructionId(instruction.get())); - StrAppend(&graph_body, "subgraph ", cluster_name, " {\n"); - StrAppend(&graph_body, - "label=<<b>fused expression</b>>;\nstyle=\"rounded,filled\";\n" - "color=lightgrey;\n"); - StrAppend(&graph_body, - InstructionSequenceGraph(instruction->fused_instructions(), - show_addresses, show_layouts, - intercomputation_edges, - hlo_execution_profile, NodeFilter()), - "}\n"); - string fusion_edge = - StrCat(InstructionId(instruction->fused_expression_root()), " -> ", - InstructionId(instruction.get()), - " [ style = \"dotted\", arrowsize=0.0, ltail=", cluster_name, - " ];\n"); - intercomputation_edges->push_back(fusion_edge); - } else { - // Add a dotted edge between the instruction and any computations that the - // instruction calls. - for (const HloComputation* computation : - instruction->called_computations()) { - string cluster_name = StrCat("cluster_", ComputationId(computation)); - string call_edge = Printf( - "%s -> %s [ style=dashed; ltail=%s ];\n", - InstructionId(computation->root_instruction()).c_str(), - InstructionId(instruction.get()).c_str(), cluster_name.c_str()); - intercomputation_edges->push_back(call_edge); + // Inline constant operands into the node. + for (int64 i = 0; i < instruction->operand_count(); ++i) { + const HloInstruction* operand = instruction->operand(i); + if (operand->opcode() != HloOpcode::kConstant) { + continue; + } + + StrAppend(&label, "<br/><b>operand ", i, "</b> = "); + if (ShapeUtil::IsEffectiveScalar(operand->shape())) { + auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex( + operand->shape(), /*linear_index=*/0); + StrAppend(&label, ShapeUtil::HumanString(operand->shape()), "{", + operand->literal().GetAsString(elem_idx), "}"); + } else { + if (tensorflow::StringPiece(operand->name()).starts_with("%constant")) { + StrAppend(&label, operand->name()); + } else { + StrAppend(&label, "constant ", operand->name()); + } } } + + Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n", + InstructionId(instruction.get()).c_str(), label.c_str(), + shape.c_str(), NodeColorAttributes(color).c_str()); } return graph_body; } @@ -513,16 +631,25 @@ stylesheet="%s" )", graph_label.c_str(), dot_stylesheet); + // Dump the subcomputations of each instruction that's shown and doesn't have + // its operands omitted. If an instruction has just one subcomputation and + // it's trivial, omit it: We'll display that subcomputation inlined into the + // instruction's node when we draw it. std::unordered_set<const HloComputation*> computations_to_dump; for (const auto& instr : computation.instructions()) { - if (!filter.Show(instr.get())) { + if (!filter.Show(instr.get()) || filter.OmitOperands(instr.get())) { continue; } if (instr->opcode() == HloOpcode::kFusion) { computations_to_dump.insert(instr->fused_instructions_computation()); } - for (const HloComputation* computation : instr->called_computations()) { - computations_to_dump.insert(computation); + + const auto& subcomputations = instr->called_computations(); + if (subcomputations.size() != 1 || + !MatchTrivialComputation(subcomputations.front())) { + for (const HloComputation* computation : instr->called_computations()) { + computations_to_dump.insert(computation); + } } } @@ -620,31 +747,41 @@ class FileGraphRenderer : public GraphRendererInterface { } }; -// Gets roughly all instructions whose distance from root is <= radius. -std::unordered_set<const HloInstruction*> GetInstructionsInNeighborhood( - const HloInstruction& root, int64 radius) { - std::unordered_set<const HloInstruction*> ret; - +// Gets a NodeFilter that includes roughly all instructions whose distance from +// root is <= radius. +// +// It's confusing to draw a node and include only some of its operands. So if +// some but not all of a node's operands are <= radius units away from the root, +// we include the other operands (unless there are a lot of them, as often in a +// tuple node). These additional operands may have as inputs other nodes +// already present in the graph, but we don't draw those edges unless *all* of +// the inputs are present. (Otherwise we'd have the same problem we were trying +// to solve in the first place!) +NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { + // First, find the neighborhood of nodes with distance from root <= radius. + // These nodes are our initial set of "normal" nodes. + std::unordered_map<const HloInstruction*, NodeFilterResult> nodes; std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist; - worklist.push_back({&root, 0}); - + worklist.push_back({root, 0}); while (!worklist.empty()) { const HloInstruction* instr; int64 depth; std::tie(instr, depth) = worklist.front(); worklist.pop_front(); - ret.insert(instr); + nodes[instr] = kNormalNode; if (depth == radius) { continue; } + // Traverse into instr's operands. + // // Don't traverse into tuples' operands unless the tuple is the root. // Usually a tuple is the bottommost node in the graph, and so its operands // are not interesting to the graph at hand. - if (instr == &root || instr->opcode() != HloOpcode::kTuple) { + if (instr == root || instr->opcode() != HloOpcode::kTuple) { for (const HloInstruction* operand : instr->operands()) { - if (ret.find(operand) == ret.end()) { + if (!nodes.count(operand)) { worklist.push_back({operand, depth + 1}); } } @@ -655,14 +792,77 @@ std::unordered_set<const HloInstruction*> GetInstructionsInNeighborhood( // constants' users. if (instr->opcode() != HloOpcode::kConstant) { for (const HloInstruction* user : instr->users()) { - if (ret.find(user) == ret.end()) { + if (!nodes.count(user)) { worklist.push_back({user, depth + 1}); } } } } - return ret; + auto is_displayed = [&](const HloInstruction* instr) { + return nodes.count(instr) > 0; + }; + + // If a node has some but not all of its operands omitted, add the operands to + // the map with type kOmitNodeOperands. Unless the node has a lot of + // operands, in which case just mark the node as "some operands omitted". + std::vector<const HloInstruction*> extra_operands; + for (auto& kv : nodes) { + const HloInstruction* instr = kv.first; + NodeFilterResult& filter_result = kv.second; + const auto& operands = instr->operands(); + + // Mark nodes with many operands and some omitted as "some operands omitted" + // and carry on -- don't add their omitted operands to extra_operands. + if (operands.size() > 4) { + if (std::any_of(operands.begin(), operands.end(), is_displayed) && + !std::all_of(operands.begin(), operands.end(), is_displayed)) { + filter_result = kSomeOperandsOmitted; + } + continue; + } + + if (std::any_of(operands.begin(), operands.end(), is_displayed)) { + for (const HloInstruction* operand : operands) { + if (!is_displayed(operand)) { + extra_operands.push_back(operand); + } + } + } + } + for (const HloInstruction* instr : extra_operands) { + nodes[instr] = kOmitNodeOperands; + } + + // Some of the nodes in extra_operands may now have all of their inputs + // present in nodes. We can promote these to normal nodes. + for (const HloInstruction* instr : extra_operands) { + const auto& operands = instr->operands(); + if (std::all_of(operands.begin(), operands.end(), is_displayed)) { + nodes[instr] = kNormalNode; + } + } + + // If none of a node's operands appear in nodes, mark it as type + // kOmitNodeOperands so it gets styled appropriately. + for (auto& kv : nodes) { + const auto& operands = kv.first->operands(); + if (!operands.empty() && + std::none_of(operands.begin(), operands.end(), is_displayed)) { + kv.second = kOmitNodeOperands; + } + } + + // Highlight the root node. + nodes[root] = kHighlightNode; + + return NodeFilter([=](const HloInstruction* instr) { + auto it = nodes.find(instr); + if (it != nodes.end()) { + return it->second; + } + return kHideNode; + }); } XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0); @@ -699,22 +899,9 @@ string DumpGraph(const HloComputation& computation, const string& label, string DumpNeighborhoodAround(const HloInstruction& node, int radius) { auto debug_options = node.GetModule()->config().debug_options(); - - std::unordered_set<const HloInstruction*> neighborhood = - GetInstructionsInNeighborhood(node, radius); - - NodeFilter filter([&](const HloInstruction* instr) { - if (instr == &node) { - return kHighlightNode; - } - if (neighborhood.find(instr) != neighborhood.end()) { - return kNormalNode; - } - return kHideNode; - }); - string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); + NodeFilter filter = MakeNodeFilter(&node, radius); string graph = ComputationToDotGraph( *node.parent(), label, /*show_addresses=*/debug_options.xla_hlo_graph_addresses(), |