diff options
Diffstat (limited to 'tensorflow/core/graph/dot.cc')
-rw-r--r-- | tensorflow/core/graph/dot.cc | 289 |
1 files changed, 289 insertions, 0 deletions
diff --git a/tensorflow/core/graph/dot.cc b/tensorflow/core/graph/dot.cc new file mode 100644 index 0000000000..6d6e46ce61 --- /dev/null +++ b/tensorflow/core/graph/dot.cc @@ -0,0 +1,289 @@ +#include "tensorflow/core/graph/dot.h" + +#include <map> +#include <unordered_map> +#include <unordered_set> + +#include "tensorflow/core/graph/colors.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/util/util.h" + +namespace tensorflow { + +static string GraphNodeName(const DotOptions& opts, const Node* n) { + return strings::StrCat("N", n->id()); +} + +bool ShoulDisplayOpType(const Node* n) { + if (n->type_string() == "NoOp") { + return false; + } + const string& op_name = n->def().name(); + if (op_name.find(n->type_string() + "_") == 0) { + return false; + } + return true; +} + +string DotGraph(const Graph& g, const DotOptions& opts) { + RegexpStringPiece flag(opts.prefix_collapse_regexp); + if (flag == "all") { + flag = "."; + } else if (flag == "none") { + flag = "^$"; + } + RE2 cluster_name_pattern(flag); + string result; + strings::StrAppend(&result, "digraph G {\n"); + strings::StrAppend(&result, "rankdir=\"BT\"\n"); + + std::map<string, int> device_index; // Map from device name to index. + std::unordered_set<Node*> visible_nodes; // Nodes to display. + // Cluster name => set of nodes. + std::unordered_map<string, std::unordered_set<Node*> > clusters; + // Node* => Cluster + std::unordered_map<Node*, string> node_cluster; + for (Node* src : g.nodes()) { + if (opts.include_node_function != nullptr && + !opts.include_node_function(src)) { + continue; + } + // Do not display source and sink nodes + if (src->IsSource() || src->IsSink()) { + continue; + } + visible_nodes.insert(src); + const string name_prefix = NodeNamePrefix(src->def().name()).ToString(); + if (!name_prefix.empty()) { + clusters[name_prefix].insert(src); + node_cluster[src] = name_prefix; + } + // Record device if present. + if (src->IsOp()) { + const string& d = src->assigned_device_name(); + if (!d.empty()) { + device_index[d] = -1; // Assigned later + } + } + } + + // Add nodes whose name is exactly a cluster name to the cluster itself. + for (Node* src : g.nodes()) { + if (node_cluster.count(src) == 0) { + const string name = src->def().name(); + auto it = clusters.find(name); + if (it != clusters.end()) { + it->second.insert(src); + node_cluster[src] = name; + } + } + } + + auto node_in_collapsed_cluster = [&node_cluster, + &cluster_name_pattern](Node* n) { + return node_cluster.count(n) > 0 && + RE2::PartialMatch(node_cluster[n], cluster_name_pattern); + }; + + // Assign device indices in sorted order. + int num = 0; + for (auto& e : device_index) { + e.second = num++; + } + + double total_node_cost = 0; + double avg_node_cost = 1; + if (opts.node_cost) { + int node_count = 0; + for (const Node* n : g.nodes()) { + total_node_cost += opts.node_cost(n); + ++node_count; + } + if (total_node_cost > 0) avg_node_cost = total_node_cost / node_count; + } + + for (Node* src : g.nodes()) { + if (visible_nodes.count(src) == 0 || node_in_collapsed_cluster(src)) { + continue; + } + string label = src->name(); + if (ShoulDisplayOpType(src)) { + // Append the op type if it is not directly deducible from the op name. + strings::StrAppend(&label, "\\n(", src->type_string(), ")"); + } + const char* shape = "box"; + const char* color = nullptr; + if (src->IsSource()) { + shape = "oval"; + } else if (src->IsSink()) { + shape = "oval"; + } else { + const string& d = src->assigned_device_name(); + const int dindex = (!d.empty()) ? device_index[d] : -1; + if (dindex >= 0) { + color = ColorFor(dindex); + } + + shape = "box"; + } + + if (opts.node_label) { + string extra = opts.node_label(src); + if (!extra.empty()) { + strings::StrAppend(&label, "\\n", extra); + } + } + + strings::StrAppend(&result, GraphNodeName(opts, src), "[shape=", shape, + ", label=\"", label, "\""); + if (opts.node_cost && total_node_cost > 0) { + // Pick fontsize in range [8..40] so that area is proportional to cost. + const double cost = opts.node_cost(src); + const double relcost = fabs(cost / avg_node_cost); + // Average cost node has font size of 12. + const int fs = 8 + static_cast<int>(4.0 * std::min(sqrt(relcost), 8.0)); + strings::StrAppend(&result, ", width=0, height=0, fontsize=", fs); + VLOG(2) << "Node: " << cost << " => " << relcost << " => " << fs; + } + if (color != nullptr) { + strings::StrAppend(&result, ", fillcolor=\"", color, + "\", fontcolor=\"white\", style=\"filled\""); + } + strings::StrAppend(&result, "]\n"); + } + + for (auto c : clusters) { + const string& cluster_name = c.first; + const std::unordered_set<Node*> nodes = c.second; + std::unordered_map<string, int> node_colors; + for (auto n : nodes) { + const string& d = n->assigned_device_name(); + const int dindex = (!d.empty()) ? device_index[d] : -1; + if (dindex >= 0) { + ++node_colors[ColorFor(dindex)]; + } + } + + string majority_color; + if (node_colors.empty()) { + majority_color = ColorFor(0); + } else { + majority_color = std::max_element(node_colors.begin(), node_colors.end(), + [](const std::pair<string, int>& x, + const std::pair<string, int>& y) { + return x.second < y.second; + }) + ->first; + } + + if (!RE2::PartialMatch(cluster_name, cluster_name_pattern)) { + strings::StrAppend(&result, "subgraph cluster_", cluster_name, "{\n"); + for (auto n : nodes) { + strings::StrAppend(&result, GraphNodeName(opts, n), ";\n"); + } + strings::StrAppend(&result, "}\n"); + } else { + strings::StrAppend(&result, cluster_name, " [shape=oval, fillcolor=\"", + majority_color, "\", label=\"", cluster_name, + "\", style=\"filled\", fontcolor=\"white\"]\n"); + } + } + + std::unordered_set<string> edge_drawn; + + double max_edge_cost = 0; + double total_edge_cost = 0; + double avg_edge_cost = 1; + if (opts.edge_cost && g.edges().size()) { + for (const Edge* e : g.edges()) { + auto cost = opts.edge_cost(e); + total_edge_cost += cost; + max_edge_cost = std::max(max_edge_cost, cost); + } + avg_edge_cost = total_edge_cost / g.edges().size(); + } + VLOG(2) << "Edge cost tot/max/avg: " << total_edge_cost << "/" + << max_edge_cost << "/" << avg_edge_cost; + + for (const Edge* e : g.edges()) { + Node* src = e->src(); + Node* dst = e->dst(); + // If either endpoint isn't drawn in the graph, don't draw the edge + if (visible_nodes.count(src) == 0 || visible_nodes.count(dst) == 0) { + continue; + } + + const string src_name = node_in_collapsed_cluster(src) + ? node_cluster[src] + : GraphNodeName(opts, src); + const string dst_name = node_in_collapsed_cluster(dst) + ? node_cluster[dst] + : GraphNodeName(opts, dst); + // Don't draw self edges + if (src_name == dst_name) { + continue; + } + // And previously drawn edges. + const string& edge_name = strings::StrCat(src_name, ":", dst_name); + if (edge_drawn.count(edge_name) > 0) { + continue; + } + edge_drawn.insert(edge_name); + + strings::StrAppend(&result, src_name, " -> ", dst_name, "["); + string label; + if (e->IsControlEdge()) { + strings::StrAppend(&result, " style=dotted"); + } + if (opts.edge_label) { + string label = opts.edge_label(e); + if (!label.empty()) { + strings::StrAppend(&result, " label=<", label, ">"); + } + } + // Make edge widths proportional to amount of data transferred. + if (opts.edge_cost && max_edge_cost > 0) { + const double cost = opts.edge_cost(e); + const double relcost = fabs(cost / avg_edge_cost); + // Pick penwidth in range [1..6] so that width is proportional to cost. + const int pw = 1 + std::min(5, static_cast<int>(2.0 * relcost)); + strings::StrAppend(&result, " penwidth=", pw); + // Use weight attributes [1..100] to keep heavier edges more vertical. + const int weight = 1 + std::min(99, static_cast<int>(100.0 * relcost)); + strings::StrAppend(&result, " weight=", weight); + VLOG(2) << "Edge: " << cost << " => " << relcost << " => " << pw << "/" + << weight; + } + + strings::StrAppend(&result, "]\n"); + } + // Compute some statistics + int op_nodes = 0; + for (Node* n : g.nodes()) { + if (n->IsOp()) { + op_nodes++; + } + } + + // Emit legend + strings::StrAppend(&result, + "{ rank = source; Legend [shape=box, margin=0, label=<", + "<TABLE BORDER=\"0\" CELLBORDER=\"1\" CELLSPACING=\"0\" ", + "CELLPADDING=\"4\">", "<TR><TD COLSPAN=\"2\">op_nodes: ", + op_nodes, "</TD></TR>\n"); + for (const auto& e : device_index) { + const int dindex = e.second; + strings::StrAppend(&result, "<TR><TD BGCOLOR=\"", ColorFor(dindex), + "\"><FONT COLOR=\"white\">", dindex, "</FONT></TD><TD>", + e.first, "</TD></TR>\n"); + } + strings::StrAppend(&result, "</TABLE>>]}\n"); + + strings::StrAppend(&result, "}\n"); // End digraph + return result; +} + +} // namespace tensorflow |