diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_graph_dumper.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_graph_dumper.cc | 1117 |
1 files changed, 626 insertions, 491 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index acd26c4e31..c6202548f1 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -48,37 +48,37 @@ using ::tensorflow::Env; using ::tensorflow::gtl::nullopt; using ::tensorflow::gtl::optional; using ::tensorflow::io::JoinPath; -using ::tensorflow::strings::Appendf; -using ::tensorflow::strings::Printf; using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; using ::tensorflow::str_util::Join; +using ::tensorflow::str_util::StringReplace; using ::tensorflow::WriteStringToFile; namespace xla { namespace hlo_graph_dumper { namespace { -// Node color schemes, used by NodeColorAttributes. -enum ColorScheme { - kBlue, - kBrown, - kDarkBlue, - kDarkGreen, - kDarkRed, - kGray, - kGreen, - kOrange, - kPurple, - kRed, - kWhite, - kYellow, - - // Causes the node's border to be a dashed line, and its content to be gray - // text on a white background, suggesting that this is an "unimportant" node. - kDashedBorder, +// Helpers for Printf and Appendf. +template <typename T> +struct PrintfConvert { + const T& operator()(const T& t) const { return t; } +}; +template <> +struct PrintfConvert<string> { + const char* operator()(const string& s) const { return s.c_str(); } }; +// Like tensorflow::strings::Printf/Appendf, but you don't need to call c_str() +// on strings. +template <typename... Ts> +string Printf(const char* fmt, const Ts&... ts) { + return tensorflow::strings::Printf(fmt, PrintfConvert<Ts>()(ts)...); +} +template <typename... Ts> +void Appendf(string* s, const char* fmt, const Ts&... ts) { + tensorflow::strings::Appendf(s, fmt, PrintfConvert<Ts>()(ts)...); +} + // Used to indicate how we should treat a given HLOInstruction in the graph. // should we treat it like normal, hide it, and so on? enum NodeFilterResult { @@ -92,6 +92,9 @@ enum NodeFilterResult { // Style the node the same as kSomeOperandsOmitted, but also don't connect it // to its operands, even if they're present in the graph. kOmitNodeOperands, + // Same style as kSomeOperandsOmitted, but used to indicate that some of the + // node's *users* have been omitted. + kSomeUsersOmitted, }; // NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult. @@ -118,11 +121,41 @@ class NodeFilter { auto result = filter_(instr); return result == kOmitNodeOperands || result == kSomeOperandsOmitted; } + bool Deemphasized(const HloInstruction* instr) const { + auto result = filter_(instr); + return result == kOmitNodeOperands || result == kSomeOperandsOmitted || + result == kSomeUsersOmitted; + } + + bool ShowFusionSubcomputation(const HloInstruction* instr) const { + CHECK_EQ(instr->opcode(), HloOpcode::kFusion); + return Show(instr) && !SomeOrAllOperandsOmitted(instr); + } private: std::function<NodeFilterResult(const HloInstruction* instr)> filter_; }; +// Node color schemes, used by NodeColorAttributes. +enum ColorScheme { + kBlue, + kBrown, + kDarkBlue, + kDarkGreen, + kDarkRed, + kGray, + kGreen, + kOrange, + kPurple, + kRed, + kWhite, + kYellow, + + // Causes the node's border to be a dashed line, and its content to be gray + // text on a white background, suggesting that this is an "unimportant" node. + kDashedBorder, +}; + // Given a ColorScheme, returns an attribute string for a node of that color. // Sets the node's style and fill/stroke/text colors. // @@ -170,19 +203,8 @@ string NodeColorAttributes(ColorScheme color) { // Replaces <> with <>, so that this string is safe(er) for use in a // graphviz HTML-like string. string HtmlLikeStringSanitize(tensorflow::StringPiece s) { - return tensorflow::str_util::StringReplace( - tensorflow::str_util::StringReplace(s, "<", "<", /*replace_all=*/true), - ">", ">", /*replace_all=*/true); -} - -// Returns the dot graph identifier for the given instruction. -string InstructionId(const HloInstruction* instruction) { - return Printf("%lld", reinterpret_cast<uint64>(instruction)); -} - -// Returns the dot graph identifier for the given computation. -string ComputationId(const HloComputation* computation) { - return Printf("%lld", reinterpret_cast<uint64>(computation)); + return StringReplace(StringReplace(s, "<", "<", /*replace_all=*/true), ">", + ">", /*replace_all=*/true); } // Tries to generates a human-readable one-word description of the given @@ -194,9 +216,15 @@ string ComputationId(const HloComputation* computation) { // "return param0 * param1;" --> "multiply" // "return min(param0, param1);" --> "min" // "return max(param0, param1);" --> "max" +// "return param0 <= param1;" --> "less-or-equal" +// "return param0 >= param1;" --> "greater-or-equal" +// "return param0 > param1;" --> "greater-than" +// "return param0 < param1;" --> "less-than" +// "return param0 == param1;" --> "equal-to" +// "return param0 != param1;" --> "not-equal-to" // -// where param0 and param1 are effective scalars. Since all of the ops above -// are commutative, we also support them with param0 and param1 swapped. +// where param0 and param1 are effective scalars. For the ops that are +// commutative, we also support them with param0 and param1 swapped. // // This is useful primarily for reduce and map nodes. These take a // subcomputation which is almost always one of the four above, and pattern @@ -219,6 +247,7 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) { operand1->opcode() != HloOpcode::kParameter) { return nullopt; } + // Check that the two operands of root are param0 and param1. All of the // opcodes we recognize are commutative, so we're OK with either order. auto n0 = operand0->parameter_number(); @@ -227,6 +256,20 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) { return nullopt; } + // If the params are reversed, check that the operation being performed is + // commutative. + if (n0 == 1) { + switch (root->opcode()) { + case HloOpcode::kLe: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLt: + return nullopt; + default: + break; + } + } + // Check that the root and params are all effective scalars. if (!ShapeUtil::IsEffectiveScalar(root->shape()) || !ShapeUtil::IsEffectiveScalar(operand0->shape()) || @@ -244,444 +287,542 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) { return "min"; case HloOpcode::kMaximum: return "max"; + case HloOpcode::kLe: + return "less-or-equal"; + case HloOpcode::kGe: + return "greater-or-equal"; + case HloOpcode::kGt: + return "greater-than"; + case HloOpcode::kLt: + return "less-than"; + case HloOpcode::kEq: + return "equal-to"; + case HloOpcode::kNe: + return "not-equal-to"; default: return nullopt; } } -// Returns the dot graph edges and nodes for the given instruction sequence. -// Edges which extend between computations are added to the vector -// intercomputation_edges. This is necessary because graphviz does not render -// the graph properly unless these inter-computation edges appear after all -// subgraph statements. -string InstructionSequenceGraph( - const std::list<std::unique_ptr<HloInstruction>>& instructions, - bool show_addresses, bool show_layouts, - std::vector<string>* intercomputation_edges, - const HloExecutionProfile* hlo_execution_profile, - const NodeFilter& filter) { - string graph_body; - - for (auto& instruction : instructions) { - if (!filter.Show(instruction.get())) { - continue; - } +class HloDotDumper { + public: + HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, + bool show_addresses, bool show_layouts, + const HloExecutionProfile* profile, NodeFilter filter) + : computation_(computation), + label_(label.ToString()), + show_addresses_(show_addresses), + show_layouts_(show_layouts), + profile_(profile), + filter_(std::move(filter)) {} + + string Dump(); - // We don't display constants as separate nodes; they're merged into their - // users. - if (instruction->opcode() == HloOpcode::kConstant) { + private: + // Returns the dot graph identifier for the given instruction. + string InstructionId(const HloInstruction* instruction) { + return StrCat(reinterpret_cast<uint64>(instruction)); + } + + // Returns the dot graph identifier for the given computation. + string SubcomputationId(const HloComputation* computation) { + return StrCat("cluster_", reinterpret_cast<uint64>(computation)); + } + + string Header(); + string Footer(); + + // Maps HloComputations we should dump to their parent instruction in the + // outer computation. + std::unordered_map<const HloComputation*, const HloInstruction*> + SubcomputationsToDump(); + + string DumpSubcomputation(const HloComputation* subcomp, + const HloInstruction* parent_instr); + string DumpComputation(const HloComputation* comp); + string DumpInstruction(const HloInstruction* instr); + ColorScheme GetInstructionColor(const HloInstruction* instr); + string GetInstructionNodeShape(const HloInstruction* instr); + string GetInstructionNodeLabel(const HloInstruction* instr); + string GetInstructionNodeExtraInfo(const HloInstruction* instr); + string GetInstructionNodeInlinedConstants(const HloInstruction* instr); + void AddInstructionIncomingEdges(const HloInstruction* instr); + + // If instr has just one computation and it's trivial (e.g. "return param0 + + // param1"), returns a string you can put into the node's body that names the + // subcomputation, e.g. "Subcomputation: <b>add</b>". + string GetInstructionTrivialComputationStr(const HloInstruction* instr); + + const HloComputation* computation_; // never null + const string label_; // overall name for the graph + const bool show_addresses_; + const bool show_layouts_; + const HloExecutionProfile* profile_; // may be null + const NodeFilter filter_; + + // Edges to print from Footer(). Edges come at the end because graphviz is + // unhappy if an edge from a subcomputation to a node in the outer computation + // appears before both the inner computation and the destination node are + // defined. + std::vector<string> edges_; +}; + +string HloDotDumper::Dump() { + string g = Header(); + for (const auto& kv : SubcomputationsToDump()) { + const HloComputation* subcomp = kv.first; + const HloInstruction* parent = kv.second; + StrAppend(&g, DumpSubcomputation(subcomp, parent)); + } + StrAppend(&g, DumpComputation(computation_)); + StrAppend(&g, Footer()); + return g; +} + +string HloDotDumper::Header() { + // DOT graphs accept a stylesheet as a URI. So naturally, an inline + // stylesheet is a data URI! + const char* fmt = R"(digraph G { +rankdir = TB; +compound = true; +label = <<b>%s</b>>; +labelloc = t; +stylesheet=" + data:text/css, + @import url(https://fonts.googleapis.com/css?family=Roboto:400,700); + svg text { + font-family: 'Roboto'; + font-size: 12px; + } +" + +)"; + + string graph_label = StrCat(label_, "<br/>", computation_->name()); + if (profile_ != nullptr) { + auto cycles = profile_->total_cycles_executed(*computation_); + Appendf(&graph_label, "<br/>total cycles = %lld (%s)", cycles, + tensorflow::strings::HumanReadableNum(cycles)); + } + return Printf(fmt, graph_label); +} + +string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); } + +std::unordered_map<const HloComputation*, const HloInstruction*> +HloDotDumper::SubcomputationsToDump() { + // Dump the subcomputations of each instruction that's shown and doesn't have + // its operands omitted. If an instruction has just one subcomputation and + // it's trivial, omit it: We'll display that subcomputation inlined into the + // instruction's node when we draw it. + std::unordered_map<const HloComputation*, const HloInstruction*> to_dump; + for (const auto& instr : computation_->instructions()) { + if (!filter_.Show(instr.get()) || + filter_.SomeOrAllOperandsOmitted(instr.get())) { continue; } + if (instr->opcode() == HloOpcode::kFusion) { + to_dump[instr->fused_instructions_computation()] = instr.get(); + } - ColorScheme color = kYellow; - string shape = "box"; - - // Build the first line or two of the node, containing its name and opcode - // (if the opcode isn't redundant with the name). - string name; - if (instruction->opcode() == HloOpcode::kParameter) { - // If we have a parameter, put the param number in the name. - name = StrCat("<b>Parameter ", instruction->parameter_number(), - "</b><br/>", HtmlLikeStringSanitize(instruction->name())); - } else if (tensorflow::StringPiece(instruction->name()) - .starts_with( - StrCat("%", instruction->ExtendedOpcodeStr()))) { - // The HLO instruction name contains usually the opcode, e.g. "%add.42" is - // an add instruction. In this case we render just the name. - name = StrCat("<b>", HtmlLikeStringSanitize(instruction->name()), "</b>"); - } else if (instruction->opcode() == HloOpcode::kFusion && - tensorflow::StringPiece(instruction->name()) - .starts_with( - StrCat("%", HloOpcodeString(instruction->opcode())))) { - // Fusion nodes are usually named e.g. "%fusion.5". We render these as - // e.g. "%fusion.5<br/>input fusion". - name = StrCat("<b>", HtmlLikeStringSanitize(instruction->name()), - "</b><br/>", - HtmlLikeStringSanitize(instruction->ToCategory())); - } else { - // If the name does not contain the opcode, render both. - name = StrCat("<b>", - HtmlLikeStringSanitize(instruction->ExtendedOpcodeStr()), - "</b><br/>", HtmlLikeStringSanitize(instruction->name())); - } - - if (HloOpcode::kConvolution == instruction->opcode()) { - StrAppend( - &name, "<br/>", - HtmlLikeStringSanitize( - instruction->ConvolutionDimensionNumbersToString()), - "<br/>", - HtmlLikeStringSanitize(window_util::ToString(instruction->window()))); - } - - if (!instruction->metadata().op_name().empty()) { - StrAppend(&name, "<br/>", - HtmlLikeStringSanitize(instruction->metadata().op_name())); - } - if (!instruction->metadata().source_file().empty() && - instruction->metadata().source_line() != 0) { - StrAppend(&name, "<br/>", instruction->metadata().source_file(), ":", - instruction->metadata().source_line()); - } - - // Pick different colors or shapes for instructions which are particularly - // expensive (eg, dot) and those which are unusual in some way or unique - // (eg, parameter). - switch (instruction->opcode()) { - // "Normal" instructions. Mostly cheap and elementwise. No call to - // embedded computations. In this case, use default color, shape and - // label. - case HloOpcode::kAbs: - case HloOpcode::kAdd: - case HloOpcode::kCeil: - case HloOpcode::kClamp: - case HloOpcode::kConvert: - case HloOpcode::kCos: - case HloOpcode::kDivide: - case HloOpcode::kEq: - case HloOpcode::kExp: - case HloOpcode::kFloor: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kIndex: - case HloOpcode::kIsFinite: - case HloOpcode::kLe: - case HloOpcode::kLog: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalNot: - case HloOpcode::kLogicalOr: - case HloOpcode::kLt: - case HloOpcode::kMaximum: - case HloOpcode::kMinimum: - case HloOpcode::kMultiply: - case HloOpcode::kNe: - case HloOpcode::kNegate: - case HloOpcode::kPower: - case HloOpcode::kRemainder: - case HloOpcode::kSelect: - case HloOpcode::kSign: - case HloOpcode::kSin: - case HloOpcode::kSlice: - case HloOpcode::kSort: - case HloOpcode::kSubtract: - case HloOpcode::kTanh: - break; - case HloOpcode::kRng: - StrAppend(&name, "<br/>", - RandomDistribution_Name(instruction->random_distribution())); - break; - case HloOpcode::kBroadcast: - case HloOpcode::kTranspose: - StrAppend(&name, "<br/>", "dims={", - Join(instruction->dimensions(), ","), "}"); - break; - case HloOpcode::kBitcast: - case HloOpcode::kTuple: - case HloOpcode::kTrace: - color = kWhite; - break; - case HloOpcode::kGetTupleElement: - color = kWhite; - StrAppend(&name, "<br/>index=", instruction->tuple_index()); - break; - case HloOpcode::kConcatenate: - case HloOpcode::kCopy: - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kPad: - case HloOpcode::kReshape: - case HloOpcode::kReverse: - case HloOpcode::kUpdate: - color = kGreen; - break; - case HloOpcode::kConvolution: - case HloOpcode::kDot: - color = kDarkBlue; - break; - case HloOpcode::kParameter: - color = kOrange; - break; - case HloOpcode::kBatchNormTraining: - StrAppend(&name, " feature_index=", instruction->feature_index()); - color = kPurple; - break; - case HloOpcode::kBatchNormGrad: - StrAppend(&name, " feature_index=", instruction->feature_index()); - color = kPurple; - break; - case HloOpcode::kReduce: - StrAppend(&name, " dims=", Join(instruction->dimensions(), ",")); - color = kPurple; - break; - case HloOpcode::kSelectAndScatter: - case HloOpcode::kReduceWindow: - color = kPurple; - break; - case HloOpcode::kWhile: - shape = "ellipse"; - color = kDarkGreen; - break; - case HloOpcode::kMap: - case HloOpcode::kFusion: - color = kGray; - break; - case HloOpcode::kSend: - case HloOpcode::kRecv: - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: - case HloOpcode::kCrossReplicaSum: - color = kBrown; - break; - case HloOpcode::kCall: - color = kDarkGreen; - break; - case HloOpcode::kCustomCall: - color = kDarkGreen; - StrAppend(&name, "<br/>", - "custom_call_target=", instruction->custom_call_target()); - break; - case HloOpcode::kReducePrecision: - // Make ReducePrecision ops a bit more visible, since typically they - // will be inserted as modifications to an existing graph. - color = kRed; - break; - case HloOpcode::kConstant: - LOG(FATAL) << "Constants don't get their own nodes in the graph."; - } - - // Create instruction node with appropriate label, shape, and color. - // label is interpreted as an HTML-like string, so newlines must be - // delimited with <br/>, rather than \n. - string label = - StrCat(name, "<br/>", ShapeUtil::HumanString(instruction->shape())); - - if (show_addresses) { - Appendf(&label, "<br/>[%p]", instruction.get()); - } - if (show_layouts && LayoutUtil::HasLayout(instruction->shape())) { - string layout_string; - if (ShapeUtil::IsTuple(instruction->shape())) { - // For tuples, emit the full shape because the layout of a tuple is not - // represented in a single Layout field. - layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); - } else { - layout_string = - Join(instruction->shape().layout().minor_to_major(), ","); - } - StrAppend(&label, "<br/>layout={", layout_string, "}"); - } - if (hlo_execution_profile != nullptr) { - auto hlo_cycles_executed = - hlo_execution_profile->GetProfileResult(*instruction); - auto total_cycles_executed = - hlo_execution_profile->total_cycles_executed(*instruction->parent()); - if (hlo_cycles_executed > 0 && total_cycles_executed > 0) { - Appendf(&label, "<br/>%% of cycles executed=%.2f", - (static_cast<double>(hlo_cycles_executed) / - static_cast<double>(total_cycles_executed)) * - 100); + for (const HloComputation* comp : instr->called_computations()) { + if (!MatchTrivialComputation(comp)) { + to_dump[comp] = instr.get(); } } + } + return to_dump; +} - // If this node's operands are omitted, style it accordingly. - if (filter.SomeOrAllOperandsOmitted(instruction.get())) { - color = kDashedBorder; - } +string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, + const HloInstruction* parent_instr) { + const char* computation_fmt = R"(subgraph %s { +%s; +label = <%s>; +labelloc = t; +%s +} // %s - // If this node is highlighted, override its formatting. - if (filter.Highlight(instruction.get())) { - shape = "diamond"; - color = kDarkRed; +)"; + + string id = SubcomputationId(subcomp); + + string subcomp_label, style; + if (parent_instr->opcode() == HloOpcode::kFusion) { + subcomp_label = Printf("Fused expression for <b>%s</b><br/>%s", + HtmlLikeStringSanitize(parent_instr->name()), + HtmlLikeStringSanitize(parent_instr->ToCategory())); + + // Subcomputation's fill/stroke color is light/dark red/gray, depending on + // whether or not the subcomputation's fusion node is highlighted. + bool highlight = filter_.Highlight(parent_instr); + const char* fillcolor = highlight ? "#ffcdd2" : "#f5f5f5"; + const char* strokecolor = highlight ? "#b71c1c" : "#c2c2c2"; + style = Printf(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s")", + fillcolor, strokecolor); + } else { + subcomp_label = Printf("Subcomputation for <b>%s</b><br/>%s", + HtmlLikeStringSanitize(parent_instr->name()), + HtmlLikeStringSanitize(subcomp->name())); + style = "style=rounded; color=black;"; + } + + string comp_body = DumpComputation(subcomp); + string computation = + Printf(computation_fmt, id, style, subcomp_label, comp_body, id); + + // Add an edge from the subcomputation to its parent node. If subcomp + // belongs to a fusion node, it's drawn in place of the fusion instruction, so + // there's no need to link those. + if (parent_instr->opcode() != HloOpcode::kFusion) { + const char* edge_fmt = R"(%s -> %s [ltail="%s", style="dashed"];)"; + edges_.push_back( + Printf(edge_fmt, InstructionId(subcomp->root_instruction()), + InstructionId(parent_instr), SubcomputationId(subcomp))); + } + + return computation; +} + +string HloDotDumper::DumpComputation(const HloComputation* comp) { + string g; + for (const auto& instr : comp->instructions()) { + if (!filter_.Show(instr.get())) { + continue; } + StrAppend(&g, DumpInstruction(instr.get())); + } + return g; +} - // Create edges from the instruction's operands to the instruction. - if (!filter.OmitOperands(instruction.get())) { - int64 operand_number = 0; - for (auto* operand : instruction->operands()) { - if (!filter.Show(operand) || - operand->opcode() == HloOpcode::kConstant) { - ++operand_number; - continue; - } - Appendf(&graph_body, "%s -> %s", InstructionId(operand).c_str(), - InstructionId(instruction.get()).c_str()); - if (instruction->operand_count() > 1) { - Appendf(&graph_body, " [headlabel=\"%lld\",labeldistance=2]", - operand_number); - } - StrAppend(&graph_body, ";\n"); - ++operand_number; - } +string HloDotDumper::DumpInstruction(const HloInstruction* instr) { + // We don't display constants as separate nodes; they're merged into their + // users. + if (instr->opcode() == HloOpcode::kConstant) { + return ""; + } + // Omit the fusion node if its subcomputation is drawn, since the + // subcomputation will be drawn inline. + if (instr->opcode() == HloOpcode::kFusion && + filter_.ShowFusionSubcomputation(instr)) { + return ""; + } - // Fusion nodes are handled specially because they contain nested - // expressions. - if (instruction->opcode() == HloOpcode::kFusion) { - string cluster_name = - StrCat("cluster_", InstructionId(instruction.get())); - StrAppend(&graph_body, "subgraph ", cluster_name, " {\n"); - StrAppend(&graph_body, "label=<fused expression for <b>", - HtmlLikeStringSanitize(instruction->name()), - "</b>>;\nstyle=\"rounded,filled\";\n" - "color=lightgrey;\n"); - StrAppend(&graph_body, - InstructionSequenceGraph(instruction->fused_instructions(), - show_addresses, show_layouts, - intercomputation_edges, - hlo_execution_profile, NodeFilter()), - "}\n"); - string fusion_edge = StrCat( - InstructionId(instruction->fused_expression_root()), " -> ", - InstructionId(instruction.get()), - " [ style = \"dotted\", arrowsize=0.0, ltail=", cluster_name, - " ];\n"); - intercomputation_edges->push_back(fusion_edge); - } else { - // If instruction has just one computation and it's trivial (e.g. - // "return param0 + param1"), put the trivial computation type (e.g. - // "add") into instruction's label. Otherwise, add a dotted edge - // between the instruction and its subcomputations. - const auto& subcomputations = instruction->called_computations(); - - bool trivial_subcomputation = false; - if (subcomputations.size() == 1) { - optional<string> computation_type = - MatchTrivialComputation(subcomputations.front()); - if (computation_type) { - trivial_subcomputation = true; - StrAppend(&label, "<br/>Subcomputation: <b>", *computation_type, - "</b>"); - } - } + ColorScheme color = GetInstructionColor(instr); + string node_shape = GetInstructionNodeShape(instr); + string node_label = GetInstructionNodeLabel(instr); + string extra_info = GetInstructionNodeExtraInfo(instr); + string inlined_constants = GetInstructionNodeInlinedConstants(instr); + string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); + AddInstructionIncomingEdges(instr); + + // Override the node's styling if it should be (de-)emphasized. + if (filter_.Deemphasized(instr)) { + color = kDashedBorder; + } + if (filter_.Highlight(instr)) { + node_shape = "diamond"; + color = kDarkRed; + } - if (!trivial_subcomputation) { - for (const HloComputation* computation : - instruction->called_computations()) { - string cluster_name = - StrCat("cluster_", ComputationId(computation)); - string call_edge = Printf( - "%s -> %s [ style=dashed; ltail=%s ];\n", - InstructionId(computation->root_instruction()).c_str(), - InstructionId(instruction.get()).c_str(), cluster_name.c_str()); - intercomputation_edges->push_back(call_edge); - } - } - } + // Build the text that will be displayed inside the node. + string node_body = node_label; + for (const string& s : + {trivial_subcomputation, extra_info, inlined_constants}) { + if (!s.empty()) { + StrAppend(&node_body, "<br/>", s); } + } - // Inline constant operands into the node. - for (int64 i = 0; i < instruction->operand_count(); ++i) { - const HloInstruction* operand = instruction->operand(i); - if (operand->opcode() != HloOpcode::kConstant) { - continue; - } + return Printf("%s [label=<%s>, shape=%s, %s];\n", InstructionId(instr), + node_body, node_shape, NodeColorAttributes(color)); +} - StrAppend(&label, "<br/><b>operand ", i, "</b> = "); - if (ShapeUtil::IsEffectiveScalar(operand->shape())) { - auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex( - operand->shape(), /*linear_index=*/0); - StrAppend(&label, ShapeUtil::HumanString(operand->shape()), "{", - operand->literal().GetAsString(elem_idx), "}"); - } else { - if (tensorflow::StringPiece(operand->name()).starts_with("%constant")) { - StrAppend(&label, operand->name()); - } else { - StrAppend(&label, "constant ", operand->name()); - } - } +string HloDotDumper::GetInstructionNodeInlinedConstants( + const HloInstruction* instr) { + auto stringify_constant = [](const HloInstruction* constant) { + if (ShapeUtil::IsEffectiveScalar(constant->shape())) { + auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex( + constant->shape(), /*linear_index=*/0); + return Printf("%s{%s}", ShapeUtil::HumanString(constant->shape()), + constant->literal().GetAsString(elem_idx)); + } + if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) { + return constant->name(); } + return StrCat("constant ", constant->name()); + }; - Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n", - InstructionId(instruction.get()).c_str(), label.c_str(), - shape.c_str(), NodeColorAttributes(color).c_str()); + // Special case: If instr is a parameter to a fusion node, check whether the + // corresponding operand to the fusion node is a constant. + if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { + const HloInstruction* fusion = instr->fusion_instruction(); + const HloInstruction* operand = fusion->operand(instr->parameter_number()); + if (operand->opcode() != HloOpcode::kConstant) { + return ""; + } + return stringify_constant(operand); } - return graph_body; + + std::vector<string> lines; + for (int64 i = 0; i < instr->operand_count(); ++i) { + const HloInstruction* operand = instr->operand(i); + if (operand->opcode() != HloOpcode::kConstant) { + continue; + } + lines.push_back( + Printf("<b>operand %lld</b> = %s", i, stringify_constant(operand))); + } + return Join(lines, "<br/>"); } -// DOT graphs accept a stylesheet as a URL. So naturally, an inline stylesheet -// is a data URI! -// -// We don't perform any escaping on this string, so be careful not to use double -// quotes inside. -static const char* dot_stylesheet = R"( -data:text/css, -@import url(https://fonts.googleapis.com/css?family=Roboto:400,700); -svg text { - font-family: 'Roboto'; - font-size: 12px; +ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { + // Pick different colors or shapes for instructions which are particularly + // expensive (eg, dot) and those which are unusual in some way or unique + // (eg, parameter). + switch (instr->opcode()) { + case HloOpcode::kAbs: + case HloOpcode::kAdd: + case HloOpcode::kCeil: + case HloOpcode::kClamp: + case HloOpcode::kConvert: + case HloOpcode::kCos: + case HloOpcode::kDivide: + case HloOpcode::kEq: + case HloOpcode::kExp: + case HloOpcode::kFloor: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kIndex: + case HloOpcode::kIsFinite: + case HloOpcode::kLe: + case HloOpcode::kLog: + case HloOpcode::kLogicalAnd: + case HloOpcode::kLogicalNot: + case HloOpcode::kLogicalOr: + case HloOpcode::kLt: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kNegate: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSelect: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kSlice: + case HloOpcode::kSort: + case HloOpcode::kSubtract: + case HloOpcode::kTanh: + case HloOpcode::kRng: + case HloOpcode::kBroadcast: + case HloOpcode::kTranspose: + return kYellow; + case HloOpcode::kBitcast: + case HloOpcode::kTuple: + case HloOpcode::kTrace: + case HloOpcode::kGetTupleElement: + return kWhite; + case HloOpcode::kConcatenate: + case HloOpcode::kCopy: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kPad: + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kUpdate: + return kGreen; + case HloOpcode::kConvolution: + case HloOpcode::kDot: + return kDarkBlue; + case HloOpcode::kReducePrecision: + return kRed; + case HloOpcode::kParameter: + return kOrange; + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kReduce: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kReduceWindow: + return kPurple; + case HloOpcode::kMap: + case HloOpcode::kFusion: + return kGray; + case HloOpcode::kSend: + case HloOpcode::kRecv: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kCrossReplicaSum: + return kBrown; + case HloOpcode::kCustomCall: + case HloOpcode::kWhile: + case HloOpcode::kCall: + return kDarkGreen; + case HloOpcode::kConstant: + LOG(FATAL) << "Constants don't get their own nodes in the graph."; + } } -)"; -string ComputationToDotGraph(const HloComputation& computation, - const string& label, bool show_addresses, - bool show_layouts, - const HloExecutionProfile* hlo_execution_profile, - const NodeFilter& filter) { - string graph_label = StrCat(label, "<br/>", computation.name()); - if (hlo_execution_profile != nullptr) { - auto cycles = hlo_execution_profile->total_cycles_executed(computation); - Appendf(&graph_label, "<br/>total cycles = %lld (%s)", cycles, - tensorflow::strings::HumanReadableNum(cycles).c_str()); - } - string graph = Printf( - R"(digraph G { -rankdir=TB; -compound=true; -label=<<b>%s</b>>; -labelloc=t; -stylesheet="%s" -)", - graph_label.c_str(), dot_stylesheet); +string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) { + // Give while loops a different shape so they're easier to pick out. + switch (instr->opcode()) { + case HloOpcode::kWhile: + return "ellipse"; + default: + return "rect"; + } +} - // Dump the subcomputations of each instruction that's shown and doesn't have - // its operands omitted. If an instruction has just one subcomputation and - // it's trivial, omit it: We'll display that subcomputation inlined into the - // instruction's node when we draw it. - std::unordered_set<const HloComputation*> computations_to_dump; - for (const auto& instr : computation.instructions()) { - if (!filter.Show(instr.get()) || filter.OmitOperands(instr.get())) { - continue; +string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { + // If we have a parameter, put the param number in the name. + if (instr->opcode() == HloOpcode::kParameter) { + return Printf("<b>Parameter %lld</b>", instr->parameter_number()); + } + + // The HLO instruction name contains usually the opcode, e.g. "%add.42" is + // an add instruction. In this case we render just the name. + if (tensorflow::StringPiece(instr->name()) + .starts_with(StrCat("%", HloOpcodeString(instr->opcode())))) { + return Printf("<b>%s</b>", HtmlLikeStringSanitize(instr->name())); + } + + // If the name does not contain the opcode, render both. + return Printf("<b>%s</b><br/>%s", + HtmlLikeStringSanitize(instr->ExtendedOpcodeStr()), + HtmlLikeStringSanitize(instr->name())); +} + +string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { + string opcode_specific_info = [&]() -> string { + switch (instr->opcode()) { + case HloOpcode::kRng: + return RandomDistribution_Name(instr->random_distribution()); + case HloOpcode::kConvolution: + return StrCat( + HtmlLikeStringSanitize( + instr->ConvolutionDimensionNumbersToString()), + "<br/>", + HtmlLikeStringSanitize(window_util::ToString(instr->window()))); + case HloOpcode::kBroadcast: + case HloOpcode::kTranspose: + case HloOpcode::kReduce: + return Printf("dims={%s}", Join(instr->dimensions(), ",")); + case HloOpcode::kGetTupleElement: + return Printf("index=%lld", instr->tuple_index()); + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormGrad: + return Printf("feature_index=%lld", instr->feature_index()); + case HloOpcode::kCustomCall: + return Printf("custom_call_target=%s", instr->custom_call_target()); + default: + return ""; } - if (instr->opcode() == HloOpcode::kFusion) { - computations_to_dump.insert(instr->fused_instructions_computation()); + }(); + + std::vector<string> lines; + if (!opcode_specific_info.empty()) { + lines.push_back(opcode_specific_info); + } + + // Some instructions have giant tuples as their shapes, so truncate the HLO's + // shape to kMaxShapeLen characters. + constexpr int kMaxShapeLen = 64; + string instr_shape = ShapeUtil::HumanString(instr->shape()); + if (instr_shape.length() > kMaxShapeLen) { + instr_shape = + StrCat(tensorflow::StringPiece(instr_shape).substr(0, kMaxShapeLen - 3), + "..."); + } + lines.push_back(instr_shape); + + if (show_addresses_) { + lines.push_back(Printf("[%p]", instr)); + } + if (show_layouts_ && LayoutUtil::HasLayout(instr->shape())) { + string layout_str; + if (ShapeUtil::IsTuple(instr->shape())) { + // For tuples, emit the full shape because the layout of a tuple is not + // represented in a single Layout field. + layout_str = ShapeUtil::HumanStringWithLayout(instr->shape()); + } else { + layout_str = Join(instr->shape().layout().minor_to_major(), ","); + } + lines.push_back(Printf("layout={%s}", layout_str)); + } + if (profile_ != nullptr) { + double hlo_cycles_executed = profile_->GetProfileResult(*instr); + double total_cycles_executed = + profile_->total_cycles_executed(*instr->parent()); + if (hlo_cycles_executed > 0 && total_cycles_executed > 0) { + lines.push_back( + Printf("%% of cycles executed=%.2f", + 100 * hlo_cycles_executed / total_cycles_executed)); } + } + return Join(lines, "<br/>"); +} - const auto& subcomputations = instr->called_computations(); - if (subcomputations.size() != 1 || - !MatchTrivialComputation(subcomputations.front())) { - for (const HloComputation* computation : instr->called_computations()) { - computations_to_dump.insert(computation); - } +void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { + auto add_edge = [&](const HloInstruction* from, const HloInstruction* to, + int64 operand_num) { + // Fusion nodes' subcomputations are displayed inline, so if 'from' is a + // fusion node and the node's subcomputation is shown, we draw our edge + // starting at the fusion node's root instead of at the fusion node itself. + if (from->opcode() == HloOpcode::kFusion && + filter_.ShowFusionSubcomputation(from)) { + from = from->fused_expression_root(); + } + if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) { + return; } + string edge = Printf("%s -> %s", InstructionId(from), InstructionId(to)); + if (instr->operand_count() > 1) { + Appendf(&edge, R"( [headlabel="%lld",labeldistance=2])", operand_num); + } + StrAppend(&edge, ";"); + edges_.push_back(edge); + }; + + // Add edges from instr's operands to instr. Parameters within fusion + // expressions are handled specially -- we draw an edge from the corresponding + // operand on the fusion node itself to the parameter. + if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { + const HloInstruction* fusion = instr->fusion_instruction(); + add_edge(fusion->operand(instr->parameter_number()), instr, + /*operand_num=*/0); + } else { + for (int64 i = 0; i < instr->operand_count(); ++i) { + add_edge(instr->operand(i), instr, i); + } + } +} + +string HloDotDumper::GetInstructionTrivialComputationStr( + const HloInstruction* instr) { + // called_computations() on a fusion node "inherits" any called computations + // of the fused root, which isn't what we want. Just ignore fusion nodes + // here; they're handled separately. + if (instr->opcode() == HloOpcode::kFusion) { + return ""; } - // Emit embedded computations as subgraph clusters. - std::vector<string> intercomputation_edges; - for (const HloComputation* embedded : - computation.MakeEmbeddedComputationsList()) { - if (!computations_to_dump.count(embedded)) { + std::vector<string> lines; + for (int64 i = 0; i < instr->called_computations().size(); ++i) { + optional<string> computation_type = + MatchTrivialComputation(instr->called_computations()[i]); + if (!computation_type) { continue; } - // Don't pass our filter down into the subcomputation -- always render the - // whole thing. - string graph_body = InstructionSequenceGraph( - embedded->instructions(), show_addresses, show_layouts, - &intercomputation_edges, hlo_execution_profile, NodeFilter()); - Appendf(&graph, - "subgraph cluster_%s " - "{\nstyle=rounded;label=<<b>%s</b>>;labelloc=t;\n%s}\n", - ComputationId(embedded).c_str(), embedded->name().c_str(), - graph_body.c_str()); - } - StrAppend(&graph, - InstructionSequenceGraph(computation.instructions(), show_addresses, - show_layouts, &intercomputation_edges, - hlo_execution_profile, filter)); - - // Edges between computations (subgraph clusters) must be emitted last for the - // graph to be rendered properly for some reason. - StrAppend(&graph, Join(intercomputation_edges, "\n"), "}\n"); - - return graph; + if (instr->called_computations().size() == 1) { + lines.push_back(Printf("Subcomputation: <b>%s</b>", + HtmlLikeStringSanitize(*computation_type))); + } else { + lines.push_back(Printf("Subcomputation %lld: <b>%s</b>", i, + HtmlLikeStringSanitize(*computation_type))); + } + } + return Join(lines, "<br/>"); } tensorflow::mutex& RendererMutex() { @@ -750,14 +891,6 @@ class FileGraphRenderer : public GraphRendererInterface { // Gets a NodeFilter that includes roughly all instructions whose distance from // root is <= radius. -// -// It's confusing to draw a node and include only some of its operands. So if -// some but not all of a node's operands are <= radius units away from the root, -// we include the other operands (unless there are a lot of them, as often in a -// tuple node). These additional operands may have as inputs other nodes -// already present in the graph, but we don't draw those edges unless *all* of -// the inputs are present. (Otherwise we'd have the same problem we were trying -// to solve in the first place!) NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { // First, find the neighborhood of nodes with distance from root <= radius. // These nodes are our initial set of "normal" nodes. @@ -788,14 +921,25 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { } } - // If you're looking at node X, it's probably not interesting that node Y - // also happens to use the same constant, so we don't traverse into - // constants' users. - if (instr->opcode() != HloOpcode::kConstant) { - for (const HloInstruction* user : instr->users()) { - if (!nodes.count(user)) { - worklist.push_back({user, depth + 1}); - } + // Traverse into instr's users, unless: + // + // - there are a ton of them, in which case they're probably not + // interesting (and anyway, rendering them all would make the graph + // unreadable), or + // - instr is a constant, in which case its users are probably not + // interesting. + if (instr->opcode() == HloOpcode::kConstant) { + continue; + } + constexpr int kMaxUsersToRender = 16; + if (instr->user_count() > kMaxUsersToRender) { + // If we're going to skip this node's users, style it as such. + nodes[instr] = kSomeUsersOmitted; + continue; + } + for (const HloInstruction* user : instr->users()) { + if (!nodes.count(user)) { + worklist.push_back({user, depth + 1}); } } } @@ -804,43 +948,27 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { return nodes.count(instr) > 0; }; - // If a node has some but not all of its operands omitted, add the operands to - // the map with type kOmitNodeOperands. Unless the node has a lot of - // operands, in which case just mark the node as "some operands omitted". - std::vector<const HloInstruction*> extra_operands; + // Mark nodes which don't have all of their operands present as "some operands + // omitted". for (auto& kv : nodes) { const HloInstruction* instr = kv.first; NodeFilterResult& filter_result = kv.second; const auto& operands = instr->operands(); - // Mark nodes with many operands and some omitted as "some operands omitted" - // and carry on -- don't add their omitted operands to extra_operands. - if (operands.size() > 4) { - if (std::any_of(operands.begin(), operands.end(), is_displayed) && - !std::all_of(operands.begin(), operands.end(), is_displayed)) { - filter_result = kSomeOperandsOmitted; - } - continue; - } - - if (std::any_of(operands.begin(), operands.end(), is_displayed)) { - for (const HloInstruction* operand : operands) { - if (!is_displayed(operand)) { - extra_operands.push_back(operand); - } - } + // Mark nodes with some omitted as "some operands omitted". + if (std::any_of(operands.begin(), operands.end(), is_displayed) && + !std::all_of(operands.begin(), operands.end(), is_displayed)) { + filter_result = kSomeOperandsOmitted; } } - for (const HloInstruction* instr : extra_operands) { - nodes[instr] = kOmitNodeOperands; - } - // Some of the nodes in extra_operands may now have all of their inputs - // present in nodes. We can promote these to normal nodes. - for (const HloInstruction* instr : extra_operands) { - const auto& operands = instr->operands(); - if (std::all_of(operands.begin(), operands.end(), is_displayed)) { - nodes[instr] = kNormalNode; + // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their + // users made it into the graph by other means. + for (auto& kv : nodes) { + const auto& users = kv.first->users(); + if (kv.second == kSomeUsersOmitted && + std::all_of(users.begin(), users.end(), is_displayed)) { + kv.second = kNormalNode; } } @@ -862,6 +990,10 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { if (it != nodes.end()) { return it->second; } + // Show all nodes in subcomputations. + if (instr->parent() != root->parent()) { + return kNormalNode; + } return kHideNode; }); } @@ -886,10 +1018,12 @@ string DumpGraph(const HloComputation& computation, const string& label, graph_url = FileGraphRenderer().RenderGraph( graph, GraphRendererInterface::TF_GRAPHDEF, debug_options); } else { - graph = ComputationToDotGraph(computation, label, - debug_options.xla_hlo_graph_addresses(), - debug_options.xla_hlo_graph_layout(), - hlo_execution_profile, NodeFilter()); + graph = + HloDotDumper(&computation, label, + /*show_addresses=*/debug_options.xla_hlo_graph_addresses(), + /*show_layouts=*/debug_options.xla_hlo_graph_layout(), + hlo_execution_profile, NodeFilter()) + .Dump(); graph_url = GetGraphRenderer()->RenderGraph( graph, GraphRendererInterface::DOT_GRAPH, debug_options); } @@ -903,11 +1037,12 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius) { string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); NodeFilter filter = MakeNodeFilter(&node, radius); - string graph = ComputationToDotGraph( - *node.parent(), label, - /*show_addresses=*/debug_options.xla_hlo_graph_addresses(), - /*show_layouts=*/debug_options.xla_hlo_graph_layout(), - /*hlo_execution_profile=*/nullptr, filter); + string graph = + HloDotDumper(node.parent(), label, + /*show_addresses=*/debug_options.xla_hlo_graph_addresses(), + /*show_layouts=*/debug_options.xla_hlo_graph_layout(), + /*profile=*/nullptr, filter) + .Dump(); return GetGraphRenderer()->RenderGraph( graph, GraphRendererInterface::DOT_GRAPH, debug_options); } |