aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r--tensorflow/core/graph/algorithm.cc107
-rw-r--r--tensorflow/core/graph/algorithm.h40
-rw-r--r--tensorflow/core/graph/algorithm_test.cc103
-rw-r--r--tensorflow/core/graph/colors.cc25
-rw-r--r--tensorflow/core/graph/colors.h14
-rw-r--r--tensorflow/core/graph/costmodel.cc308
-rw-r--r--tensorflow/core/graph/costmodel.h123
-rw-r--r--tensorflow/core/graph/costutil.cc22
-rw-r--r--tensorflow/core/graph/costutil.h19
-rw-r--r--tensorflow/core/graph/default_device.h25
-rw-r--r--tensorflow/core/graph/dot.cc289
-rw-r--r--tensorflow/core/graph/dot.h43
-rw-r--r--tensorflow/core/graph/edgeset.cc56
-rw-r--r--tensorflow/core/graph/edgeset.h216
-rw-r--r--tensorflow/core/graph/edgeset_test.cc95
-rw-r--r--tensorflow/core/graph/equal_graph_def.cc176
-rw-r--r--tensorflow/core/graph/equal_graph_def.h32
-rw-r--r--tensorflow/core/graph/equal_graph_def_test.cc279
-rw-r--r--tensorflow/core/graph/graph.cc319
-rw-r--r--tensorflow/core/graph/graph.h440
-rw-r--r--tensorflow/core/graph/graph_constructor.cc385
-rw-r--r--tensorflow/core/graph/graph_constructor.h43
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc190
-rw-r--r--tensorflow/core/graph/graph_def_builder.cc121
-rw-r--r--tensorflow/core/graph/graph_def_builder.h181
-rw-r--r--tensorflow/core/graph/graph_partition.cc1050
-rw-r--r--tensorflow/core/graph/graph_partition.h77
-rw-r--r--tensorflow/core/graph/graph_partition_test.cc316
-rw-r--r--tensorflow/core/graph/graph_test.cc252
-rw-r--r--tensorflow/core/graph/node_builder.cc115
-rw-r--r--tensorflow/core/graph/node_builder.h146
-rw-r--r--tensorflow/core/graph/node_builder_test.cc59
-rw-r--r--tensorflow/core/graph/optimizer_cse.cc220
-rw-r--r--tensorflow/core/graph/optimizer_cse.h19
-rw-r--r--tensorflow/core/graph/optimizer_cse_test.cc365
-rw-r--r--tensorflow/core/graph/subgraph.cc258
-rw-r--r--tensorflow/core/graph/subgraph.h49
-rw-r--r--tensorflow/core/graph/subgraph_test.cc305
-rw-r--r--tensorflow/core/graph/tensor_id.cc41
-rw-r--r--tensorflow/core/graph/tensor_id.h28
-rw-r--r--tensorflow/core/graph/tensor_id_test.cc77
-rw-r--r--tensorflow/core/graph/testlib.cc299
-rw-r--r--tensorflow/core/graph/testlib.h141
-rw-r--r--tensorflow/core/graph/types.h17
44 files changed, 7485 insertions, 0 deletions
diff --git a/tensorflow/core/graph/algorithm.cc b/tensorflow/core/graph/algorithm.cc
new file mode 100644
index 0000000000..fd79ead0b1
--- /dev/null
+++ b/tensorflow/core/graph/algorithm.cc
@@ -0,0 +1,107 @@
+#include "tensorflow/core/graph/algorithm.h"
+
+#include <algorithm>
+#include <deque>
+#include <vector>
+
+namespace tensorflow {
+
+void DFS(const Graph& g, std::function<void(Node*)> enter,
+ std::function<void(Node*)> leave) {
+ // Stack of work to do.
+ struct Work {
+ Node* node;
+ bool leave; // Are we entering or leaving n?
+ };
+ std::vector<Work> stack;
+ stack.push_back(Work{g.source_node(), false});
+
+ std::vector<bool> visited(g.num_node_ids(), false);
+ while (!stack.empty()) {
+ Work w = stack.back();
+ stack.pop_back();
+
+ Node* n = w.node;
+ if (w.leave) {
+ leave(n);
+ continue;
+ }
+
+ if (visited[n->id()]) continue;
+ visited[n->id()] = true;
+ if (enter) enter(n);
+
+ // Arrange to call leave(n) when all done with descendants.
+ if (leave) stack.push_back(Work{n, true});
+
+ // Arrange to work on descendants.
+ for (Node* out : n->out_nodes()) {
+ if (!visited[out->id()]) {
+ // Note; we must not mark as visited until we actually process it.
+ stack.push_back(Work{out, false});
+ }
+ }
+ }
+}
+
+void GetPostOrder(const Graph& g, std::vector<Node*>* order) {
+ order->clear();
+ DFS(g, nullptr, [order](Node* n) { order->push_back(n); });
+}
+
+void GetReversePostOrder(const Graph& g, std::vector<Node*>* order) {
+ GetPostOrder(g, order);
+ std::reverse(order->begin(), order->end());
+}
+
+void PruneForReverseReachability(Graph* g,
+ const std::unordered_set<const Node*>& nodes) {
+ std::unordered_set<const Node*> visited;
+
+ // Compute set of nodes that we need to traverse in order to reach
+ // the nodes in "nodes" by performing a breadth-first search from those
+ // nodes, and accumulating the visited nodes.
+ std::deque<const Node*> queue;
+ for (const Node* n : nodes) {
+ queue.push_back(n);
+ }
+ while (!queue.empty()) {
+ const Node* n = queue.front();
+ queue.pop_front();
+ if (visited.insert(n).second) {
+ for (const Node* in : n->in_nodes()) {
+ queue.push_back(in);
+ }
+ }
+ }
+
+ // Make a pass over the graph to remove nodes not in "visited"
+ std::vector<Node*> all_nodes;
+ for (Node* n : g->nodes()) {
+ all_nodes.push_back(n);
+ }
+
+ for (Node* n : all_nodes) {
+ if (visited.count(n) == 0 && !n->IsSource() && !n->IsSink()) {
+ g->RemoveNode(n);
+ }
+ }
+
+ // Reconnect nodes with no outgoing edges to the sink node
+ FixupSourceAndSinkEdges(g);
+}
+
+void FixupSourceAndSinkEdges(Graph* g) {
+ // Connect all nodes with no incoming edges to source.
+ // Connect all nodes with no outgoing edges to sink.
+ for (Node* n : g->nodes()) {
+ if (!n->IsSource() && n->in_edges().empty()) {
+ g->AddControlEdge(g->source_node(), n);
+ }
+ if (!n->IsSink() && n->out_edges().empty()) {
+ g->AddControlEdge(n, g->sink_node());
+ }
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/algorithm.h b/tensorflow/core/graph/algorithm.h
new file mode 100644
index 0000000000..58b74a0ace
--- /dev/null
+++ b/tensorflow/core/graph/algorithm.h
@@ -0,0 +1,40 @@
+#ifndef TENSORFLOW_GRAPH_ALGORITHM_H_
+#define TENSORFLOW_GRAPH_ALGORITHM_H_
+
+#include <functional>
+#include <unordered_set>
+
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+
+// Perform a depth-first-search on g starting at the source node.
+// If enter is not empty, calls enter(n) before visiting any children of n.
+// If leave is not empty, calls leave(n) after visiting all children of n.
+extern void DFS(const Graph& g, std::function<void(Node*)> enter,
+ std::function<void(Node*)> leave);
+
+// Stores in *order the post-order numbering of all nodes
+// in graph found via a depth first search starting at the source node.
+//
+// Note that this is equivalent to topological sorting when the
+// graph does not have cycles.
+//
+// REQUIRES: order is not NULL.
+void GetPostOrder(const Graph& g, std::vector<Node*>* order);
+
+// Stores in *order the reverse post-order numbering of all nodes
+void GetReversePostOrder(const Graph& g, std::vector<Node*>* order);
+
+// Prune nodes in "g" that are not in some path from the source node
+// to any node in 'nodes'.
+void PruneForReverseReachability(Graph* g,
+ const std::unordered_set<const Node*>& nodes);
+
+// Connect all nodes with no incoming edges to source.
+// Connect all nodes with no outgoing edges to sink.
+void FixupSourceAndSinkEdges(Graph* g);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_ALGORITHM_H_
diff --git a/tensorflow/core/graph/algorithm_test.cc b/tensorflow/core/graph/algorithm_test.cc
new file mode 100644
index 0000000000..48f2e1ebd7
--- /dev/null
+++ b/tensorflow/core/graph/algorithm_test.cc
@@ -0,0 +1,103 @@
+#include "tensorflow/core/graph/algorithm.h"
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/subgraph.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+// TODO(josh11b): Test setting the "device" field of a NodeDef.
+// TODO(josh11b): Test that feeding won't prune targets.
+
+namespace tensorflow {
+namespace {
+
+REGISTER_OP("TestParams").Output("o: float");
+REGISTER_OP("TestInput").Output("a: float").Output("b: float");
+REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
+
+// Compares that the order of nodes in 'inputs' respects the
+// pair orders described in 'ordered_pairs'.
+bool ExpectBefore(const std::vector<std::pair<string, string>>& ordered_pairs,
+ const std::vector<Node*>& inputs, string* error) {
+ for (const std::pair<string, string>& pair : ordered_pairs) {
+ const string& before_node = pair.first;
+ const string& after_node = pair.second;
+ bool seen_before = false;
+ bool seen_both = false;
+ for (const Node* node : inputs) {
+ if (!seen_before && after_node == node->name()) {
+ *error = strings::StrCat("Saw ", after_node, " before ", before_node);
+ return false;
+ }
+
+ if (before_node == node->name()) {
+ seen_before = true;
+ } else if (after_node == node->name()) {
+ seen_both = seen_before;
+ break;
+ }
+ }
+ if (!seen_both) {
+ *error = strings::StrCat("didn't see either ", before_node, " or ",
+ after_node);
+ return false;
+ }
+ }
+
+ return true;
+}
+
+TEST(AlgorithmTest, ReversePostOrder) {
+ RequireDefaultOps();
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Node* w1 = SourceOp("TestParams", b.opts().WithName("W1"));
+ Node* w2 = SourceOp("TestParams", b.opts().WithName("W2"));
+ Node* input =
+ SourceOp("TestInput", b.opts().WithName("input").WithControlInput(w1));
+ Node* t1 = BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t1"));
+ BinaryOp("TestMul", w1, {input, 1},
+ b.opts().WithName("t2").WithControlInput(t1));
+ BinaryOp("TestMul", w2, {input, 1}, b.opts().WithName("t3"));
+
+ Graph g(OpRegistry::Global());
+ ASSERT_OK(b.ToGraph(&g));
+ std::vector<Node*> order;
+
+ // Test reverse post order:
+ GetReversePostOrder(g, &order);
+
+ // Check that the order respects the dependencies correctly.
+ std::vector<std::pair<string, string>> reverse_orders = {
+ {"W1", "input"}, {"W1", "t1"}, {"W1", "t2"}, {"W1", "t3"},
+ {"input", "t1"}, {"input", "t3"}, {"t1", "t2"}, {"W2", "t3"}};
+ string error;
+ EXPECT_TRUE(ExpectBefore(reverse_orders, order, &error)) << error;
+
+ // A false ordering should fail the check.
+ reverse_orders = {{"input", "W1"}};
+ EXPECT_FALSE(ExpectBefore(reverse_orders, order, &error));
+
+ // Test post order:
+ GetPostOrder(g, &order);
+
+ // Check that the order respects the dependencies correctly.
+ std::vector<std::pair<string, string>> orders = {
+ {"input", "W1"}, {"t1", "W1"}, {"t2", "W1"}, {"t3", "W1"},
+ {"t1", "input"}, {"t3", "input"}, {"t2", "t1"}, {"t3", "W2"}};
+ EXPECT_TRUE(ExpectBefore(orders, order, &error)) << error;
+
+ // A false ordering should fail the check.
+ orders = {{"W1", "t3"}};
+ EXPECT_FALSE(ExpectBefore(orders, order, &error));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/colors.cc b/tensorflow/core/graph/colors.cc
new file mode 100644
index 0000000000..0eb2fc3740
--- /dev/null
+++ b/tensorflow/core/graph/colors.cc
@@ -0,0 +1,25 @@
+#include "tensorflow/core/graph/colors.h"
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// Color palette
+// http://www.mulinblog.com/a-color-palette-optimized-for-data-visualization/
+static const char* kColors[] = {
+ "#F15854", // red
+ "#5DA5DA", // blue
+ "#FAA43A", // orange
+ "#60BD68", // green
+ "#F17CB0", // pink
+ "#B2912F", // brown
+ "#B276B2", // purple
+ "#DECF3F", // yellow
+ "#4D4D4D", // gray
+};
+
+const char* ColorFor(int dindex) {
+ return kColors[dindex % TF_ARRAYSIZE(kColors)];
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/colors.h b/tensorflow/core/graph/colors.h
new file mode 100644
index 0000000000..150c8dc025
--- /dev/null
+++ b/tensorflow/core/graph/colors.h
@@ -0,0 +1,14 @@
+#ifndef TENSORFLOW_GRAPH_COLORS_H_
+#define TENSORFLOW_GRAPH_COLORS_H_
+
+namespace tensorflow {
+
+// Return a color drawn from a palette to represent an entity
+// identified by "i". The return value has the form "#RRGGBB" Note
+// that the palette has a limited set of colors and therefore colors
+// will be reused eventually.
+const char* ColorFor(int dindex);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_COLORS_H_
diff --git a/tensorflow/core/graph/costmodel.cc b/tensorflow/core/graph/costmodel.cc
new file mode 100644
index 0000000000..89bc41acfd
--- /dev/null
+++ b/tensorflow/core/graph/costmodel.cc
@@ -0,0 +1,308 @@
+#include "tensorflow/core/graph/costmodel.h"
+
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace {
+const Microseconds kDefaultTimeEstimate(1);
+const Microseconds kMinTimeEstimate(1);
+} // namespace
+
+void CostModel::SuppressInfrequent() {
+ // Find the median of the non-zero counts, and use half of its value
+ // as the cutoff for a "normal" execution mode node.
+ if (count_.empty()) return;
+ std::vector<int32> non_zero;
+ for (auto v : count_) {
+ if (v > 0) non_zero.push_back(v);
+ }
+ const size_t sz = non_zero.size();
+ if (sz > 0) {
+ std::nth_element(non_zero.begin(), non_zero.begin() + sz / 2,
+ non_zero.end());
+ int32 median_value = non_zero[sz / 2];
+ min_count_ = median_value / 2;
+ VLOG(1) << "num non_zero vals: " << non_zero.size() << " median_value "
+ << median_value;
+ } else {
+ min_count_ = 1;
+ }
+}
+
+void CostModel::MergeFromLocal(const Graph& g, const CostModel& cm) {
+ CHECK(is_global_);
+ CHECK(!cm.is_global());
+ for (const Node* n : g.nodes()) {
+ const int local_id = cm.Id(n);
+ const int global_id = Id(n);
+ if (local_id < 0 || global_id < 0) continue;
+ Ensure(global_id);
+ count_[global_id] += cm.count_[local_id];
+ time_[global_id] += cm.time_[local_id];
+ int num_slots = cm.slot_bytes_[local_id].size();
+ if (num_slots > 0) {
+ if (slot_bytes_[global_id].size() == 0) {
+ slot_bytes_[global_id].resize(num_slots);
+ } else {
+ CHECK_EQ(num_slots, slot_bytes_[global_id].size());
+ }
+ for (int s = 0; s < num_slots; ++s) {
+ slot_bytes_[global_id][s] += cm.slot_bytes_[local_id][s];
+ }
+ }
+ }
+}
+
+void CostModel::MergeFromGlobal(const CostModel& cm) {
+ CHECK(is_global_);
+ CHECK_EQ(true, cm.is_global());
+ const int num_nodes = cm.count_.size();
+ Ensure(num_nodes);
+ for (int i = 0; i < num_nodes; ++i) {
+ count_[i] += cm.count_[i];
+ time_[i] += cm.time_[i];
+ int num_slots = cm.slot_bytes_[i].size();
+ if (num_slots > 0) {
+ if (slot_bytes_[i].size() == 0) {
+ slot_bytes_[i].resize(num_slots);
+ } else {
+ CHECK_EQ(num_slots, slot_bytes_[i].size());
+ }
+ for (int s = 0; s < num_slots; ++s) {
+ slot_bytes_[i][s] += cm.slot_bytes_[i][s];
+ }
+ }
+ }
+}
+
+void CostModel::MergeFromStats(const NodeNameToCostIdMap& map,
+ const StepStats& ss) {
+ CHECK(is_global_);
+ for (auto& ds : ss.dev_stats()) {
+ for (auto& ns : ds.node_stats()) {
+ NodeNameToCostIdMap::const_iterator iter = map.find(ns.node_name());
+ // We don't keep stats for nodes not in the global graph, i.e.
+ // copy/send/recv nodes, feed/fetch, etc.
+ if (iter == map.end()) continue;
+ int32 global_id = iter->second;
+ Ensure(global_id);
+ int64 elapsed_micros = ns.op_end_rel_micros() - ns.op_start_rel_micros();
+ count_[global_id]++;
+ time_[global_id] += elapsed_micros;
+ for (auto& no : ns.output()) {
+ int si = no.slot();
+ if (static_cast<size_t>(si) >= slot_bytes_[global_id].size()) {
+ slot_bytes_[global_id].resize(1 + si);
+ }
+ slot_bytes_[global_id][si] +=
+ no.tensor_description().allocation_description().requested_bytes();
+ }
+ }
+ }
+}
+
+void CostModel::Ensure(int id) {
+ if (slot_bytes_.size() <= static_cast<size_t>(id)) {
+ slot_bytes_.resize(id + 1);
+ count_.resize(id + 1);
+ time_.resize(id + 1);
+ }
+}
+
+void CostModel::SetNumOutputs(const Node* node, int num_outputs) {
+ const int id = Id(node);
+ if (id < 0) return;
+ Ensure(id);
+ auto perslot = &slot_bytes_[id];
+ if (perslot->size() > 0) {
+ CHECK_EQ(num_outputs, perslot->size()) << "Cannot resize slot_bytes, node="
+ << node->name();
+ } else {
+ perslot->resize(num_outputs, Bytes(-1));
+ }
+}
+
+void CostModel::RecordCount(const Node* node, int count) {
+ const int id = Id(node);
+ if (id < 0) return;
+ CHECK_LT(id, slot_bytes_.size());
+ count_[id] += count;
+}
+
+int32 CostModel::TotalCount(const Node* node) const {
+ const int id = Id(node);
+ if (id < 0) return 0;
+ return (static_cast<size_t>(id) < slot_bytes_.size()) ? count_[id] : 0;
+}
+
+void CostModel::RecordSize(const Node* node, int slot, Bytes bytes) {
+ const int id = Id(node);
+ if (id < 0) return;
+ CHECK_LT(id, slot_bytes_.size());
+ auto perslot = &slot_bytes_[id];
+ CHECK_LT(slot, perslot->size());
+ auto v = &(*perslot)[slot];
+ if (*v >= 0) {
+ *v += bytes;
+ } else {
+ *v = bytes;
+ }
+}
+
+Bytes CostModel::TotalBytes(const Node* node, int slot) const {
+ const int id = Id(node);
+ if (id < 0 || static_cast<size_t>(id) >= slot_bytes_.size() ||
+ slot_bytes_[id].size() <= static_cast<size_t>(slot)) {
+ return Bytes(0);
+ }
+ return slot_bytes_[id][slot];
+}
+
+Bytes CostModel::SizeEstimate(const Node* node, int slot) const {
+ int32 count = TotalCount(node);
+ if (count < min_count_) return Bytes(0);
+ return TotalBytes(node, slot) / std::max(1, TotalCount(node));
+}
+
+void CostModel::RecordTime(const Node* node, Microseconds time) {
+ const int id = Id(node);
+ if (id < 0) return;
+ DCHECK(node->IsOp()) << node->DebugString();
+ Ensure(id);
+ time_[id] += time;
+}
+
+Microseconds CostModel::TotalTime(const Node* node) const {
+ DCHECK(node->IsOp()) << node->DebugString();
+ const int id = Id(node);
+ if (id < 0 || static_cast<size_t>(id) >= time_.size() ||
+ time_[id] < Microseconds(0)) {
+ return Microseconds(0);
+ }
+ return time_[id];
+}
+
+Microseconds CostModel::TimeEstimate(const Node* node) const {
+ int32 count = TotalCount(node);
+ if (count <= min_count_) return kMinTimeEstimate;
+ return std::max(kMinTimeEstimate, TotalTime(node) / std::max(1, count));
+}
+
+void CostModel::CheckInitialized(const Graph& graph) const {
+ for (const Node* n : graph.nodes()) {
+ if (n->IsOp()) {
+ CHECK(static_cast<size_t>(n->id()) < time_.size() &&
+ time_[n->id()] >= Microseconds(0))
+ << ": no time estimate for " << n->DebugString();
+
+ CHECK(static_cast<size_t>(n->id()) < slot_bytes_.size())
+ << ": no size estimate for " << n->DebugString();
+ const auto& perslot = slot_bytes_[n->id()];
+ for (size_t i = 0; i < perslot.size(); i++) {
+ CHECK_GE(perslot[i], Bytes(0)) << ": no size estimate for output# " << i
+ << " of " << n->DebugString();
+ }
+ }
+ }
+}
+
+Microseconds CostModel::CopyTimeEstimate(Bytes b, double network_latency_millis,
+ double estimated_gbps) {
+ // TODO(jeff,sanjay): estimate cost based on bandwidth along the
+ // communication path and the type of transport we are using between
+ // devices.
+ //
+ // We assume the copy time follows a linear model:
+ // copy_time = copy_bytes / rate + min_time
+ int64 copy_bytes = b.value();
+ const double bytes_per_usec = estimated_gbps * 1000.0 / 8;
+ const double min_micros = network_latency_millis * 1000.0;
+ return Microseconds(
+ static_cast<int64>(copy_bytes / bytes_per_usec + min_micros));
+}
+
+Microseconds CostModel::ComputationTimeEstimate(int64 math_ops) {
+ // TODO(jeff,sanjay): Eventually we should pass in the type of device
+ // (GPU vs. CPU) and use that to affect the estimate.
+
+ // We estimate the microseconds using that value. We divide
+ // by 1000 to convert the madd number into microseconds (assuming
+ // roughly 1000 madds per microsecond (~1 GHz for one core)).
+ return Microseconds(math_ops / 1000);
+}
+
+// ----------------------------------------------------------------------------
+// InitCostModel
+// ----------------------------------------------------------------------------
+
+namespace {
+
+static void AddNodesToCostModel(const Graph& g, CostModel* cost_model) {
+ for (Node* n : g.nodes()) {
+ const int num_outputs = n->num_outputs();
+ cost_model->SetNumOutputs(n, num_outputs);
+ for (int output = 0; output < num_outputs; output++) {
+ // Set up an initial bogus estimate for the node's outputs
+ cost_model->RecordSize(n, output, Bytes(1));
+ }
+ }
+}
+
+static void AssignSizes(const Graph& g, CostModel* cost_model) {
+ for (const Edge* e : g.edges()) {
+ // Skip if it is a control edge.
+ if (e->IsControlEdge()) {
+ continue;
+ }
+ Node* src = e->src();
+
+ // TODO(josh11b): Get an estimate from the Op
+ Bytes size(1);
+ cost_model->RecordSize(src, e->src_output(), size);
+ }
+}
+
+// This generates an extremely simple initial guess for the
+// computation cost of each node. For ordinary Ops, its value should quickly
+// be wiped out by the real runtime measurements. For other Ops we don't
+// actually generate measurements, so suppression of infrequent Ops ends up
+// giving them 0 costs. So, this is not of much consequence except perhaps
+// in tests.
+static Microseconds TimeEstimateForNode(CostModel* cost_model, Node* n) {
+ CHECK(n->IsOp());
+ VLOG(2) << "Node " << n->id() << ": " << n->name()
+ << " type_string: " << n->type_string();
+ if (IsConstant(n) || IsVariable(n)) {
+ return Microseconds(0);
+ }
+ return kDefaultTimeEstimate;
+}
+
+static void EstimateComputationCosts(const Graph& g, CostModel* cost_model) {
+ for (Node* n : g.nodes()) {
+ if (!n->IsOp()) continue;
+ cost_model->RecordTime(n, TimeEstimateForNode(cost_model, n));
+ }
+}
+
+} // namespace
+
+void CostModel::InitFromGraph(const Graph& g) {
+ AddNodesToCostModel(g, this);
+ AssignSizes(g, this);
+ EstimateComputationCosts(g, this);
+ CheckInitialized(g);
+}
+
+void CostModel::WriteToLog() {
+ LOG(INFO) << " min_count_=" << min_count_;
+ for (size_t i = 0; i < count_.size(); ++i) {
+ LOG(INFO) << "Node " << i << " count " << count_[i] << " total time "
+ << time_[i] << " avg time "
+ << (time_[i] / (std::max(1, count_[i])));
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/costmodel.h b/tensorflow/core/graph/costmodel.h
new file mode 100644
index 0000000000..4d7dd65f5a
--- /dev/null
+++ b/tensorflow/core/graph/costmodel.h
@@ -0,0 +1,123 @@
+#ifndef TENSORFLOW_GRAPH_COSTMODEL_H_
+#define TENSORFLOW_GRAPH_COSTMODEL_H_
+
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace tensorflow {
+typedef std::unordered_map<string, int32> NodeNameToCostIdMap;
+
+class StepStats;
+
+// CostModel keeps track of the following runtime statistics for nodes
+// of a single Graph:
+// * The total number of times a node has executed.
+// * The accumulated execution time (in microseconds) of a node.
+// * The accumulated size (in bytes) of each node's output.
+//
+// This class is NOT thread-safe.
+class CostModel {
+ public:
+ // If "global" is true, maintains costs based on Node::cost_id, otherwise
+ // maintains costs based on Node::id.
+ explicit CostModel(bool is_global) : is_global_(is_global) {}
+
+ // Assigns min_count_ as a function of the median count for a Node.
+ // This value is then used for suppressing the time/size costs of
+ // infrequent operations.
+ // NOTE(tucker): Maybe this should move to a subclass of CostModel.
+ void SuppressInfrequent();
+
+ bool is_global() const { return is_global_; }
+
+ // Initializes cost model for 'g'.
+ void InitFromGraph(const Graph& g);
+
+ // Merges costs from cm.
+ // REQUIRES: is_global_ is true for this and for "cm"
+ void MergeFromGlobal(const CostModel& cm);
+
+ // Merges costs from "cm", which has been computed relative to "g".
+ // REQUIRES: is_global_ is true for this, and false for "cm".
+ void MergeFromLocal(const Graph& g, const CostModel& cm);
+
+ void MergeFromStats(const NodeNameToCostIdMap& map, const StepStats& ss);
+
+ // Sets the number of outputs of "node".
+ void SetNumOutputs(const Node* node, int num_outputs);
+
+ // Records that "node" has executed "num_count" more times.
+ void RecordCount(const Node* node, int num_count);
+
+ // Returns how many times "node" has been executed.
+ int32 TotalCount(const Node* node) const;
+
+ // Records that "output_slot" of "node" has produced tensors of
+ // aggregated "bytes".
+ void RecordSize(const Node* node, int output_slot, Bytes bytes);
+
+ // Returns total bytes of tensors produced by "node"s output slot.
+ Bytes TotalBytes(const Node* node, int output_slot) const;
+
+ // Returns a prediction for the size of the tensor at the
+ // output_slot produced by one execution of "node".
+ Bytes SizeEstimate(const Node* node, int output_slot) const;
+
+ // Records that Executions of "node" have taken "time" microseconds.
+ void RecordTime(const Node* node, Microseconds time);
+
+ // Returns the total execution time for "node".
+ Microseconds TotalTime(const Node* node) const;
+
+ // Returns a prediction for one execution of "node".
+ Microseconds TimeEstimate(const Node* node) const;
+
+ // Check that an estimate is available for every OP node in graph.
+ void CheckInitialized(const Graph& graph) const;
+
+ // Helper routines to encapsulate static estimatation heuristics
+
+ // Compute an estimate of the time to copy "b" bytes over the network,
+ // given a fixed cost of "network_latency_millis" milliseconds and
+ // an estimated bandwidth of "estimated_gbps" gigabits per second (note that
+ // this value is in gigabits, not gigabytes).
+ static Microseconds CopyTimeEstimate(Bytes b, double network_latency_millis,
+ double estimated_gbps);
+ static Microseconds ComputationTimeEstimate(int64 mathops);
+
+ // Write the contents of the CostModel to the INFO log.
+ void WriteToLog();
+
+ private:
+ const bool is_global_;
+ inline int Id(const Node* n) const {
+ if (is_global_) {
+ return n->cost_id();
+ } else {
+ return n->id();
+ }
+ }
+ // Resizes vectors so that they are large enough for "id".
+ void Ensure(int id);
+
+ // Nodes and Edges whose count is < this value
+ // get type/byte estimates of 0.
+ int32 min_count_ = 0;
+
+ // Number of times each Node has been executed.
+ std::vector<int32> count_;
+ // Cumulative execution time.
+ std::vector<Microseconds> time_;
+ // Cumulative Bytes output on each channel.
+ std::vector<gtl::InlinedVector<Bytes, 2> > slot_bytes_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CostModel);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_COSTMODEL_H_
diff --git a/tensorflow/core/graph/costutil.cc b/tensorflow/core/graph/costutil.cc
new file mode 100644
index 0000000000..f8e2d9fe68
--- /dev/null
+++ b/tensorflow/core/graph/costutil.cc
@@ -0,0 +1,22 @@
+#include "tensorflow/core/graph/costutil.h"
+
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/costmodel.h"
+
+namespace tensorflow {
+
+std::vector<int64> LongestOutgoingPathCost(const Graph& graph,
+ const CostModel& cm) {
+ std::vector<int64> result(graph.num_node_ids());
+ DFS(graph, nullptr, [&result, &cm](Node* n) {
+ int64 max_child = 0;
+ for (const Node* out : n->out_nodes()) {
+ max_child = std::max(max_child, result[out->id()]);
+ }
+ result[n->id()] = max_child + (n->IsOp() ? cm.TimeEstimate(n).value() : 0);
+ });
+ return result;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/costutil.h b/tensorflow/core/graph/costutil.h
new file mode 100644
index 0000000000..46e5215132
--- /dev/null
+++ b/tensorflow/core/graph/costutil.h
@@ -0,0 +1,19 @@
+#ifndef TENSORFLOW_GRAPH_COSTUTIL_H_
+#define TENSORFLOW_GRAPH_COSTUTIL_H_
+
+#include <vector>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+class CostModel;
+class Graph;
+
+// result[i] is an estimate of the longest execution path from
+// the node with id i to the sink node.
+std::vector<int64> LongestOutgoingPathCost(const Graph& graph,
+ const CostModel& cm);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_COSTUTIL_H_
diff --git a/tensorflow/core/graph/default_device.h b/tensorflow/core/graph/default_device.h
new file mode 100644
index 0000000000..30cd4e8a57
--- /dev/null
+++ b/tensorflow/core/graph/default_device.h
@@ -0,0 +1,25 @@
+#ifndef TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_
+#define TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/graph.pb.h"
+
+namespace tensorflow {
+namespace graph {
+
+// Sets the default device for all nodes in graph_def to "device",
+// only if not already set.
+inline void SetDefaultDevice(const string& device, GraphDef* graph_def) {
+ for (int i = 0; i < graph_def->node_size(); ++i) {
+ auto node = graph_def->mutable_node(i);
+ if (node->device().empty()) {
+ node->set_device(device);
+ }
+ }
+}
+
+} // namespace graph
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_
diff --git a/tensorflow/core/graph/dot.cc b/tensorflow/core/graph/dot.cc
new file mode 100644
index 0000000000..6d6e46ce61
--- /dev/null
+++ b/tensorflow/core/graph/dot.cc
@@ -0,0 +1,289 @@
+#include "tensorflow/core/graph/dot.h"
+
+#include <map>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "tensorflow/core/graph/colors.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/util/util.h"
+
+namespace tensorflow {
+
+static string GraphNodeName(const DotOptions& opts, const Node* n) {
+ return strings::StrCat("N", n->id());
+}
+
+bool ShoulDisplayOpType(const Node* n) {
+ if (n->type_string() == "NoOp") {
+ return false;
+ }
+ const string& op_name = n->def().name();
+ if (op_name.find(n->type_string() + "_") == 0) {
+ return false;
+ }
+ return true;
+}
+
+string DotGraph(const Graph& g, const DotOptions& opts) {
+ RegexpStringPiece flag(opts.prefix_collapse_regexp);
+ if (flag == "all") {
+ flag = ".";
+ } else if (flag == "none") {
+ flag = "^$";
+ }
+ RE2 cluster_name_pattern(flag);
+ string result;
+ strings::StrAppend(&result, "digraph G {\n");
+ strings::StrAppend(&result, "rankdir=\"BT\"\n");
+
+ std::map<string, int> device_index; // Map from device name to index.
+ std::unordered_set<Node*> visible_nodes; // Nodes to display.
+ // Cluster name => set of nodes.
+ std::unordered_map<string, std::unordered_set<Node*> > clusters;
+ // Node* => Cluster
+ std::unordered_map<Node*, string> node_cluster;
+ for (Node* src : g.nodes()) {
+ if (opts.include_node_function != nullptr &&
+ !opts.include_node_function(src)) {
+ continue;
+ }
+ // Do not display source and sink nodes
+ if (src->IsSource() || src->IsSink()) {
+ continue;
+ }
+ visible_nodes.insert(src);
+ const string name_prefix = NodeNamePrefix(src->def().name()).ToString();
+ if (!name_prefix.empty()) {
+ clusters[name_prefix].insert(src);
+ node_cluster[src] = name_prefix;
+ }
+ // Record device if present.
+ if (src->IsOp()) {
+ const string& d = src->assigned_device_name();
+ if (!d.empty()) {
+ device_index[d] = -1; // Assigned later
+ }
+ }
+ }
+
+ // Add nodes whose name is exactly a cluster name to the cluster itself.
+ for (Node* src : g.nodes()) {
+ if (node_cluster.count(src) == 0) {
+ const string name = src->def().name();
+ auto it = clusters.find(name);
+ if (it != clusters.end()) {
+ it->second.insert(src);
+ node_cluster[src] = name;
+ }
+ }
+ }
+
+ auto node_in_collapsed_cluster = [&node_cluster,
+ &cluster_name_pattern](Node* n) {
+ return node_cluster.count(n) > 0 &&
+ RE2::PartialMatch(node_cluster[n], cluster_name_pattern);
+ };
+
+ // Assign device indices in sorted order.
+ int num = 0;
+ for (auto& e : device_index) {
+ e.second = num++;
+ }
+
+ double total_node_cost = 0;
+ double avg_node_cost = 1;
+ if (opts.node_cost) {
+ int node_count = 0;
+ for (const Node* n : g.nodes()) {
+ total_node_cost += opts.node_cost(n);
+ ++node_count;
+ }
+ if (total_node_cost > 0) avg_node_cost = total_node_cost / node_count;
+ }
+
+ for (Node* src : g.nodes()) {
+ if (visible_nodes.count(src) == 0 || node_in_collapsed_cluster(src)) {
+ continue;
+ }
+ string label = src->name();
+ if (ShoulDisplayOpType(src)) {
+ // Append the op type if it is not directly deducible from the op name.
+ strings::StrAppend(&label, "\\n(", src->type_string(), ")");
+ }
+ const char* shape = "box";
+ const char* color = nullptr;
+ if (src->IsSource()) {
+ shape = "oval";
+ } else if (src->IsSink()) {
+ shape = "oval";
+ } else {
+ const string& d = src->assigned_device_name();
+ const int dindex = (!d.empty()) ? device_index[d] : -1;
+ if (dindex >= 0) {
+ color = ColorFor(dindex);
+ }
+
+ shape = "box";
+ }
+
+ if (opts.node_label) {
+ string extra = opts.node_label(src);
+ if (!extra.empty()) {
+ strings::StrAppend(&label, "\\n", extra);
+ }
+ }
+
+ strings::StrAppend(&result, GraphNodeName(opts, src), "[shape=", shape,
+ ", label=\"", label, "\"");
+ if (opts.node_cost && total_node_cost > 0) {
+ // Pick fontsize in range [8..40] so that area is proportional to cost.
+ const double cost = opts.node_cost(src);
+ const double relcost = fabs(cost / avg_node_cost);
+ // Average cost node has font size of 12.
+ const int fs = 8 + static_cast<int>(4.0 * std::min(sqrt(relcost), 8.0));
+ strings::StrAppend(&result, ", width=0, height=0, fontsize=", fs);
+ VLOG(2) << "Node: " << cost << " => " << relcost << " => " << fs;
+ }
+ if (color != nullptr) {
+ strings::StrAppend(&result, ", fillcolor=\"", color,
+ "\", fontcolor=\"white\", style=\"filled\"");
+ }
+ strings::StrAppend(&result, "]\n");
+ }
+
+ for (auto c : clusters) {
+ const string& cluster_name = c.first;
+ const std::unordered_set<Node*> nodes = c.second;
+ std::unordered_map<string, int> node_colors;
+ for (auto n : nodes) {
+ const string& d = n->assigned_device_name();
+ const int dindex = (!d.empty()) ? device_index[d] : -1;
+ if (dindex >= 0) {
+ ++node_colors[ColorFor(dindex)];
+ }
+ }
+
+ string majority_color;
+ if (node_colors.empty()) {
+ majority_color = ColorFor(0);
+ } else {
+ majority_color = std::max_element(node_colors.begin(), node_colors.end(),
+ [](const std::pair<string, int>& x,
+ const std::pair<string, int>& y) {
+ return x.second < y.second;
+ })
+ ->first;
+ }
+
+ if (!RE2::PartialMatch(cluster_name, cluster_name_pattern)) {
+ strings::StrAppend(&result, "subgraph cluster_", cluster_name, "{\n");
+ for (auto n : nodes) {
+ strings::StrAppend(&result, GraphNodeName(opts, n), ";\n");
+ }
+ strings::StrAppend(&result, "}\n");
+ } else {
+ strings::StrAppend(&result, cluster_name, " [shape=oval, fillcolor=\"",
+ majority_color, "\", label=\"", cluster_name,
+ "\", style=\"filled\", fontcolor=\"white\"]\n");
+ }
+ }
+
+ std::unordered_set<string> edge_drawn;
+
+ double max_edge_cost = 0;
+ double total_edge_cost = 0;
+ double avg_edge_cost = 1;
+ if (opts.edge_cost && g.edges().size()) {
+ for (const Edge* e : g.edges()) {
+ auto cost = opts.edge_cost(e);
+ total_edge_cost += cost;
+ max_edge_cost = std::max(max_edge_cost, cost);
+ }
+ avg_edge_cost = total_edge_cost / g.edges().size();
+ }
+ VLOG(2) << "Edge cost tot/max/avg: " << total_edge_cost << "/"
+ << max_edge_cost << "/" << avg_edge_cost;
+
+ for (const Edge* e : g.edges()) {
+ Node* src = e->src();
+ Node* dst = e->dst();
+ // If either endpoint isn't drawn in the graph, don't draw the edge
+ if (visible_nodes.count(src) == 0 || visible_nodes.count(dst) == 0) {
+ continue;
+ }
+
+ const string src_name = node_in_collapsed_cluster(src)
+ ? node_cluster[src]
+ : GraphNodeName(opts, src);
+ const string dst_name = node_in_collapsed_cluster(dst)
+ ? node_cluster[dst]
+ : GraphNodeName(opts, dst);
+ // Don't draw self edges
+ if (src_name == dst_name) {
+ continue;
+ }
+ // And previously drawn edges.
+ const string& edge_name = strings::StrCat(src_name, ":", dst_name);
+ if (edge_drawn.count(edge_name) > 0) {
+ continue;
+ }
+ edge_drawn.insert(edge_name);
+
+ strings::StrAppend(&result, src_name, " -> ", dst_name, "[");
+ string label;
+ if (e->IsControlEdge()) {
+ strings::StrAppend(&result, " style=dotted");
+ }
+ if (opts.edge_label) {
+ string label = opts.edge_label(e);
+ if (!label.empty()) {
+ strings::StrAppend(&result, " label=<", label, ">");
+ }
+ }
+ // Make edge widths proportional to amount of data transferred.
+ if (opts.edge_cost && max_edge_cost > 0) {
+ const double cost = opts.edge_cost(e);
+ const double relcost = fabs(cost / avg_edge_cost);
+ // Pick penwidth in range [1..6] so that width is proportional to cost.
+ const int pw = 1 + std::min(5, static_cast<int>(2.0 * relcost));
+ strings::StrAppend(&result, " penwidth=", pw);
+ // Use weight attributes [1..100] to keep heavier edges more vertical.
+ const int weight = 1 + std::min(99, static_cast<int>(100.0 * relcost));
+ strings::StrAppend(&result, " weight=", weight);
+ VLOG(2) << "Edge: " << cost << " => " << relcost << " => " << pw << "/"
+ << weight;
+ }
+
+ strings::StrAppend(&result, "]\n");
+ }
+ // Compute some statistics
+ int op_nodes = 0;
+ for (Node* n : g.nodes()) {
+ if (n->IsOp()) {
+ op_nodes++;
+ }
+ }
+
+ // Emit legend
+ strings::StrAppend(&result,
+ "{ rank = source; Legend [shape=box, margin=0, label=<",
+ "<TABLE BORDER=\"0\" CELLBORDER=\"1\" CELLSPACING=\"0\" ",
+ "CELLPADDING=\"4\">", "<TR><TD COLSPAN=\"2\">op_nodes: ",
+ op_nodes, "</TD></TR>\n");
+ for (const auto& e : device_index) {
+ const int dindex = e.second;
+ strings::StrAppend(&result, "<TR><TD BGCOLOR=\"", ColorFor(dindex),
+ "\"><FONT COLOR=\"white\">", dindex, "</FONT></TD><TD>",
+ e.first, "</TD></TR>\n");
+ }
+ strings::StrAppend(&result, "</TABLE>>]}\n");
+
+ strings::StrAppend(&result, "}\n"); // End digraph
+ return result;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/dot.h b/tensorflow/core/graph/dot.h
new file mode 100644
index 0000000000..f87f68099c
--- /dev/null
+++ b/tensorflow/core/graph/dot.h
@@ -0,0 +1,43 @@
+#ifndef TENSORFLOW_GRAPH_DOT_H_
+#define TENSORFLOW_GRAPH_DOT_H_
+
+#include <functional>
+#include <string>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+class Edge;
+class Graph;
+class Node;
+
+struct DotOptions {
+ bool (*include_node_function)(const Node*) = nullptr;
+
+ // By default, all nodes with the same name prefix are collapsed into
+ // a single node in the dot graph. This regexp can be changed so that
+ // only prefixes that match the regexp are collapsed in this fashion.
+ // 'all' collapses all ops with prefixes, 'none' disables all collapsing.
+ string prefix_collapse_regexp = "all";
+
+ // A function that returns a label to embed into the per-node display.
+ std::function<string(const Node*)> node_label;
+
+ // A function that returns a label to attach to an edge.
+ std::function<string(const Edge*)> edge_label;
+
+ // A function that returns the "cost" of the node. The dot display
+ // makes a node size proportional to its cost.
+ std::function<double(const Node*)> node_cost;
+
+ // A function that returns the "cost" of the edge. The dot display
+ // makes a edge thickness proportional to its cost.
+ std::function<double(const Edge*)> edge_cost;
+};
+
+// Return a string that contains a graphviz specification of the graph.
+string DotGraph(const Graph& g, const DotOptions& opts);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_DOT_H_
diff --git a/tensorflow/core/graph/edgeset.cc b/tensorflow/core/graph/edgeset.cc
new file mode 100644
index 0000000000..83293c7b4e
--- /dev/null
+++ b/tensorflow/core/graph/edgeset.cc
@@ -0,0 +1,56 @@
+#include "tensorflow/core/graph/edgeset.h"
+
+namespace tensorflow {
+
+std::pair<EdgeSet::const_iterator, bool> EdgeSet::insert(value_type value) {
+ RegisterMutation();
+ const_iterator ci;
+ ci.Init(this);
+ auto s = get_set();
+ if (!s) {
+ for (int i = 0; i < kInline; i++) {
+ if (ptrs_[i] == value) {
+ ci.array_iter_ = &ptrs_[i];
+ return std::make_pair(ci, false);
+ }
+ }
+ for (int i = 0; i < kInline; i++) {
+ if (ptrs_[i] == nullptr) {
+ ptrs_[i] = value;
+ ci.array_iter_ = &ptrs_[i];
+ return std::make_pair(ci, true);
+ }
+ }
+ // array is full. convert to set.
+ s = new std::set<const Edge*>;
+ for (int i = 0; i < kInline; i++) {
+ s->insert(static_cast<const Edge*>(ptrs_[i]));
+ }
+ ptrs_[0] = this;
+ ptrs_[1] = s;
+ // fall through.
+ }
+ auto p = s->insert(value);
+ ci.tree_iter_ = p.first;
+ return std::make_pair(ci, p.second);
+}
+
+EdgeSet::size_type EdgeSet::erase(key_type key) {
+ RegisterMutation();
+ auto s = get_set();
+ if (!s) {
+ for (int i = 0; i < kInline; i++) {
+ if (ptrs_[i] == key) {
+ size_t n = size();
+ ptrs_[i] = ptrs_[n - 1];
+ ptrs_[n - 1] = nullptr;
+ return 1;
+ }
+ }
+ return 0;
+ } else {
+ return s->erase(key);
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/edgeset.h b/tensorflow/core/graph/edgeset.h
new file mode 100644
index 0000000000..df0d78b8fb
--- /dev/null
+++ b/tensorflow/core/graph/edgeset.h
@@ -0,0 +1,216 @@
+#ifndef TENSORFLOW_GRAPH_EDGESET_H_
+#define TENSORFLOW_GRAPH_EDGESET_H_
+
+#include <stddef.h>
+#include <set>
+#include "tensorflow/core/platform/port.h"
+
+#include "tensorflow/core/platform/logging.h"
+namespace tensorflow {
+
+class Edge;
+
+// An unordered set of edges. Uses very little memory for small sets.
+// Unlike std::set, EdgeSet does NOT allow mutations during iteration.
+class EdgeSet {
+ public:
+ EdgeSet();
+ ~EdgeSet();
+
+ typedef const Edge* key_type;
+ typedef const Edge* value_type;
+ typedef size_t size_type;
+ typedef ptrdiff_t difference_type;
+
+ class const_iterator;
+ typedef const_iterator iterator;
+
+ bool empty() const;
+ size_type size() const;
+ void clear();
+ std::pair<iterator, bool> insert(value_type value);
+ size_type erase(key_type key);
+
+ // Caller is not allowed to mutate the EdgeSet while iterating.
+ const_iterator begin() const;
+ const_iterator end() const;
+
+ private:
+ // Up to kInline elements are stored directly in ptrs_ (nullptr means none).
+ // If ptrs_[0] == this then ptrs_[1] points to a set<const Edge*>.
+ static const int kInline = 2; // Must be >= 2.
+ const void* ptrs_[kInline];
+
+ std::set<const Edge*>* get_set() const {
+ if (ptrs_[0] == this) {
+ return static_cast<std::set<const Edge*>*>(const_cast<void*>(ptrs_[1]));
+ } else {
+ return nullptr;
+ }
+ }
+
+// To detect mutations while iterating.
+#ifdef NDEBUG
+ void RegisterMutation() {}
+#else
+ uint32 mutations_ = 0;
+ void RegisterMutation() { mutations_++; }
+#endif
+
+ TF_DISALLOW_COPY_AND_ASSIGN(EdgeSet);
+};
+
+class EdgeSet::const_iterator {
+ public:
+ typedef typename EdgeSet::value_type value_type;
+ typedef const typename EdgeSet::value_type& reference;
+ typedef const typename EdgeSet::value_type* pointer;
+ typedef typename EdgeSet::difference_type difference_type;
+ typedef std::forward_iterator_tag iterator_category;
+
+ const_iterator() {}
+
+ const_iterator& operator++();
+ const_iterator operator++(int /*unused*/);
+ const value_type* operator->() const;
+ value_type operator*() const;
+ bool operator==(const const_iterator& other) const;
+ bool operator!=(const const_iterator& other) const {
+ return !(*this == other);
+ }
+
+ private:
+ friend class EdgeSet;
+
+ void const* const* array_iter_ = nullptr;
+ typename std::set<const Edge*>::const_iterator tree_iter_;
+
+#ifdef NDEBUG
+ inline void Init(const EdgeSet* e) {}
+ inline void CheckNoMutations() const {}
+#else
+ inline void Init(const EdgeSet* e) {
+ owner_ = e;
+ init_mutations_ = e->mutations_;
+ }
+ inline void CheckNoMutations() const {
+ CHECK_EQ(init_mutations_, owner_->mutations_);
+ }
+ const EdgeSet* owner_ = nullptr;
+ uint32 init_mutations_ = 0;
+#endif
+};
+
+inline EdgeSet::EdgeSet() {
+ for (int i = 0; i < kInline; i++) {
+ ptrs_[i] = nullptr;
+ }
+}
+
+inline EdgeSet::~EdgeSet() { delete get_set(); }
+
+inline bool EdgeSet::empty() const { return size() == 0; }
+
+inline EdgeSet::size_type EdgeSet::size() const {
+ auto s = get_set();
+ if (s) {
+ return s->size();
+ } else {
+ size_t result = 0;
+ for (int i = 0; i < kInline; i++) {
+ if (ptrs_[i]) result++;
+ }
+ return result;
+ }
+}
+
+inline void EdgeSet::clear() {
+ RegisterMutation();
+ delete get_set();
+ for (int i = 0; i < kInline; i++) {
+ ptrs_[i] = nullptr;
+ }
+}
+
+inline EdgeSet::const_iterator EdgeSet::begin() const {
+ const_iterator ci;
+ ci.Init(this);
+ auto s = get_set();
+ if (s) {
+ ci.tree_iter_ = s->begin();
+ } else {
+ ci.array_iter_ = &ptrs_[0];
+ }
+ return ci;
+}
+
+inline EdgeSet::const_iterator EdgeSet::end() const {
+ const_iterator ci;
+ ci.Init(this);
+ auto s = get_set();
+ if (s) {
+ ci.tree_iter_ = s->end();
+ } else {
+ ci.array_iter_ = &ptrs_[size()];
+ }
+ return ci;
+}
+
+inline EdgeSet::const_iterator& EdgeSet::const_iterator::operator++() {
+ CheckNoMutations();
+ if (array_iter_ != nullptr) {
+ ++array_iter_;
+ } else {
+ ++tree_iter_;
+ }
+ return *this;
+}
+
+inline EdgeSet::const_iterator EdgeSet::const_iterator::operator++(
+ int /*unused*/) {
+ CheckNoMutations();
+ const_iterator tmp = *this;
+ operator++();
+ return tmp;
+}
+
+// gcc's set and multiset always use const_iterator since it will otherwise
+// allow modification of keys.
+inline const EdgeSet::const_iterator::value_type* EdgeSet::const_iterator::
+operator->() const {
+ CheckNoMutations();
+ if (array_iter_ != nullptr) {
+ return reinterpret_cast<const value_type*>(array_iter_);
+ } else {
+ return tree_iter_.operator->();
+ }
+}
+
+// gcc's set and multiset always use const_iterator since it will otherwise
+// allow modification of keys.
+inline EdgeSet::const_iterator::value_type EdgeSet::const_iterator::operator*()
+ const {
+ CheckNoMutations();
+ if (array_iter_ != nullptr) {
+ return static_cast<value_type>(*array_iter_);
+ } else {
+ return *tree_iter_;
+ }
+}
+
+inline bool EdgeSet::const_iterator::operator==(
+ const const_iterator& other) const {
+ DCHECK((array_iter_ == nullptr) == (other.array_iter_ == nullptr))
+ << "Iterators being compared must be from same set that has not "
+ << "been modified since the iterator was constructed";
+ CheckNoMutations();
+ if (array_iter_ != nullptr) {
+ return array_iter_ == other.array_iter_;
+ } else {
+ return other.array_iter_ == nullptr && tree_iter_ == other.tree_iter_;
+ }
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_EDGESET_H_
diff --git a/tensorflow/core/graph/edgeset_test.cc b/tensorflow/core/graph/edgeset_test.cc
new file mode 100644
index 0000000000..7909e8ea0a
--- /dev/null
+++ b/tensorflow/core/graph/edgeset_test.cc
@@ -0,0 +1,95 @@
+#include "tensorflow/core/graph/edgeset.h"
+
+#include "tensorflow/core/graph/graph.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+class EdgeSetTest : public ::testing::Test {
+ public:
+ EdgeSetTest() : edges_(nullptr), eset_(nullptr) {}
+
+ ~EdgeSetTest() override {
+ delete eset_;
+ delete[] edges_;
+ }
+
+ void MakeEdgeSet(int n) {
+ delete eset_;
+ delete[] edges_;
+ edges_ = new Edge[n];
+ eset_ = new EdgeSet;
+ model_.clear();
+ for (int i = 0; i < n; i++) {
+ eset_->insert(&edges_[i]);
+ model_.insert(&edges_[i]);
+ }
+ }
+
+ void CheckSame() {
+ EXPECT_EQ(model_.size(), eset_->size());
+ EXPECT_EQ(model_.empty(), eset_->empty());
+ std::vector<const Edge*> modelv(model_.begin(), model_.end());
+ std::vector<const Edge*> esetv(eset_->begin(), eset_->end());
+ std::sort(modelv.begin(), modelv.end());
+ std::sort(esetv.begin(), esetv.end());
+ EXPECT_EQ(modelv.size(), esetv.size());
+ for (size_t i = 0; i < modelv.size(); i++) {
+ EXPECT_EQ(modelv[i], esetv[i]) << i;
+ }
+ }
+
+ Edge nonexistent_;
+ Edge* edges_;
+ EdgeSet* eset_;
+ std::set<const Edge*> model_;
+};
+
+namespace {
+
+TEST_F(EdgeSetTest, Ops) {
+ for (int n : {0, 1, 2, 3, 4, 10}) {
+ MakeEdgeSet(n);
+ CheckSame();
+ EXPECT_EQ((n == 0), eset_->empty());
+ EXPECT_EQ(n, eset_->size());
+
+ eset_->clear();
+ model_.clear();
+ CheckSame();
+
+ eset_->insert(&edges_[0]);
+ model_.insert(&edges_[0]);
+ CheckSame();
+ }
+}
+
+// Try insert/erase of existing elements at different positions.
+TEST_F(EdgeSetTest, Exists) {
+ for (int n : {0, 1, 2, 3, 4, 10}) {
+ MakeEdgeSet(n);
+ for (int pos = 0; pos < n; pos++) {
+ MakeEdgeSet(n);
+ auto p = eset_->insert(&edges_[pos]);
+ EXPECT_FALSE(p.second);
+ EXPECT_EQ(&edges_[pos], *p.first);
+
+ EXPECT_EQ(1, eset_->erase(&edges_[pos]));
+ model_.erase(&edges_[pos]);
+ CheckSame();
+ }
+ }
+}
+
+// Try insert/erase of non-existent element.
+TEST_F(EdgeSetTest, DoesNotExist) {
+ for (int n : {0, 1, 2, 3, 4, 10}) {
+ MakeEdgeSet(n);
+ EXPECT_EQ(0, eset_->erase(&nonexistent_));
+ auto p = eset_->insert(&nonexistent_);
+ EXPECT_TRUE(p.second);
+ EXPECT_EQ(&nonexistent_, *p.first);
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/equal_graph_def.cc b/tensorflow/core/graph/equal_graph_def.cc
new file mode 100644
index 0000000000..35f59b5ed0
--- /dev/null
+++ b/tensorflow/core/graph/equal_graph_def.cc
@@ -0,0 +1,176 @@
+#include "tensorflow/core/graph/equal_graph_def.h"
+
+#include <unordered_map>
+#include <unordered_set>
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
+ string* diff) {
+ std::unordered_map<string, const NodeDef*> actual_index;
+ for (const NodeDef& node : actual.node()) {
+ actual_index[node.name()] = &node;
+ }
+
+ for (const NodeDef& expected_node : expected.node()) {
+ auto actual_iter = actual_index.find(expected_node.name());
+ if (actual_iter == actual_index.end()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Did not find expected node '",
+ SummarizeNodeDef(expected_node), "'");
+ }
+ return false;
+ }
+
+ if (!EqualNodeDef(*actual_iter->second, expected_node, diff)) return false;
+
+ actual_index.erase(actual_iter);
+ }
+
+ if (!actual_index.empty()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Found unexpected node '",
+ SummarizeNodeDef(*actual_index.begin()->second),
+ "' not in expected graph:\n",
+ SummarizeGraphDef(expected));
+ }
+ return false;
+ }
+
+ return true;
+}
+
+namespace {
+
+string JoinStringField(const protobuf::RepeatedPtrField<string>& f) {
+ string ret;
+ for (int i = 0; i < f.size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, f.Get(i));
+ }
+ return ret;
+}
+
+} // namespace
+
+bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected,
+ string* diff) {
+ if (actual.name() != expected.name()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Actual node name '", actual.name(),
+ "' is not expected '", expected.name(), "'");
+ }
+ return false;
+ }
+
+ if (actual.op() != expected.op()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Node named '", actual.name(), "' has op '",
+ actual.op(), "' that is not expected '",
+ expected.op(), "'");
+ }
+ return false;
+ }
+
+ if (actual.device() != expected.device()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Node named '", actual.name(), "' has device '",
+ actual.device(), "' that is not expected '",
+ expected.device(), "'");
+ }
+ return false;
+ }
+
+ if (actual.input_size() != expected.input_size()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Node named '", actual.name(), "' has inputs '",
+ JoinStringField(actual.input()),
+ "' that don't match expected '",
+ JoinStringField(expected.input()), "'");
+ }
+ return false;
+ }
+
+ int first_control_input = actual.input_size();
+ for (int i = 0; i < actual.input_size(); ++i) {
+ if (StringPiece(actual.input(i)).starts_with("^")) {
+ first_control_input = i;
+ break;
+ }
+ if (actual.input(i) != expected.input(i)) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Node named '", actual.name(), "' has input ",
+ i, " '", actual.input(i),
+ "' that doesn't match expected '",
+ expected.input(i), "'");
+ }
+ return false;
+ }
+ }
+
+ std::unordered_set<string> actual_control;
+ std::unordered_set<string> expected_control;
+ for (int i = first_control_input; i < actual.input_size(); ++i) {
+ actual_control.insert(actual.input(i));
+ expected_control.insert(expected.input(i));
+ }
+ for (const auto& e : expected_control) {
+ if (actual_control.erase(e) == 0) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Node named '", actual.name(),
+ "' missing expected control input '", e, "'");
+ }
+ return false;
+ }
+ }
+ if (!actual_control.empty()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Node named '", actual.name(),
+ "' has unexpected control input '",
+ *actual_control.begin(), "'");
+ }
+ return false;
+ }
+
+ std::unordered_set<string> actual_attr;
+ for (const auto& a : actual.attr()) {
+ actual_attr.insert(a.first);
+ }
+ for (const auto& e : expected.attr()) {
+ if (actual_attr.erase(e.first) == 0) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Node named '", actual.name(),
+ "' missing expected attr '", e.first,
+ "' with value: ", SummarizeAttrValue(e.second));
+ }
+ return false;
+ }
+ auto iter = actual.attr().find(e.first);
+ if (!AreAttrValuesEqual(e.second, iter->second)) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat(
+ "Node named '", actual.name(), "' has attr '", e.first,
+ "' with value: ", SummarizeAttrValue(iter->second),
+ " that does not match expected: ", SummarizeAttrValue(e.second));
+ }
+ return false;
+ }
+ }
+ if (!actual_attr.empty()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat(
+ "Node named '", actual.name(), "' has unexpected attr '",
+ *actual_attr.begin(), "' with value: ",
+ SummarizeAttrValue(actual.attr().find(*actual_attr.begin())->second));
+ }
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/equal_graph_def.h b/tensorflow/core/graph/equal_graph_def.h
new file mode 100644
index 0000000000..7dd8aab340
--- /dev/null
+++ b/tensorflow/core/graph/equal_graph_def.h
@@ -0,0 +1,32 @@
+#ifndef TENSORFLOW_GRAPH_EQUAL_GRAPH_DEF_H_
+#define TENSORFLOW_GRAPH_EQUAL_GRAPH_DEF_H_
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// Determines if actual and expected are equal, ignoring ordering of
+// nodes, attrs, and control inputs. If the GraphDefs are different
+// and diff != nullptr, *diff is set to an explanation of the
+// difference. Note that we use node names to match up nodes between
+// the graphs, and so the naming of nodes must be consistent.
+bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
+ string* diff);
+
+// Determines if actual and expected are equal, ignoring ordering of
+// attrs and control inputs. If the NodeDefs are different and
+// diff != nullptr, *diff is set to an explanation of the difference.
+bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff);
+
+#define TF_EXPECT_GRAPH_EQ(expected, actual) \
+ do { \
+ string diff; \
+ EXPECT_TRUE(EqualGraphDef(actual, expected, &diff)) \
+ << diff << "\nActual: " << SummarizeGraphDef(actual); \
+ } while (false)
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_EQUAL_GRAPH_DEF_H_
diff --git a/tensorflow/core/graph/equal_graph_def_test.cc b/tensorflow/core/graph/equal_graph_def_test.cc
new file mode 100644
index 0000000000..3a38b9e522
--- /dev/null
+++ b/tensorflow/core/graph/equal_graph_def_test.cc
@@ -0,0 +1,279 @@
+#include "tensorflow/core/graph/equal_graph_def.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+REGISTER_OP("Input").Output("o: float");
+REGISTER_OP("Alternate").Output("o: float");
+REGISTER_OP("Cross").Input("a: float").Input("b: float").Output("o: float");
+
+Node* Input(const GraphDefBuilder::Options& opts) {
+ return ops::SourceOp("Input", opts);
+}
+
+Node* Alternate(const GraphDefBuilder::Options& opts) {
+ return ops::SourceOp("Alternate", opts);
+}
+
+Node* Cross(ops::NodeOut a, ops::NodeOut b,
+ const GraphDefBuilder::Options& opts) {
+ return ops::BinaryOp("Cross", a, b, opts);
+}
+
+class EqualGraphDefTest : public ::testing::Test {
+ protected:
+ EqualGraphDefTest()
+ : e_(GraphDefBuilder::kFailImmediately),
+ a_(GraphDefBuilder::kFailImmediately) {
+ RequireDefaultOps();
+ }
+
+ bool Match() {
+ GraphDef expected;
+ e_.ToGraphDef(&expected);
+ GraphDef actual;
+ a_.ToGraphDef(&actual);
+ return EqualGraphDef(actual, expected, &diff_);
+ }
+
+ GraphDefBuilder e_;
+ GraphDefBuilder a_;
+ string diff_;
+};
+
+TEST_F(EqualGraphDefTest, Match) {
+ Input(e_.opts().WithName("A"));
+ Input(a_.opts().WithName("A"));
+ EXPECT_TRUE(Match()) << diff_;
+}
+
+TEST_F(EqualGraphDefTest, NoMatch) {
+ Input(e_.opts().WithName("A"));
+ Input(a_.opts().WithName("B"));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ("Did not find expected node 'A = Input[]()'", diff_);
+}
+
+TEST_F(EqualGraphDefTest, MissingNode) {
+ Input(e_.opts().WithName("A"));
+ Input(e_.opts().WithName("B"));
+ Input(a_.opts().WithName("A"));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ("Did not find expected node 'B = Input[]()'", diff_);
+}
+
+TEST_F(EqualGraphDefTest, ExtraNode) {
+ Input(e_.opts().WithName("A"));
+ Input(a_.opts().WithName("A"));
+ Input(a_.opts().WithName("B"));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ(
+ "Found unexpected node 'B = Input[]()' not in expected graph:\n"
+ "A = Input[]();\n",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, NodeOrder) {
+ Node* a = Input(e_.opts().WithName("A"));
+ Node* b = Input(e_.opts().WithName("B"));
+ Cross(a, b, e_.opts().WithName("C"));
+
+ b = Input(a_.opts().WithName("B"));
+ a = Input(a_.opts().WithName("A"));
+ Cross(a, b, a_.opts().WithName("C"));
+ EXPECT_TRUE(Match()) << diff_;
+}
+
+TEST_F(EqualGraphDefTest, NameMismatch) {
+ Node* a = Input(e_.opts().WithName("A"));
+ Node* b = Input(e_.opts().WithName("B"));
+ // Have to call EqualNodeDef() directly here, since EqualGraphDef()
+ // only calls EqualNodeDef() with nodes that have matching names.
+ EXPECT_FALSE(EqualNodeDef(a->def(), b->def(), &diff_));
+ EXPECT_EQ("Actual node name 'A' is not expected 'B'", diff_);
+}
+
+TEST_F(EqualGraphDefTest, OpMismatch) {
+ Input(e_.opts().WithName("A"));
+ Alternate(a_.opts().WithName("A"));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ("Node named 'A' has op 'Alternate' that is not expected 'Input'",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, DeviceMatch) {
+ Input(e_.opts().WithName("A").WithDevice("/cpu:0"));
+ Input(a_.opts().WithName("A").WithDevice("/cpu:0"));
+ EXPECT_TRUE(Match()) << diff_;
+}
+
+TEST_F(EqualGraphDefTest, DeviceMismatch) {
+ Input(e_.opts().WithName("A").WithDevice("/cpu:0"));
+ Input(a_.opts().WithName("A").WithDevice("/cpu:1"));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ("Node named 'A' has device '/cpu:1' that is not expected '/cpu:0'",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, InputMismatch) {
+ Node* a = Input(e_.opts().WithName("A"));
+ Node* b = Input(e_.opts().WithName("B"));
+ Cross(a, a, e_.opts().WithName("C"));
+
+ a = Input(a_.opts().WithName("A"));
+ b = Input(a_.opts().WithName("B"));
+ Cross(b, b, a_.opts().WithName("C"));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ("Node named 'C' has input 0 'B' that doesn't match expected 'A'",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, InputOrderMismatch) {
+ Node* a = Input(e_.opts().WithName("A"));
+ Node* b = Input(e_.opts().WithName("B"));
+ Cross(a, b, e_.opts().WithName("C"));
+
+ a = Input(a_.opts().WithName("A"));
+ b = Input(a_.opts().WithName("B"));
+ Cross(b, a, a_.opts().WithName("C"));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ("Node named 'C' has input 0 'B' that doesn't match expected 'A'",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, ControlInputOrder) {
+ Node* a = Input(e_.opts().WithName("A"));
+ Node* b = Input(e_.opts().WithName("B"));
+ Node* c = Input(e_.opts().WithName("C"));
+ Node* d = Input(e_.opts().WithName("D"));
+ Cross(a, a, e_.opts()
+ .WithName("E")
+ .WithControlInput(b)
+ .WithControlInput(c)
+ .WithControlInput(d));
+
+ a = Input(a_.opts().WithName("A"));
+ b = Input(a_.opts().WithName("B"));
+ c = Input(a_.opts().WithName("C"));
+ d = Input(a_.opts().WithName("D"));
+ Cross(a, a, a_.opts()
+ .WithName("E")
+ .WithControlInput(c)
+ .WithControlInput(d)
+ .WithControlInput(b));
+ EXPECT_TRUE(Match()) << diff_;
+}
+
+TEST_F(EqualGraphDefTest, ControlInputMismatch) {
+ Node* a = Input(e_.opts().WithName("A"));
+ Node* b = Input(e_.opts().WithName("B"));
+ Node* c = Input(e_.opts().WithName("C"));
+ Node* d = Input(e_.opts().WithName("D"));
+ Cross(a, a, e_.opts().WithName("E").WithControlInput(b).WithControlInput(c));
+
+ a = Input(a_.opts().WithName("A"));
+ b = Input(a_.opts().WithName("B"));
+ c = Input(a_.opts().WithName("C"));
+ d = Input(a_.opts().WithName("D"));
+ Cross(a, a, a_.opts().WithName("E").WithControlInput(b).WithControlInput(d));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ("Node named 'E' missing expected control input '^C'", diff_);
+}
+
+TEST_F(EqualGraphDefTest, ControlInputAdded) {
+ Node* a = Input(e_.opts().WithName("A"));
+ Node* b = Input(e_.opts().WithName("B"));
+ Node* c = Input(e_.opts().WithName("C"));
+ Cross(a, a, e_.opts().WithName("D").WithControlInput(b));
+
+ a = Input(a_.opts().WithName("A"));
+ b = Input(a_.opts().WithName("B"));
+ c = Input(a_.opts().WithName("C"));
+ Cross(a, a, a_.opts().WithName("D").WithControlInput(b).WithControlInput(c));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ(
+ "Node named 'D' has inputs 'A, A, ^B, ^C' that don't match "
+ "expected 'A, A, ^B'",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, ControlInputRemoved) {
+ Node* a = Input(e_.opts().WithName("A"));
+ Node* b = Input(e_.opts().WithName("B"));
+ Node* c = Input(e_.opts().WithName("C"));
+ Cross(a, a, e_.opts().WithName("D").WithControlInput(b).WithControlInput(c));
+
+ a = Input(a_.opts().WithName("A"));
+ b = Input(a_.opts().WithName("B"));
+ c = Input(a_.opts().WithName("C"));
+ Cross(a, a, a_.opts().WithName("D").WithControlInput(b));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ(
+ "Node named 'D' has inputs 'A, A, ^B' that don't match "
+ "expected 'A, A, ^B, ^C'",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, Attr) {
+ Node* a = Input(e_.opts().WithName("A"));
+ NodeDef same(a->def());
+ AddNodeAttr("foo", "bar", &same);
+ EXPECT_TRUE(EqualNodeDef(same, same, &diff_)) << diff_;
+}
+
+TEST_F(EqualGraphDefTest, AttrAdded) {
+ Node* a = Input(e_.opts().WithName("A"));
+ NodeDef actual(a->def());
+ AddNodeAttr("foo", "bar", &actual);
+ EXPECT_FALSE(EqualNodeDef(actual, a->def(), &diff_));
+ EXPECT_EQ("Node named 'A' has unexpected attr 'foo' with value: \"bar\"",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, AttrRemoved) {
+ Node* a = Input(e_.opts().WithName("A"));
+ NodeDef expected(a->def());
+ AddNodeAttr("foo", "bar", &expected);
+ EXPECT_FALSE(EqualNodeDef(a->def(), expected, &diff_));
+ EXPECT_EQ("Node named 'A' missing expected attr 'foo' with value: \"bar\"",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, AttrOrder) {
+ Node* a = Input(e_.opts().WithName("A"));
+ NodeDef actual(a->def());
+ AddNodeAttr("foo", "bar", &actual);
+ AddNodeAttr("baz", 42, &actual);
+
+ NodeDef expected(a->def());
+ AddNodeAttr("baz", 42, &expected);
+ AddNodeAttr("foo", "bar", &expected);
+
+ EXPECT_TRUE(EqualNodeDef(actual, expected, &diff_)) << diff_;
+}
+
+TEST_F(EqualGraphDefTest, AttrMismatch) {
+ Node* a = Input(e_.opts().WithName("A"));
+ NodeDef actual(a->def());
+ AddNodeAttr("foo", "bar", &actual);
+ AddNodeAttr("baz", 5, &actual);
+
+ NodeDef expected(a->def());
+ AddNodeAttr("baz", 42, &expected);
+ AddNodeAttr("foo", "bar", &expected);
+
+ EXPECT_FALSE(EqualNodeDef(actual, expected, &diff_));
+ EXPECT_EQ(
+ "Node named 'A' has attr 'baz' with value: 5 that does not match "
+ "expected: 42",
+ diff_);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
new file mode 100644
index 0000000000..0c268a51a9
--- /dev/null
+++ b/tensorflow/core/graph/graph.cc
@@ -0,0 +1,319 @@
+#include "tensorflow/core/graph/graph.h"
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+// Node
+
+string Node::DebugString() const {
+ if (this == nullptr) {
+ return "{nullptr}";
+ }
+ string ret = strings::StrCat("{name:'", name(), "' id:", id_);
+ if (IsSource()) {
+ strings::StrAppend(&ret, " source}");
+ } else if (IsSink()) {
+ strings::StrAppend(&ret, " sink}");
+ } else {
+ strings::StrAppend(&ret, " op device:");
+ strings::StrAppend(&ret, "{", assigned_device_name_, "}");
+ strings::StrAppend(&ret, " def:{", SummarizeNodeDef(def()), "}}");
+ }
+ return ret;
+}
+
+Node::Node()
+ : id_(-1), cost_id_(-1), props_(nullptr), assigned_device_name_() {}
+
+Node::~Node() {
+ if (props_) {
+ props_->Unref();
+ }
+}
+
+void Node::Initialize(int id, int cost_id, Properties* props) {
+ DCHECK_EQ(id_, -1);
+ DCHECK(in_edges_.empty());
+ DCHECK(out_edges_.empty());
+ id_ = id;
+ cost_id_ = cost_id;
+
+ // Unref the old, assign the new properties.
+ if (props_) {
+ props_->Unref();
+ }
+ props_ = props;
+}
+
+void Node::Clear() {
+ in_edges_.clear();
+ out_edges_.clear();
+ id_ = -1;
+ cost_id_ = -1;
+
+ if (props_) {
+ props_->Unref();
+ props_ = nullptr;
+ }
+
+ assigned_device_name_.clear();
+}
+
+gtl::iterator_range<NeighborIter> Node::out_nodes() const {
+ return gtl::make_range(NeighborIter(out_edges_.begin(), false),
+ NeighborIter(out_edges_.end(), false));
+}
+
+gtl::iterator_range<NeighborIter> Node::in_nodes() const {
+ return gtl::make_range(NeighborIter(in_edges_.begin(), true),
+ NeighborIter(in_edges_.end(), true));
+}
+
+// Node::Properties
+
+Node::Properties::Properties(const OpDef* op_def, const NodeDef& node_def,
+ const DataTypeSlice inputs,
+ const DataTypeSlice outputs)
+ : op_def_(op_def),
+ node_def_(node_def),
+ input_types_(inputs.begin(), inputs.end()),
+ output_types_(outputs.begin(), outputs.end()) {}
+
+Node::Properties::~Properties() {}
+
+// Graph
+
+Graph::Graph(const OpRegistryInterface* ops)
+ : ops_(ops), arena_(8 << 10 /* 8kB */) {
+ // Source and sink have no endpoints, just control edges.
+ NodeDef def;
+ def.set_name("_SOURCE");
+ def.set_op("NoOp");
+ Status status;
+ Node* source = AddNode(def, &status);
+ TF_CHECK_OK(status);
+ CHECK_EQ(source->id(), kSourceId);
+
+ def.set_name("_SINK");
+ Node* sink = AddNode(def, &status);
+ TF_CHECK_OK(status);
+ CHECK_EQ(sink->id(), kSinkId);
+
+ AddControlEdge(source, sink);
+}
+
+Graph::~Graph() {
+ // Manually call the destructors for all the Nodes we constructed using
+ // placement new.
+ for (Node* node : nodes_) {
+ if (node != nullptr) {
+ node->~Node();
+ }
+ }
+ for (Node* node : free_nodes_) {
+ node->~Node();
+ }
+ // Edges have no destructor, and we arena-allocated them, so no need to
+ // destroy them.
+}
+
+Node* Graph::AddNode(const NodeDef& node_def, Status* status) {
+ const OpDef* op_def = ops_->LookUp(node_def.op(), status);
+ if (op_def == nullptr) return nullptr;
+
+ // TODO(vrv,josh11b): Find a location higher in the stack to add these defaults
+ // to the NodeDef.
+ NodeDef node_def_with_defaults(node_def);
+ AddDefaultsToNodeDef(*op_def, &node_def_with_defaults);
+
+ DataTypeVector inputs;
+ DataTypeVector outputs;
+ status->Update(
+ InOutTypesForNode(node_def_with_defaults, *op_def, &inputs, &outputs));
+ if (!status->ok()) {
+ *status = AttachDef(*status, node_def_with_defaults);
+ return nullptr;
+ }
+
+ Node* node = AllocateNode(
+ new Node::Properties(op_def, node_def_with_defaults, inputs, outputs),
+ nullptr);
+ return node;
+}
+
+Node* Graph::CopyNode(Node* node) {
+ DCHECK(!node->IsSource());
+ DCHECK(!node->IsSink());
+ Node::Properties* props = node->properties();
+ props->Ref();
+ Node* copy = AllocateNode(props, node);
+ copy->set_assigned_device_name(node->assigned_device_name());
+ return copy;
+}
+
+void Graph::RemoveNode(Node* node) {
+ DCHECK(IsValidNode(node)) << node->DebugString();
+ DCHECK(!node->IsSource());
+ DCHECK(!node->IsSink());
+
+ // Remove any edges involving this node.
+ while (!node->in_edges_.empty()) {
+ RemoveEdge(*node->in_edges_.begin());
+ }
+ while (!node->out_edges_.empty()) {
+ RemoveEdge(*node->out_edges_.begin());
+ }
+ ReleaseNode(node);
+}
+
+const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
+ DCHECK(IsValidNode(source)) << source->DebugString();
+ DCHECK(IsValidNode(dest)) << dest->DebugString();
+
+ // source/sink must only be linked via control slots, and
+ // control slots must only be linked to control slots.
+ if (source == source_node() || dest == sink_node() || x == kControlSlot ||
+ y == kControlSlot) {
+ DCHECK_EQ(x, kControlSlot) << source->DebugString();
+ DCHECK_EQ(y, kControlSlot) << dest->DebugString();
+ }
+
+ Edge* e = nullptr;
+ if (free_edges_.empty()) {
+ e = new (arena_.Alloc(sizeof(Edge))) Edge; // placement new
+ } else {
+ e = free_edges_.back();
+ free_edges_.pop_back();
+ }
+ e->id_ = edges_.size();
+ e->src_ = source;
+ e->dst_ = dest;
+ e->src_output_ = x;
+ e->dst_input_ = y;
+ CHECK(source->out_edges_.insert(e).second);
+ CHECK(dest->in_edges_.insert(e).second);
+ edges_.push_back(e);
+ edge_set_.insert(e);
+ return e;
+}
+
+void Graph::RemoveEdge(const Edge* e) {
+ DCHECK(IsValidNode(e->src_)) << e->src_->DebugString();
+ DCHECK(IsValidNode(e->dst_)) << e->dst_->DebugString();
+ CHECK_EQ(e->src_->out_edges_.erase(e), 1);
+ CHECK_EQ(e->dst_->in_edges_.erase(e), 1);
+ CHECK_EQ(e, edges_[e->id_]);
+
+ CHECK_EQ(edge_set_.erase(e), 1);
+ edges_[e->id_] = nullptr;
+
+ Edge* del = const_cast<Edge*>(e);
+ del->src_ = nullptr;
+ del->dst_ = nullptr;
+ del->id_ = -1;
+ del->src_output_ = kControlSlot - 1;
+ del->dst_input_ = kControlSlot - 1;
+ free_edges_.push_back(del);
+}
+
+namespace {
+
+void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
+ if (src_slot == Graph::kControlSlot) {
+ dst->add_input(strings::StrCat("^", src_name));
+ } else if (src_slot == 0) {
+ dst->add_input(src_name.data(), src_name.size());
+ } else {
+ dst->add_input(strings::StrCat(src_name, ":", src_slot));
+ }
+}
+
+} // namespace
+
+void Graph::ToGraphDef(GraphDef* graph_def) const {
+ graph_def->Clear();
+ std::vector<const Edge*>
+ inputs; // Construct this outside the loop for speed.
+ for (const Node* node : nodes()) {
+ if (!node->IsOp()) continue;
+ NodeDef* node_def = graph_def->add_node();
+ *node_def = node->def();
+
+ // Use the node's assigned device, if any, instead of the device requested
+ // in the NodeDef.
+ if (!node->assigned_device_name().empty()) {
+ node_def->set_device(node->assigned_device_name());
+ }
+
+ // Get the inputs for this Node. We make sure control inputs are
+ // after data inputs, as required by GraphDef.
+ inputs.clear();
+ inputs.resize(node->num_inputs(), nullptr);
+ for (const Edge* edge : node->in_edges()) {
+ if (edge->IsControlEdge()) {
+ inputs.push_back(edge);
+ } else {
+ DCHECK(inputs[edge->dst_input()] == nullptr);
+ inputs[edge->dst_input()] = edge;
+ }
+ }
+ node_def->clear_input();
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ const Edge* edge = inputs[i];
+ if (edge == nullptr) {
+ node_def->add_input(node->def().input(i));
+ } else {
+ const Node* src = edge->src();
+ if (!src->IsOp()) continue;
+ AddInput(node_def, src->name(), edge->src_output());
+ }
+ }
+ }
+}
+
+string Graph::NewName(StringPiece prefix) {
+ return strings::StrCat(prefix, "/_", name_counter_++);
+}
+
+gtl::iterator_range<NodeIter> Graph::nodes() const {
+ // Note that NodeId 0 is always valid since we don't let the source
+ // node be removed from the graph.
+ return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
+}
+
+bool Graph::IsValidNode(Node* node) const {
+ if (node == nullptr) return false;
+ const int id = node->id();
+ if (id < 0 || static_cast<size_t>(id) >= nodes_.size()) return false;
+ return nodes_[id] == node;
+}
+
+Node* Graph::AllocateNode(Node::Properties* props, const Node* cost_node) {
+ Node* node = nullptr;
+ if (free_nodes_.empty()) {
+ node = new (arena_.Alloc(sizeof(Node))) Node; // placement new
+ } else {
+ node = free_nodes_.back();
+ free_nodes_.pop_back();
+ }
+ const int id = nodes_.size();
+ int cost_id = cost_node ? cost_node->cost_id() : id;
+ node->Initialize(id, cost_id, props);
+ nodes_.push_back(node);
+ return node;
+}
+
+void Graph::ReleaseNode(Node* node) {
+ DCHECK(IsValidNode(node)) << node->DebugString();
+ nodes_[node->id()] = nullptr;
+ free_nodes_.push_back(node);
+ node->Clear();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
new file mode 100644
index 0000000000..030e471bf4
--- /dev/null
+++ b/tensorflow/core/graph/graph.h
@@ -0,0 +1,440 @@
+// A Graph describes a set of computations that are to be
+// performed, as well as the dependencies between those
+// compuations. The basic model is a DAG (directed acyclic graph) with
+// * internal nodes representing computational operations to be performed;
+// * edges represent dependencies, indicating the target may only be
+// executed once the source has completed; and
+// * predefined "source" (start) and "sink" (finish) nodes -- the source
+// should be the only node that doesn't depend on anything, and the sink
+// should be the only node that nothing depends on.
+//
+// Note: Node ids are intended to be relatively dense in the
+// 0..max_id range, but there may be gaps since ids won't be reused.
+//
+// Note: Some dependencies between operations are due to one operation
+// consuming the output of another. In fact operations can produce
+// multiple outputs and consume multiple inputs, and some
+// optimizations will care about which specific outputs are connected
+// to which specific inputs. We therefore represent data dependency
+// between output O of layer A and input I of layer B using
+// "input index" and "output index" labels per edge.
+
+#ifndef TENSORFLOW_GRAPH_GRAPH_H_
+#define TENSORFLOW_GRAPH_GRAPH_H_
+
+#include <functional>
+#include <string>
+#include <vector>
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/edgeset.h"
+#include "tensorflow/core/lib/core/arena.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/gtl/iterator_range.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+class Edge;
+class EdgeSetTest;
+class Graph;
+class Node;
+
+class NeighborIter; // Declared below
+class NodeIter; // Declared below
+
+class Node {
+ public:
+ string DebugString() const;
+ int id() const { return id_; }
+ int cost_id() const { return cost_id_; }
+ const string& name() const { return props_->node_def_.name(); }
+ const string& type_string() const { return props_->node_def_.op(); }
+ const NodeDef& def() const { return props_->node_def_; }
+ const OpDef& op_def() const { return *props_->op_def_; }
+
+ // input and output types
+ int num_inputs() const { return props_->input_types_.size(); }
+ DataType input_type(int i) const { return props_->input_types_[i]; }
+ const DataTypeVector& input_types() const { return props_->input_types_; }
+
+ int num_outputs() const { return props_->output_types_.size(); }
+ DataType output_type(int o) const { return props_->output_types_[o]; }
+ const DataTypeVector& output_types() const { return props_->output_types_; }
+
+ // This gives the device the runtime has assigned this node to. If
+ // you want the device the user requested, use def().device() instead.
+ // TODO(josh11b): Validate that the assigned_device, if not empty:
+ // fully specifies a device, and satisfies def().device().
+ // TODO(josh11b): Move device_name outside of Node into a NodeId->DeviceName
+ // map.
+ string assigned_device_name() const { return assigned_device_name_; }
+ void set_assigned_device_name(const string& device_name) {
+ assigned_device_name_ = device_name;
+ }
+
+ // Get the neighboring nodes via edges either in or out of this node.
+ gtl::iterator_range<NeighborIter> in_nodes() const;
+ gtl::iterator_range<NeighborIter> out_nodes() const;
+ const EdgeSet& in_edges() const { return in_edges_; }
+ const EdgeSet& out_edges() const { return out_edges_; }
+
+ // Node type helpers.
+ bool IsSource() const { return id() == 0; }
+ bool IsSink() const { return id() == 1; }
+ // Anything other than the special Source & Sink nodes.
+ bool IsOp() const { return id() > 1; }
+
+ private:
+ friend class Graph;
+ Node();
+ ~Node();
+
+ class Properties : public core::RefCounted {
+ public:
+ Properties(const OpDef* op_def, const NodeDef& node_def,
+ const DataTypeSlice inputs, const DataTypeSlice outputs);
+
+ const OpDef* op_def_; // not owned
+ const NodeDef node_def_;
+ const DataTypeVector input_types_;
+ const DataTypeVector output_types_;
+
+ private:
+ // Destructor invoked when last reference goes away via Unref()
+ virtual ~Properties();
+ TF_DISALLOW_COPY_AND_ASSIGN(Properties);
+ };
+
+ Properties* properties() const { return props_; }
+
+ // Initialize() adopts a reference to props, and so is suitable if props was
+ // just allocated or you call props->Ref() to increment the reference
+ // count for a props being held by another Node.
+ void Initialize(int id, int cost_id, Properties* props);
+ // Releases memory from props_, in addition to restoring *this to its
+ // uninitialized state.
+ void Clear();
+
+ int id_; // -1 until Initialize() is called
+ int cost_id_; // -1 if there is no corresponding cost accounting node
+
+ EdgeSet in_edges_;
+ EdgeSet out_edges_;
+
+ Properties* props_;
+
+ // Name of device assigned to perform this computation.
+ string assigned_device_name_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Node);
+};
+
+class Edge {
+ public:
+ Node* src() const { return src_; }
+ Node* dst() const { return dst_; }
+ int id() const { return id_; }
+
+ // Return the number of the source output that produces the data
+ // carried by this edge. The special value kControlSlot is used
+ // for control dependencies.
+ int src_output() const { return src_output_; }
+
+ // Return the number of the destination input that consumes the data
+ // carried by this edge. The special value kControlSlot is used
+ // for control dependencies.
+ int dst_input() const { return dst_input_; }
+
+ // Return true iff this is an edge that indicates a control-flow
+ // (as opposed to a data-flow) dependency.
+ bool IsControlEdge() const;
+
+ private:
+ Edge() {}
+
+ friend class EdgeSetTest;
+ friend class Graph;
+ Node* src_;
+ Node* dst_;
+ int id_;
+ int src_output_;
+ int dst_input_;
+};
+
+// Thread compatible but not thread safe.
+class Graph {
+ public:
+ // Constructs a graph with a single SOURCE (always id kSourceId) and a
+ // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
+ //
+ // The graph can hold ops found in registry.
+ explicit Graph(const OpRegistryInterface* registry);
+ ~Graph();
+
+ static const int kControlSlot = -1;
+
+ // Adds a new node to this graph, and returns it. Infers the Op and
+ // input/output types for the node. *this owns the returned instance.
+ // Returns nullptr and sets *status on error.
+ Node* AddNode(const NodeDef& node_def, Status* status);
+
+ // Copies *node, which may belong to another graph, to a new node,
+ // which is returned. Does not copy any edges. *this owns the
+ // returned instance.
+ Node* CopyNode(Node* node);
+
+ // Remove a node from this graph, including all edges from or to it.
+ // *node should not be accessed after calling this function.
+ // REQUIRES: node->IsOp()
+ void RemoveNode(Node* node);
+
+ // Add an edge that connects the xth output of "source" to the yth input
+ // of "dest".
+ const Edge* AddEdge(Node* source, int x, Node* dest, int y);
+
+ // Add a control-edge (no data flows along this edge) that
+ // connects "source" to "dest".
+ const Edge* AddControlEdge(Node* source, Node* dest) {
+ return AddEdge(source, kControlSlot, dest, kControlSlot);
+ }
+
+ // Removes edge from the graph.
+ // REQUIRES: The edge must exist.
+ void RemoveEdge(const Edge* edge);
+
+ // Returns one more than the maximum id assigned to any node.
+ int num_node_ids() const { return nodes_.size(); }
+
+ // Serialize to a GraphDef.
+ void ToGraphDef(GraphDef* graph_def) const;
+
+ // Generate new node name with the specified prefix that is unique
+ // across this graph.
+ string NewName(StringPiece prefix);
+
+ // Access to the list of all nodes. Example usage:
+ // for (Node* node : graph.nodes()) { ... }
+ gtl::iterator_range<NodeIter> nodes() const;
+
+ // Returns the node associated with an id, or nullptr if no node
+ // with that id (the node with that id was removed and the id has
+ // not yet been re-used). *this owns the returned instance.
+ // REQUIRES: 0 <= id < num_node_ids().
+ Node* FindNodeId(int id) const { return nodes_[id]; }
+
+ // Returns one more than the maximum id assigned to any edge.
+ int num_edge_ids() const { return edges_.size(); }
+
+ // Returns the Edge associated with an id, or nullptr if no edge
+ // with that id (the node with that id was removed and the id has
+ // not yet been re-used). *this owns the returned instance.
+ // REQUIRES: 0 <= id < num_node_ids().
+ const Edge* FindEdgeId(int id) const { return edges_[id]; }
+
+ // Access to the set of all edges. Example usage:
+ // for (const Edge* e : graph.edges()) { ... }
+ const EdgeSet& edges() const { return edge_set_; }
+
+ // The pre-defined nodes.
+ enum { kSourceId = 0, kSinkId = 1 };
+ Node* source_node() const { return FindNodeId(kSourceId); }
+ Node* sink_node() const { return FindNodeId(kSinkId); }
+
+ const OpRegistryInterface* op_registry() const { return ops_; }
+
+ // TODO(josh11b): uint64 hash() const;
+
+ private:
+ bool IsValidNode(Node* node) const;
+ // If cost_node is non-null, then cost accounting (in CostModel)
+ // will be associated with that node rather than the new one being
+ // created.
+ Node* AllocateNode(Node::Properties* props, const Node* cost_node);
+ void ReleaseNode(Node* node);
+
+ // Registry of all known ops. Not owned.
+ const OpRegistryInterface* const ops_;
+
+ // Allocator which will give us good locality.
+ core::Arena arena_;
+
+ // Map from node ids to allocated nodes. nodes_[id] may be nullptr if
+ // the node with that id was removed from the graph.
+ std::vector<Node*> nodes_;
+
+ // Map from edge ids to allocated edges. edges_[id] may be nullptr if
+ // the edge with that id was removed from the graph.
+ std::vector<Edge*> edges_;
+
+ // For ease of iteration, we currently just keep a set of all live
+ // edges. May want to optimize by removing this copy.
+ EdgeSet edge_set_;
+
+ // Allocated but free nodes and edges.
+ std::vector<Node*> free_nodes_;
+ std::vector<Edge*> free_edges_;
+
+ // For generating unique names.
+ int name_counter_ = 0;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Graph);
+};
+
+// TODO(josh11b): We may want to support keeping an index on various
+// node/edge attributes in a graph, particularly node names.
+
+// Helper routines
+
+inline bool IsSwitch(const Node* node) {
+ return node->type_string() == "Switch" || node->type_string() == "RefSwitch";
+}
+
+inline bool IsMerge(const Node* node) { return node->type_string() == "Merge"; }
+
+inline bool IsEnter(const Node* node) {
+ return node->type_string() == "Enter" || node->type_string() == "RefEnter";
+}
+
+inline bool IsExit(const Node* node) { return node->type_string() == "Exit"; }
+
+inline bool IsNextIteration(const Node* node) {
+ return node->type_string() == "NextIteration";
+}
+
+inline bool IsLoopCond(const Node* node) {
+ return node->type_string() == "LoopCond";
+}
+
+inline bool IsControlTrigger(const Node* node) {
+ return node->type_string() == "ControlTrigger";
+}
+
+inline bool IsSend(const Node* node) {
+ return node->type_string() == "_Send" || node->type_string() == "_HostSend";
+}
+
+inline bool IsRecv(const Node* node) {
+ return node->type_string() == "_Recv" || node->type_string() == "_HostRecv";
+}
+
+// True for Nodes that mediate the transfer of values between processes.
+inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); }
+
+inline bool IsConstant(const Node* node) {
+ return node->type_string() == "Const" || node->type_string() == "HostConst";
+}
+
+inline bool IsVariable(const Node* node) {
+ return node->type_string() == "Variable";
+}
+
+inline bool IsIdentity(const Node* node) {
+ return (node->type_string() == "Identity" ||
+ node->type_string() == "RefIdentity");
+}
+
+// Returns true iff 'n' is a control flow node.
+inline bool IsControlFlow(const Node* n) {
+ return IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n) ||
+ IsNextIteration(n);
+}
+
+inline bool IsHostMemoryPreserving(const Node* node) {
+ return IsIdentity(node) || IsControlFlow(node);
+}
+
+// Iterator for stepping through the nodes of a graph.
+class NodeIter {
+ public:
+ NodeIter(const Graph* graph, int id);
+ bool operator==(const NodeIter& rhs);
+ bool operator!=(const NodeIter& rhs);
+ void operator++();
+ Node* operator*();
+ Node* operator->();
+
+ private:
+ // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr
+ const Graph* graph_;
+ int id_;
+};
+
+// Iterator for stepping through the neighbors of a node.
+class NeighborIter {
+ public:
+ NeighborIter(EdgeSet::const_iterator iter, bool incoming);
+ bool operator==(const NeighborIter& rhs);
+ bool operator!=(const NeighborIter& rhs);
+ void operator++();
+ Node* operator*();
+ Node* operator->();
+
+ private:
+ EdgeSet::const_iterator iter_;
+ bool incoming_;
+};
+
+// IMPLEMENTATION DETAILS, PLEASE IGNORE
+
+inline NodeIter::NodeIter(const Graph* graph, int id)
+ : graph_(graph), id_(id) {}
+
+inline bool NodeIter::operator==(const NodeIter& rhs) {
+ DCHECK(graph_ == rhs.graph_);
+ return id_ == rhs.id_;
+}
+
+inline bool NodeIter::operator!=(const NodeIter& rhs) {
+ return !(*this == rhs);
+}
+
+inline void NodeIter::operator++() {
+ while (1) {
+ DCHECK_LE(id_, graph_->num_node_ids());
+ ++id_;
+ if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) {
+ return;
+ }
+ }
+}
+
+inline Node* NodeIter::operator*() { return graph_->FindNodeId(id_); }
+
+inline Node* NodeIter::operator->() { return graph_->FindNodeId(id_); }
+
+inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming)
+ : iter_(iter), incoming_(incoming) {}
+
+inline bool NeighborIter::operator==(const NeighborIter& rhs) {
+ return iter_ == rhs.iter_ && incoming_ == rhs.incoming_;
+}
+
+inline bool NeighborIter::operator!=(const NeighborIter& rhs) {
+ return !(*this == rhs);
+}
+
+inline void NeighborIter::operator++() { ++iter_; }
+
+inline Node* NeighborIter::operator*() {
+ const Edge* e = *iter_;
+ return incoming_ ? e->src() : e->dst();
+}
+
+inline Node* NeighborIter::operator->() {
+ const Edge* e = *iter_;
+ return incoming_ ? e->src() : e->dst();
+}
+
+inline bool Edge::IsControlEdge() const {
+ // Note that if either src_output_ or dst_input_ is kControlSlot,
+ // so is the other one (AddEdge checks this).
+ return src_output_ == Graph::kControlSlot;
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_GRAPH_H_
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
new file mode 100644
index 0000000000..3928348f0a
--- /dev/null
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -0,0 +1,385 @@
+#include "tensorflow/core/graph/graph_constructor.h"
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/optimizer_cse.h"
+#include "tensorflow/core/graph/tensor_id.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace tensorflow {
+
+namespace {
+inline bool IsMerge(const NodeDef& node_def) {
+ return node_def.op() == "Merge";
+}
+} // namespace
+
+namespace {
+
+class GraphConstructor {
+ public:
+ GraphConstructor(const GraphConstructorOptions& opts, const GraphDef* gdef,
+ Graph* g, Status* status)
+ : opts_(opts), gdef_(gdef), g_(g), status_(status) {
+ BuildNodeIndex();
+ InitFromEdges();
+ Convert();
+ }
+
+ private:
+ void SetError(const string& error);
+ void SetNodeError(const NodeDef& node_def, const StringPiece& message) {
+ SetError(strings::StrCat("Node '", node_def.name(), "': ", message));
+ }
+ void BuildNodeIndex();
+ void InitFromEdges();
+ Node* MakeNode(const NodeDef& node_def);
+ void Convert();
+ // Calls SetError() and returns false if the type of the output of
+ // the source of the edge can't be consumed by destination of the edge.
+ // REQUIRES: edge must be a data edge, not a control edge.
+ bool TypeValidateEdge(const Edge* edge);
+
+ // From constructor
+ const GraphConstructorOptions opts_;
+ const GraphDef* gdef_;
+ Graph* g_;
+ Status* status_;
+
+ // Mapping from node name to the index within gdef_
+ struct NodeInfo {
+ explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
+ // std::unordered_map<> requires that we have a default constructor.
+ NodeInfo() : NodeInfo(-1) {}
+ int gdef_index;
+ Node* node; // nullptr until the NodeDef is converted to a Node.
+ };
+ // TODO(vrv): Profile this data structure to see if we should use an
+ // alternative implementation of std::unordered_map.
+ std::unordered_map<StringPiece, NodeInfo, StringPiece::Hasher> name_index_;
+
+ // Index of NodeDefs in gdef_ with all inputs already converted.
+ std::vector<int> ready_;
+
+ // Mapping between index within gdef_ and the number of inputs that
+ // still need to be converted.
+ std::vector<int> pending_count_;
+
+ // Mapping between index within gdef_ and the index within gdef_ of
+ // all nodes it outputs to.
+ std::vector<gtl::InlinedVector<int, 4>> outputs_;
+
+ // Used in the conversion from gdef_ to g_ to represent the ith input
+ // of a node.
+ struct InputInfo {
+ explicit InputInfo(StringPiece node_name, Node* n, int i)
+ : name(node_name), node(n), index(i) {}
+ StringPiece name;
+ Node* node;
+ int index;
+ };
+
+ // Used in the conversion from gdef_ to g_ to represent an edge from
+ // the node named 'name' to node 'n'.
+ struct EdgeInfo {
+ explicit EdgeInfo(StringPiece name, int i1, Node* n, int i2)
+ : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {}
+ StringPiece src_name;
+ int src_index;
+ Node* dst_node;
+ int dst_index;
+ };
+};
+
+void GraphConstructor::SetError(const string& error) {
+ status_->Update(errors::InvalidArgument(error));
+}
+
+void GraphConstructor::BuildNodeIndex() {
+ // Initialized outside the loop for efficiency
+ const char* pattern;
+ if (opts_.allow_internal_ops) {
+ pattern = "[A-Za-z0-9._][A-Za-z0-9_.\\-/]*";
+ } else {
+ pattern = "[A-Za-z0-9.][A-Za-z0-9_.\\-/]*";
+ }
+ RE2 node_name_re(pattern);
+
+ // Validate the node names and add them to name_index_.
+ for (int n = 0; n < gdef_->node_size(); ++n) {
+ const NodeDef& node_def(gdef_->node(n));
+ if (!RE2::FullMatch(node_def.name(), node_name_re)) {
+ SetNodeError(node_def, "Node name contains invalid characters");
+ return;
+ }
+ if (!name_index_.insert(std::make_pair(StringPiece(node_def.name()),
+ NodeInfo(n)))
+ .second) {
+ SetNodeError(node_def, "Node name is not unique");
+ return;
+ }
+ // Validate the operation's type.
+ if (node_def.op().empty()) {
+ SetNodeError(node_def, "Does not specify a type");
+ return;
+ }
+ if (opts_.expect_device_spec && node_def.device().empty()) {
+ SetNodeError(node_def, strings::StrCat("Missing device specification."));
+ return;
+ }
+ }
+}
+
+void GraphConstructor::InitFromEdges() {
+ const int num_nodes = gdef_->node_size();
+ ready_.reserve(num_nodes);
+ pending_count_.reserve(num_nodes);
+ outputs_.resize(num_nodes);
+
+ // Parse the inputs for each node.
+ for (int n = 0; n < num_nodes; ++n) {
+ const NodeDef& node_def(gdef_->node(n));
+ if (IsMerge(node_def)) {
+ // for merge only wait for one non-control input.
+ int32 num_control_edges = 0;
+ for (int i = 0; i < node_def.input_size(); ++i) {
+ StringPiece input_name(node_def.input(i));
+ if (StringPiece(input_name).starts_with("^")) {
+ num_control_edges++;
+ }
+ }
+ pending_count_.push_back(num_control_edges + 1);
+ } else {
+ pending_count_.push_back(node_def.input_size());
+ }
+ if (node_def.input_size() == 0) {
+ ready_.push_back(n);
+ continue;
+ }
+ for (int i = 0; i < node_def.input_size(); ++i) {
+ StringPiece input_name = node_def.input(i);
+ if (input_name.starts_with("^")) {
+ // Control dependence
+ input_name.remove_prefix(1);
+ }
+ TensorId id(ParseTensorName(input_name));
+ auto iter = name_index_.find(id.first);
+ if (iter == name_index_.end()) {
+ SetNodeError(node_def,
+ strings::StrCat("Unknown input node ", node_def.input(i)));
+ return;
+ }
+ outputs_[iter->second.gdef_index].push_back(n);
+ }
+ }
+}
+
+Node* GraphConstructor::MakeNode(const NodeDef& node_def) {
+ // Add the node to the graph.
+ Node* node = g_->AddNode(node_def, status_);
+ if (node == nullptr) return nullptr;
+ if (opts_.expect_device_spec) {
+ node->set_assigned_device_name(node_def.device());
+ }
+ name_index_[node_def.name()].node = node;
+ return node;
+}
+
+// Return the number of nodes in "g"
+static int CountNodes(Graph* g) {
+ int nodes = 0;
+ for (Node* node : g->nodes()) {
+ VLOG(1) << node; // Dummy use to avoid compiler warning
+ nodes++;
+ }
+ return nodes;
+}
+
+void GraphConstructor::Convert() {
+ std::vector<InputInfo> inputs;
+ std::vector<EdgeInfo> back_edges;
+ int processed = 0;
+ // Process the NodeDefs in topological order.
+ while (!ready_.empty()) {
+ int o = ready_.back();
+ ready_.pop_back();
+ ++processed;
+ const NodeDef& node_def(gdef_->node(o));
+ inputs.clear();
+ bool in_control_dependence = false;
+ bool has_data_back_edge = false;
+ for (int i = 0; i < node_def.input_size(); ++i) {
+ StringPiece input_name(node_def.input(i));
+ if (StringPiece(input_name).starts_with("^")) {
+ // A control dependence
+ in_control_dependence = true;
+ input_name.remove_prefix(1);
+ } else {
+ if (in_control_dependence) {
+ SetNodeError(node_def, strings::StrCat(
+ "Control dependencies must come after ",
+ "regular dependencies: input ", input_name,
+ " of source node ", node_def.name()));
+ return;
+ }
+ }
+ TensorId id(ParseTensorName(input_name));
+ auto iter = name_index_.find(id.first);
+ DCHECK(iter != name_index_.end());
+ Node* src_node = iter->second.node;
+ if (in_control_dependence) {
+ inputs.push_back(InputInfo(id.first, src_node, -1));
+ } else {
+ if (src_node == nullptr) {
+ has_data_back_edge = true;
+ inputs.push_back(InputInfo(id.first, src_node, id.second));
+ } else {
+ if (id.second >= src_node->num_outputs()) {
+ SetNodeError(
+ node_def,
+ strings::StrCat("Connecting to invalid output ", id.second,
+ " of source node ", id.first, " which has ",
+ src_node->num_outputs(), " outputs"));
+ return;
+ }
+ inputs.push_back(InputInfo(id.first, src_node, id.second));
+ }
+ }
+ }
+ if (has_data_back_edge && !IsMerge(node_def)) {
+ SetError(strings::StrCat(
+ node_def.name(),
+ " had a back edge. But only Merge can have back edges."));
+ return;
+ }
+
+ Node* node = MakeNode(node_def);
+ if (node == nullptr) return;
+
+ // Add edges from inputs to *node to the graph.
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ if (inputs[i].node == nullptr) {
+ // Record this back edge, which will be added after all nodes
+ // are created.
+ back_edges.push_back(
+ EdgeInfo(inputs[i].name, inputs[i].index, node, i));
+ } else if (inputs[i].index == -1) {
+ g_->AddControlEdge(inputs[i].node, node);
+ } else {
+ const Edge* edge =
+ g_->AddEdge(inputs[i].node, inputs[i].index, node, i);
+ if (!TypeValidateEdge(edge)) return;
+ }
+ }
+
+ // Update pending_count_ for outputs.
+ for (size_t i = 0; i < outputs_[o].size(); ++i) {
+ const int output = outputs_[o][i];
+ pending_count_[output]--;
+ if (pending_count_[output] == 0) {
+ ready_.push_back(output);
+ }
+ }
+ }
+
+ // Add the back edges after all nodes are created.
+ for (auto e : back_edges) {
+ Node* src_node = name_index_[e.src_name].node;
+ if (e.src_index == -1) {
+ g_->AddControlEdge(src_node, e.dst_node);
+ } else {
+ const Edge* edge =
+ g_->AddEdge(src_node, e.src_index, e.dst_node, e.dst_index);
+ if (!TypeValidateEdge(edge)) return;
+ }
+
+ VLOG(2) << "Add back edge: " << src_node->name() << " -> "
+ << e.dst_node->name();
+ }
+
+ if (processed < gdef_->node_size()) {
+ SetError(
+ strings::StrCat(gdef_->node_size() - processed, " nodes in a cycle"));
+ return;
+ }
+
+ if (status_->ok()) {
+ FixupSourceAndSinkEdges(g_);
+
+ if (opts_.optimizer_do_cse) {
+ if (!back_edges.empty()) {
+ LOG(WARNING) << "Not doing CSE. We need to figure out how to handle "
+ << "loops in the CSE phase.";
+ } else {
+ VLOG(1) << "Starting CSE: graph of " << CountNodes(g_) << " nodes";
+ OptimizeCSE(g_, opts_.cse_consider_function);
+ VLOG(1) << "Finished CSE: graph of " << CountNodes(g_) << " nodes";
+ }
+ }
+ }
+}
+
+bool GraphConstructor::TypeValidateEdge(const Edge* edge) {
+ DataType src_out = edge->src()->output_type(edge->src_output());
+ DataType dst_in = edge->dst()->input_type(edge->dst_input());
+ if (!TypesCompatible(dst_in, src_out)) {
+ SetError(strings::StrCat(
+ "Input ", edge->dst_input(), " of node ", edge->dst()->name(),
+ " was passed ", DataTypeString(src_out), " from ", edge->src()->name(),
+ ":", edge->src_output(), " incompatible with expected ",
+ DataTypeString(dst_in), "."));
+ return false;
+ }
+ return true;
+}
+
+} // namespace
+
+// ----------------------------------------------------------------------------
+// ConvertGraphDefToGraph
+// ----------------------------------------------------------------------------
+
+Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
+ const GraphDef& gdef, Graph* g) {
+ Status status;
+ GraphConstructor constructor(opts, &gdef, g, &status);
+ return status;
+}
+
+// ----------------------------------------------------------------------------
+// CopyGraph
+// ----------------------------------------------------------------------------
+void CopyGraph(const Graph& src, Graph* dest) {
+ for (Node* n : dest->nodes()) {
+ CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
+ }
+
+ // Copy the nodes
+ std::unordered_map<Node*, Node*>
+ node_map; // "Node in src" -> "Node in *dest"
+ node_map[src.source_node()] = dest->source_node();
+ node_map[src.sink_node()] = dest->sink_node();
+ for (Node* n : src.nodes()) {
+ if (n->IsSource() || n->IsSink()) continue;
+ CHECK(n->IsOp());
+ node_map[n] = dest->CopyNode(n);
+ }
+
+ // Copy the edges
+ for (const Edge* e : src.edges()) {
+ Node* src_copy = node_map[e->src()];
+ Node* dst_copy = node_map[e->dst()];
+ dest->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h
new file mode 100644
index 0000000000..cd1615ef6b
--- /dev/null
+++ b/tensorflow/core/graph/graph_constructor.h
@@ -0,0 +1,43 @@
+#ifndef TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_
+#define TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// Construct a graph *g out of a GraphDef gdef. Returns non-OK on
+// error, in which case *g is left in an incomplete state.
+struct GraphConstructorOptions {
+ // If true, allows internal ops in the GraphDef.
+ bool allow_internal_ops = false;
+
+ // If true, the graph def is expected to have fully specified
+ // devices for all nodes. A node in the resulting graph "g" has the
+ // device name set accordingly.
+ //
+ // TODO(zhifengc): if possible, consider removing this option.
+ bool expect_device_spec = false;
+
+ // If true, perform common subexpression elimination on the graph.
+ // TODO(jeff): Turn this default to true?
+ bool optimizer_do_cse = false;
+
+ // If "optimizer_do_cse" is true and "cse_consider_function" is
+ // not nullptr, then only consider nodes for CSE for which
+ // "cse_consider_function(node)" returns true.
+ std::function<bool(const Node*)> cse_consider_function = nullptr;
+};
+extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
+ const GraphDef& gdef, Graph* g);
+
+// Make a copy of "src" into "*dest".
+//
+// REQUIRES: "*dest" is a freshly allocated graph without any nodes or edges
+// other than the implicit Source/Sink nodes.
+extern void CopyGraph(const Graph& src, Graph* dest);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
new file mode 100644
index 0000000000..61f4427297
--- /dev/null
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -0,0 +1,190 @@
+#include "tensorflow/core/graph/graph_constructor.h"
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/public/status.h"
+#include <gtest/gtest.h>
+
+// TODO(josh11b): Test InitCostModel().
+// TODO(josh11b): Test setting the "device" field of a NodeDef.
+// TODO(josh11b): Test that feeding won't prune targets.
+
+namespace tensorflow {
+namespace {
+
+class GraphConstructorTest : public ::testing::Test {
+ protected:
+ GraphConstructorTest() : g_(new Graph(OpRegistry::Global())) {
+ RequireDefaultOps();
+ }
+ ~GraphConstructorTest() override {}
+
+ void Convert(const string& gdef_ascii) {
+ CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef_));
+ }
+
+ void ExpectError(const string& gdef_ascii, const string& expected_error_re) {
+ Convert(gdef_ascii);
+ GraphConstructorOptions opts;
+ Status status = ConvertGraphDefToGraph(opts, gdef_, g_.get());
+ EXPECT_FALSE(status.ok());
+ EXPECT_TRUE(RE2::PartialMatch(status.error_message(), expected_error_re))
+ << status;
+ }
+
+ void ExpectOK(const string& gdef_ascii) {
+ Convert(gdef_ascii);
+ GraphConstructorOptions opts;
+ TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef_, g_.get()));
+ }
+
+ Node* FindNode(const string& name) {
+ for (Node* n : g_->nodes()) {
+ if (n->name() == name) return n;
+ }
+ return nullptr;
+ }
+
+ bool HasNode(const string& name) { return FindNode(name) != nullptr; }
+
+ void ExpectNodes(const string& nodes) {
+ int count = 0;
+ std::vector<string> actual_nodes;
+ for (Node* n : g_->nodes()) {
+ if (n->IsOp()) {
+ count++;
+ actual_nodes.push_back(n->name());
+ }
+ }
+ std::sort(actual_nodes.begin(), actual_nodes.end());
+
+ LOG(INFO) << "Nodes present: " << str_util::Join(actual_nodes, " ");
+
+ std::vector<string> expected_nodes = str_util::Split(nodes, ',');
+ std::sort(expected_nodes.begin(), expected_nodes.end());
+ for (const string& s : expected_nodes) {
+ Node* n = FindNode(s);
+ EXPECT_TRUE(n != nullptr) << s;
+ }
+
+ EXPECT_TRUE(actual_nodes.size() == expected_nodes.size())
+ << "\nActual: " << str_util::Join(actual_nodes, ",")
+ << "\nExpected: " << str_util::Join(expected_nodes, ",");
+ }
+
+ bool HasEdge(const string& src, int src_out, const string& dst, int dst_in) {
+ for (const Edge* e : g_->edges()) {
+ if (e->src()->name() == src && e->src_output() == src_out &&
+ e->dst()->name() == dst && e->dst_input() == src_out)
+ return true;
+ }
+ return false;
+ }
+ bool HasControlEdge(const string& src, const string& dst) {
+ return HasEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot);
+ }
+
+ private:
+ GraphDef gdef_;
+ std::unique_ptr<Graph> g_;
+};
+
+REGISTER_OP("ABC");
+REGISTER_OP("TestParams").Output("o: float");
+REGISTER_OP("TestInput").Output("a: float").Output("b: float");
+REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
+REGISTER_OP("TestInt").Input("a: int32");
+
+TEST_F(GraphConstructorTest, InvalidNodeName) {
+ ExpectError("node { name: 'a:b' op: 'ABC' }",
+ "Node 'a:b': Node name contains invalid characters");
+ ExpectError("node { name: '_abc' op: 'ABC' }",
+ // Can't start with '_'
+ "Node '_abc': Node name contains invalid characters");
+ ExpectOK("node { name: 'a-bc_' op: 'ABC' }");
+}
+
+TEST_F(GraphConstructorTest, InvalidSourceNodeName) {
+ ExpectError(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: 'W999' input: 'input' }",
+
+ "Unknown input node.*W999");
+}
+
+TEST_F(GraphConstructorTest, InvalidSourceNodeIndex) {
+ ExpectError(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1:1', 'input:1' ] }",
+
+ "Connecting to invalid output 1 of source node W1");
+}
+
+TEST_F(GraphConstructorTest, GraphWithCycle) {
+ ExpectError(
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'input:0', 't2' ] }"
+ "node { name: 't2' op: 'TestMul' input: [ 'input:1', 't1' ] }",
+
+ "cycle");
+}
+
+TEST_F(GraphConstructorTest, TypeMismatch) {
+ ExpectError(
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 'int' op: 'TestInt' input: [ 'input' ] }",
+
+ "Input 0 of node int was passed float from input:0 incompatible with "
+ "expected int32.");
+}
+
+TEST_F(GraphConstructorTest, EmptyGraph) { ExpectOK(""); }
+
+TEST_F(GraphConstructorTest, SimpleModel) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }");
+ EXPECT_TRUE(HasNode("W1"));
+ EXPECT_TRUE(HasNode("input"));
+ EXPECT_TRUE(HasNode("t1"));
+ EXPECT_TRUE(HasEdge("W1", 0, "t1", 0));
+ EXPECT_TRUE(HasEdge("input", 1, "t1", 0));
+}
+
+TEST_F(GraphConstructorTest, SimpleModelWithControlEdges) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' input: [ '^W1' ] }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
+ "node { name: 't2' op: 'TestMul' input: [ 'W1', 'input:1', '^t1' ] }");
+ EXPECT_TRUE(HasNode("W1"));
+ EXPECT_TRUE(HasNode("input"));
+ EXPECT_TRUE(HasNode("t1"));
+ EXPECT_TRUE(HasNode("t2"));
+ EXPECT_TRUE(HasEdge("W1", 0, "t1", 0));
+ EXPECT_TRUE(HasEdge("input", 1, "t1", 0));
+ EXPECT_TRUE(HasEdge("W1", 0, "t2", 0));
+ EXPECT_TRUE(HasEdge("input", 1, "t2", 0));
+ EXPECT_TRUE(HasControlEdge("W1", "input"));
+ EXPECT_TRUE(HasControlEdge("t1", "t2"));
+}
+
+TEST_F(GraphConstructorTest, Error_ControlEdgeBeforeRealInput) {
+ ExpectError(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' input: [ '^W1' ] }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
+ "node { name: 't2' op: 'TestMul' input: [ 'W1', '^t1', 'input:1' ] }",
+ "Node 't2': Control dependencies must come after regular dependencies");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph_def_builder.cc b/tensorflow/core/graph/graph_def_builder.cc
new file mode 100644
index 0000000000..979604f948
--- /dev/null
+++ b/tensorflow/core/graph/graph_def_builder.cc
@@ -0,0 +1,121 @@
+#include "tensorflow/core/graph/graph_def_builder.h"
+
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/tensor_id.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+GraphDefBuilder::Options::Options(Graph* graph, Status* status)
+ : graph_(graph), status_(status) {}
+GraphDefBuilder::Options::~Options() {}
+
+GraphDefBuilder::Options GraphDefBuilder::Options::WithName(
+ StringPiece name) const {
+ return Options(*this).WithNameImpl(name);
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithDevice(
+ StringPiece device) const {
+ return Options(*this).WithDeviceImpl(device);
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInput(
+ Node* control_input) const {
+ return Options(*this).WithControlInputImpl(control_input);
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputs(
+ gtl::ArraySlice<Node*> control_inputs) const {
+ return Options(*this).WithControlInputsImpl(control_inputs);
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithNameImpl(
+ StringPiece name) {
+ name_ = name.ToString();
+ return *this;
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithDeviceImpl(
+ StringPiece device) {
+ device_ = device.ToString();
+ return *this;
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputImpl(
+ Node* control_input) {
+ control_inputs_.push_back(control_input);
+ return *this;
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputsImpl(
+ gtl::ArraySlice<Node*> control_inputs) {
+ control_inputs_.insert(control_inputs_.end(), control_inputs.begin(),
+ control_inputs.end());
+ return *this;
+}
+
+Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const {
+ if (status_.ok()) {
+ graph_.ToGraphDef(graph_def);
+ }
+ return status_;
+}
+
+Status GraphDefBuilder::ToGraph(Graph* graph) const {
+ if (status_.ok()) {
+ GraphDef graph_def;
+ graph_.ToGraphDef(&graph_def);
+ GraphConstructorOptions opts;
+ TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def, graph));
+ }
+ return status_;
+}
+
+string GraphDefBuilder::Options::GetNameForOp(StringPiece op) const {
+ if (name_.empty()) return graph_->NewName(op);
+ return name_;
+}
+
+Node* GraphDefBuilder::Options::FinalizeBuilder(NodeBuilder* builder) const {
+ builder->ControlInputs(control_inputs_);
+ if (!device_.empty()) builder->Device(device_);
+ for (const auto& attr : attrs_) {
+ builder->Attr(attr.first, attr.second);
+ }
+
+ Node* returned_node;
+ UpdateStatus(builder->Finalize(graph_, &returned_node));
+ return returned_node;
+}
+
+void GraphDefBuilder::Options::UpdateStatus(const Status& status) const {
+ if (status_ == nullptr) {
+ TF_CHECK_OK(status);
+ } else {
+ status_->Update(status);
+ }
+}
+
+namespace ops {
+
+Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts) {
+ if (opts.HaveError()) return nullptr;
+ NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
+ opts.op_registry());
+ return opts.FinalizeBuilder(&node_builder);
+}
+
+Node* UnaryOp(const string& op_name, NodeOut input,
+ const GraphDefBuilder::Options& opts) {
+ if (opts.HaveError()) return nullptr;
+ NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
+ opts.op_registry());
+ node_builder.Input(input);
+ return opts.FinalizeBuilder(&node_builder);
+}
+
+Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
+ const GraphDefBuilder::Options& opts) {
+ if (opts.HaveError()) return nullptr;
+ NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
+ opts.op_registry());
+ node_builder.Input(a).Input(b);
+ return opts.FinalizeBuilder(&node_builder);
+}
+
+} // end namespace ops
+} // end namespace tensorflow
diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h
new file mode 100644
index 0000000000..bb72f9eea6
--- /dev/null
+++ b/tensorflow/core/graph/graph_def_builder.h
@@ -0,0 +1,181 @@
+#ifndef TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_
+#define TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_
+
+#include <vector>
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace tensorflow {
+
+// Given a function like:
+// namespace ops {
+// Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) {
+// if (opts.HaveError()) return nullptr;
+// static const string kOpName = "Identity";
+// NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName,
+// opts.op_registry());
+// node_builder.Input(input);
+// return opts.FinalizeBuilder(&node_builder);
+// }
+// } // namspace ops
+//
+// // Or, alternatively:
+// namespace ops {
+// Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) {
+// static const string kOpName = "Identity";
+// return UnaryOp(kOpName, input, opts);
+// }
+// } // namspace ops
+//
+// You call it like:
+// GraphDefBuilder b;
+// using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+// Node* a = Const(7, b.opts());
+// // Note: WithName() returns a copy, opts is unchanged.
+// Node* b = Const(5, b.opts().WithName("control-input"));
+// Node* c = Identity(a, b.opts().WithControlInput(b));
+// GraphDef graph_def;
+// Status status = b.ToGraphDef(&graph_def);
+// if (!status.ok()) { /* Handle error */ }
+//
+// In tests you can skip the status handling via:
+// GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+// ...
+// b.ToGraphDef(&graph_def);
+
+class GraphDefBuilder {
+ public:
+ // Options for adding a Node to a Graph.
+ class Options {
+ public:
+ // Sets the Graph (that Nodes will be added to) and the status. The
+ // status may be set to nullptr, in which case errors cause CHECK
+ // failures. The graph and status must outlive *this.
+ Options(Graph* graph, Status* status);
+ ~Options();
+
+ // Methods for setting options. These are const methods: they
+ // return a copy of *this with the option set.
+ Options WithName(StringPiece name) const;
+ Options WithDevice(StringPiece device) const;
+ Options WithControlInput(Node* control_input) const;
+ Options WithControlInputs(gtl::ArraySlice<Node*> control_inputs) const;
+
+ // Override the default value for an optional attr.
+ template <class T>
+ Options WithAttr(StringPiece attr_name, T&& value) const {
+ return Options(*this).WithAttrImpl(attr_name, std::forward<T>(value));
+ }
+ // Note: overload needed to allow {...} expressions for value.
+ template <class T>
+ Options WithAttr(StringPiece attr_name,
+ std::initializer_list<T> value) const {
+ return WithAttr<std::initializer_list<T>>(attr_name, std::move(value));
+ }
+
+ // Methods for using options from a function that creates a Node.
+
+ // Returns true if the status associated with *this has an error.
+ // Use this to skip processing that may depend on prior results.
+ bool HaveError() const { return status_ != nullptr && !status_->ok(); }
+
+ // Given the Op type name, return a name for a node of that type.
+ // Uses the value set in WithName() if that has been called. Otherwise,
+ // returns a name built out of the Op type name.
+ string GetNameForOp(StringPiece op) const;
+
+ // Sets the device, adds control inputs, adds attrs, and calls Finalize().
+ // If Finalize returns an error, it is saved and this function returns
+ // nullptr.
+ Node* FinalizeBuilder(NodeBuilder* builder) const;
+
+ // Updates the associated status, if any, or calls TF_CHECK_OK if none.
+ void UpdateStatus(const Status& status) const;
+
+ // Accessor
+ const OpRegistryInterface* op_registry() const {
+ return graph_->op_registry();
+ }
+
+ private:
+ Options WithNameImpl(StringPiece name);
+ Options WithDeviceImpl(StringPiece device);
+ Options WithControlInputImpl(Node* control_input);
+ Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs);
+ template <class T>
+ Options WithAttrImpl(StringPiece name, T&& value) {
+ attrs_.emplace_back(name.ToString(), AttrValue());
+ SetAttrValue(std::forward<T>(value), &attrs_.back().second);
+ return *this;
+ }
+
+ Graph* const graph_;
+ Status* const status_;
+ string name_;
+ string device_;
+ std::vector<Node*> control_inputs_;
+ std::vector<std::pair<string, AttrValue>> attrs_;
+ };
+
+ // Start building a new graph.
+ explicit GraphDefBuilder(
+ const OpRegistryInterface* op_registry = OpRegistry::Global())
+ : graph_(op_registry), opts_(&graph_, &status_) {}
+
+ // For use in tests, where you want to fail immediately on error instead
+ // of checking the status at the end.
+ enum TestFailImmediatelyType { kFailImmediately };
+ explicit GraphDefBuilder(
+ TestFailImmediatelyType,
+ const OpRegistryInterface* op_registry = OpRegistry::Global())
+ : graph_(op_registry), opts_(&graph_, nullptr) {}
+
+ // Gets the Options with the associated Graph and Status.
+ const Options& opts() const { return opts_; }
+
+ // Once all the nodes have been added, call this to get whether it was
+ // successful, and if so fill *graph_def.
+ Status ToGraphDef(GraphDef* graph_def) const;
+
+ // Like ToGraphDef(), but converts to a Graph (using the default
+ // GraphConstructorOptions).
+ // TODO(josh11b): Make this faster; right now it converts
+ // Graph->GraphDef->Graph. This cleans up the graph (e.g. adds
+ // edges from the source and to the sink node, resolves back edges
+ // by name), and makes sure the resulting graph is valid.
+ Status ToGraph(Graph* graph) const;
+
+ private:
+ Graph graph_;
+ Status status_;
+ Options opts_;
+};
+
+namespace ops {
+
+// A NodeOut may either be a regular input or back input. Regular
+// inputs are specified via either a Node* or a Node* and an output
+// index. Back inputs are specified by a node name, output index, and
+// output type.
+typedef NodeBuilder::NodeOut NodeOut;
+
+// For adding an Op with no inputs to a GraphDefBuilder.
+Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts);
+
+// For adding an Op with one input to a GraphDefBuilder.
+Node* UnaryOp(const string& op_name, NodeOut input,
+ const GraphDefBuilder::Options& opts);
+
+// For adding an Op with two inputs to a GraphDefBuilder.
+Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
+ const GraphDefBuilder::Options& opts);
+
+} // namespace ops
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
new file mode 100644
index 0000000000..1571790e59
--- /dev/null
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -0,0 +1,1050 @@
+#include "tensorflow/core/graph/graph_partition.h"
+
+#include <deque>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/costmodel.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+namespace {
+
+struct DupRecvKey {
+ int src_node_id; // Edge's src node id
+ int src_output_slot; // Edge's src node output slot
+ GraphDef* dst_graph; // Edge's dst node is in this subgraph
+ bool recv_output_on_host; // The output of recv is on host
+};
+
+struct DupRecvKeyHash {
+ size_t operator()(const DupRecvKey& k) const {
+ size_t h = Hash64(reinterpret_cast<const char*>(&k.src_node_id),
+ sizeof(k.src_node_id), k.src_output_slot);
+ h = Hash64(reinterpret_cast<const char*>(&k.dst_graph), sizeof(k.dst_graph),
+ h);
+ h = Hash64(reinterpret_cast<const char*>(&k.recv_output_on_host),
+ sizeof(k.recv_output_on_host), h);
+ return h;
+ }
+};
+
+struct DupRecvKeyEq {
+ bool operator()(const DupRecvKey& x, const DupRecvKey& y) const {
+ return (x.src_node_id == y.src_node_id) &&
+ (x.src_output_slot == y.src_output_slot) &&
+ (x.dst_graph == y.dst_graph) &&
+ (x.recv_output_on_host == y.recv_output_on_host);
+ }
+};
+
+// struct used to store the recvs, so that start times can be properly updated
+struct RecvInfo {
+ NodeDef* recv;
+ NodeDef* real_recv;
+ int64 start_time;
+};
+
+typedef std::unordered_map<DupRecvKey, RecvInfo, DupRecvKeyHash, DupRecvKeyEq>
+ DupRecvTable;
+
+// Control flow info for a graph node.
+struct ControlFlowInfo {
+ const Node* frame = nullptr; // frame of a node
+ const Node* parent_frame = nullptr; // parent frame of a node
+ string frame_name; // frame name of a node
+ int iter_level = -1; // level of a node
+};
+
+struct PairIntHash {
+ public:
+ std::size_t operator()(const std::pair<int, int>& x) const {
+ return std::hash<int>()(x.first) ^ std::hash<int>()(x.second);
+ }
+};
+// A map used to store memory types for the inputs/outputs of every node.
+// The key is a pair of ints consisting of a node id and input/output index.
+typedef std::unordered_map<std::pair<int, int>, MemoryType, PairIntHash>
+ MemoryTypeMap;
+
+// We collect the following information about the graph before performing
+// graph partitioning.
+struct GraphInfo {
+ std::vector<DeviceType> device_types;
+ MemoryTypeMap input_types;
+ MemoryTypeMap output_types;
+ std::vector<ControlFlowInfo> cf_info;
+};
+
+DataType EdgeType(const Edge* e) {
+ if (e->IsControlEdge()) {
+ return DT_FLOAT;
+ } else {
+ return e->dst()->input_type(e->dst_input());
+ }
+}
+
+// Return true iff we need to add a same device send/recv for 'edge'.
+bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) {
+ if (edge->IsControlEdge()) {
+ return false;
+ }
+
+ Node* src = edge->src();
+ Node* dst = edge->dst();
+ if (src->assigned_device_name() == dst->assigned_device_name()) {
+ int src_port = edge->src_output();
+ int dst_port = edge->dst_input();
+ if (info.device_types[src->id()] == DEVICE_GPU) {
+ auto src_it = info.output_types.find({src->id(), src_port});
+ DCHECK(src_it != info.output_types.end());
+ auto dst_it = info.input_types.find({dst->id(), dst_port});
+ DCHECK(dst_it != info.input_types.end());
+ return src_it->second != dst_it->second;
+ }
+ }
+ return false;
+}
+
+// Return true iff (dst, dst_input) is specified on host memory.
+bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) {
+ Node* dst = edge->dst();
+ int dst_port = edge->dst_input();
+ if (info.device_types[dst->id()] == DEVICE_GPU) {
+ if (edge->IsControlEdge()) return false;
+ auto dst_it = info.input_types.find({dst->id(), dst_port});
+ DCHECK(dst_it != info.input_types.end());
+ return dst_it->second == HOST_MEMORY;
+ }
+ return true;
+}
+
+// Add an input to dst that comes from the "src_slot" output of the
+// node named by "src_name".
+void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
+ if (src_slot == Graph::kControlSlot) {
+ dst->add_input(strings::StrCat("^", src_name));
+ } else if (src_slot == 0) {
+ dst->add_input(src_name.data(), src_name.size());
+ } else {
+ dst->add_input(strings::StrCat(src_name, ":", src_slot));
+ }
+}
+
+// Add a control edge from each input to each recv.
+void AddReadControl(const std::vector<NodeDef*>& recvs,
+ const std::vector<string>& inputs) {
+ for (NodeDef* recv : recvs) {
+ for (const string& input : inputs) {
+ recv->add_input(strings::StrCat("^", input));
+ }
+ }
+}
+
+void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge,
+ NodeDefBuilder* builder) {
+ builder->Attr("tensor_name",
+ strings::StrCat("edge_", edge->id(), "_", edge->src()->name()));
+ builder->Attr("send_device", edge->src()->assigned_device_name());
+ builder->Attr("send_device_incarnation",
+ static_cast<int64>(
+ opts.get_incarnation(edge->src()->assigned_device_name())));
+ builder->Attr("recv_device", edge->dst()->assigned_device_name());
+ builder->Attr("client_terminated", false);
+}
+
+NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
+ GraphDef* gdef, const Edge* edge,
+ NodeDefBuilder::NodeOut send_from, int64 start_time,
+ Status* status) {
+ const DataType dtype = send_from.data_type;
+ const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype;
+ const Node* src = edge->src();
+ const int src_port = edge->src_output();
+
+ // host_memory = true iff we need to use HostSend/HostCast.
+ bool host_memory = false;
+ if (!edge->IsControlEdge()) {
+ auto src_it = g_info.output_types.find({src->id(), src_port});
+ DCHECK(src_it != g_info.output_types.end());
+ host_memory = (src_it->second == HOST_MEMORY);
+ }
+
+ // Add a cast node that casts dtype to cast_dtype.
+ // NOTE(yuanbyu): Only cast for cross-device send/recv.
+ if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) {
+ const string cast_op = (host_memory) ? "_HostCast" : "Cast";
+ NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op);
+ cast_builder.Device(src->assigned_device_name()).Input(send_from);
+ if (opts.scheduling_for_recvs) {
+ cast_builder.Attr("_start_time", start_time);
+ }
+ cast_builder.Attr("DstT", cast_dtype);
+ NodeDef* cast = gdef->add_node();
+ *status = cast_builder.Finalize(cast);
+ if (!status->ok()) return nullptr;
+
+ // Connect the Send op to the cast.
+ send_from.Reset(cast->name(), 0, cast_dtype);
+ }
+
+ // Add the send node.
+ const string send_op = (host_memory) ? "_HostSend" : "_Send";
+ NodeDefBuilder send_builder(opts.new_name(src->name()), send_op);
+ SetSendRecvAttrs(opts, edge, &send_builder);
+ send_builder.Device(src->assigned_device_name()).Input(send_from);
+ if (opts.scheduling_for_recvs) {
+ send_builder.Attr("_start_time", start_time);
+ }
+ NodeDef* send = gdef->add_node();
+ *status = send_builder.Finalize(send);
+ return send;
+}
+
+NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info,
+ GraphDef* gdef, const Edge* edge, NodeDef** real_recv,
+ Status* status) {
+ const DataType dtype = EdgeType(edge);
+ const Node* src = edge->src();
+ const Node* dst = edge->dst();
+ const int dst_port = edge->dst_input();
+ DataType cast_dtype = dtype;
+
+ // NOTE(yuanbyu): Only cast for cross-device send/recv.
+ if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) {
+ cast_dtype = opts.should_cast(edge);
+ }
+
+ // host_memory = true iff we need to use HostRecv/HostCast.
+ bool host_memory = false;
+ if (!edge->IsControlEdge()) {
+ auto dst_it = g_info.input_types.find({dst->id(), dst_port});
+ DCHECK(dst_it != g_info.input_types.end());
+ host_memory = (dst_it->second == HOST_MEMORY);
+ }
+
+ // Add the recv node.
+ const string recv_op = (host_memory) ? "_HostRecv" : "_Recv";
+ NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op);
+ SetSendRecvAttrs(opts, edge, &recv_builder);
+ recv_builder.Device(dst->assigned_device_name())
+ .Attr("tensor_type", cast_dtype);
+ NodeDef* recv = gdef->add_node();
+ *status = recv_builder.Finalize(recv);
+ if (!status->ok()) return nullptr;
+ *real_recv = recv;
+
+ // Add the cast node (from cast_dtype to dtype) or an Identity node.
+ if (dtype != cast_dtype) {
+ const string cast_op = (host_memory) ? "_HostCast" : "Cast";
+ NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op);
+ cast_builder.Attr("DstT", dtype);
+ cast_builder.Device(dst->assigned_device_name())
+ .Input(recv->name(), 0, cast_dtype);
+ NodeDef* cast = gdef->add_node();
+ *status = cast_builder.Finalize(cast);
+ if (!status->ok()) return nullptr;
+ return cast;
+ } else if (edge->IsControlEdge()) {
+ // An Identity is only needed for control edges.
+ NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity");
+ id_builder.Device(dst->assigned_device_name())
+ .Input(recv->name(), 0, cast_dtype);
+ NodeDef* id = gdef->add_node();
+ *status = id_builder.Finalize(id);
+ if (!status->ok()) return nullptr;
+ return id;
+ } else {
+ return recv;
+ }
+}
+
+NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef,
+ const Edge* edge, Status* status) {
+ const Node* src = edge->src();
+ Tensor tensor(DT_FLOAT, TensorShape({0}));
+ NodeDef* result = gdef->add_node();
+ *status = NodeDefBuilder(opts.new_name(src->name()), "Const")
+ .Device(src->assigned_device_name())
+ .Attr("dtype", DT_FLOAT)
+ .Attr("value", tensor)
+ .Finalize(result);
+ return result;
+}
+
+// A dummy node for scheduling.
+NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
+ const string& assigned_device_name, int64 epoch,
+ int64 starttime, Status* status) {
+ NodeDef* result = gdef->add_node();
+ *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_", epoch)),
+ "ControlTrigger")
+ .Device(assigned_device_name)
+ .Attr("_start_time", starttime)
+ .Finalize(result);
+ return result;
+}
+
+// Assign to each node the name of the frame and the level it belongs to.
+// We check the well-formedness of the graph: All inputs to a node must
+// come from the same frame and have the same "static" iteration level.
+// NOTE(yuanbyu): For now, we require all sends/recvs have iteration level
+// 0. This essentially means there can't be multiple serial Nexts in
+// an iteration, which all sane front-ends should satisfy.
+Status BuildControlFlowInfo(Graph* g, std::vector<ControlFlowInfo>* info) {
+ info->clear();
+ info->resize(g->num_node_ids());
+
+ Node* src_node = g->source_node();
+ ControlFlowInfo& src_info = (*info)[src_node->id()];
+ src_info.frame = src_node;
+ src_info.parent_frame = src_node;
+ src_info.iter_level = 0;
+
+ string frame_name;
+ std::deque<const Node*> ready;
+ ready.push_back(src_node);
+ while (!ready.empty()) {
+ const Node* curr_node = ready.front();
+ ready.pop_front();
+ const ControlFlowInfo& curr_info = (*info)[curr_node->id()];
+ const Node* frame = curr_info.frame;
+ const Node* parent = curr_info.parent_frame;
+ frame_name = curr_info.frame_name;
+ int iter_level = curr_info.iter_level;
+
+ if (IsExit(curr_node)) {
+ const ControlFlowInfo& parent_info = (*info)[parent->id()];
+ frame = parent_info.frame;
+ parent = parent_info.parent_frame;
+ frame_name = parent_info.frame_name;
+ iter_level = parent_info.iter_level;
+ }
+
+ for (const Edge* out_edge : curr_node->out_edges()) {
+ const Node* out = out_edge->dst();
+ int out_id = out->id();
+ ControlFlowInfo* out_info = &(*info)[out_id];
+ const Node* out_parent = out_info->parent_frame;
+ bool is_visited = (out_info->iter_level != -1);
+
+ // Skip Sink/Source nodes.
+ if (!out->IsOp()) continue;
+
+ // Add to ready queue if not seen.
+ if (!is_visited) {
+ ready.push_back(out);
+ }
+
+ // Process the node 'out'.
+ if (IsEnter(out)) {
+ if (is_visited) {
+ const string& parent_name = (*info)[out_parent->id()].frame_name;
+ if (parent_name != frame_name || iter_level != out_info->iter_level) {
+ return errors::InvalidArgument(
+ "All inputs to Enter must be from the same frame and level.");
+ }
+ } else {
+ out_info->frame = out;
+ out_info->parent_frame = frame;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(out->def(), "frame_name", &out_info->frame_name));
+ if (out_info->frame_name.empty()) {
+ return errors::InvalidArgument(
+ "Enter must have a non-empty frame name.");
+ }
+ out_info->iter_level = 0;
+ }
+ } else if (IsNextIteration(out)) {
+ if (is_visited) {
+ if (out_info->frame_name != frame_name ||
+ out_info->iter_level != (iter_level + 1)) {
+ return errors::InvalidArgument(
+ "All inputs to NextIteration must be from the same frame "
+ "and level.");
+ }
+ } else {
+ out_info->frame = frame;
+ out_info->parent_frame = parent;
+ out_info->frame_name = frame_name;
+ out_info->iter_level = iter_level + 1;
+ }
+ } else {
+ if (is_visited) {
+ if (out_info->frame_name != frame_name) {
+ return errors::InvalidArgument(
+ "All inputs to a node must be from the same frame.");
+ }
+ } else {
+ out_info->frame = frame;
+ out_info->parent_frame = parent;
+ out_info->frame_name = frame_name;
+ out_info->iter_level = iter_level;
+ }
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+string ControlLoopName(const string& name) {
+ return strings::StrCat("_cloop", name);
+}
+
+bool IsControlLoop(const Node* node) {
+ const string& name = node->def().name();
+ return StringPiece(name).starts_with("_cloop");
+}
+
+// An enter node for control flow.
+Node* AddControlEnter(Graph* g, const string& node_name,
+ const string& device_name, const string& frame_name,
+ const int parallel_iterations, Status* status) {
+ NodeBuilder node_builder(node_name, "Enter", g->op_registry());
+ node_builder.Input({"dummy", 0, DT_FLOAT});
+ node_builder.Attr("frame_name", frame_name);
+ node_builder.Attr("parallel_iterations", parallel_iterations);
+ Node* res_node;
+ *status = node_builder.Finalize(g, &res_node);
+ if (!status->ok()) return nullptr;
+ res_node->set_assigned_device_name(device_name);
+ return res_node;
+}
+
+// A merge node for control flow.
+Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g,
+ const string& node_name, const string& device_name,
+ Status* status) {
+ NodeBuilder node_builder(node_name, "Merge", g->op_registry());
+ node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}});
+ Node* res_node;
+ *status = node_builder.Finalize(g, &res_node);
+ if (!status->ok()) return nullptr;
+ res_node->set_assigned_device_name(device_name);
+ return res_node;
+}
+
+// A switch node for control flow.
+Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2,
+ const string& device_name,
+ const GraphDefBuilder::Options& bopts) {
+ Node* res_node = ops::BinaryOp("Switch", input1, input2, bopts);
+ if (bopts.HaveError()) return nullptr;
+ res_node->set_assigned_device_name(device_name);
+ return res_node;
+}
+
+// A next_iteration node for control flow.
+Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name,
+ const GraphDefBuilder::Options& bopts) {
+ Node* res_node = ops::UnaryOp("NextIteration", input, bopts);
+ if (bopts.HaveError()) return nullptr;
+ res_node->set_assigned_device_name(device_name);
+ return res_node;
+}
+
+Node* EmptyConst(const GraphDefBuilder::Options& options) {
+ if (options.HaveError()) return nullptr;
+ NodeBuilder node_builder(options.GetNameForOp("Const"), "Const",
+ options.op_registry());
+ const DataType dt = DataTypeToEnum<float>::v();
+ TensorProto proto;
+ proto.set_dtype(dt);
+ TensorShape empty_shape({0});
+ empty_shape.AsProto(proto.mutable_tensor_shape());
+ node_builder.Attr("dtype", dt).Attr("value", proto);
+ return options.FinalizeBuilder(&node_builder);
+}
+
+// A dummy const node for control flow.
+Node* AddControlConst(const string& device_name,
+ const GraphDefBuilder::Options& bopts) {
+ Node* res_node = EmptyConst(bopts);
+ if (bopts.HaveError()) return nullptr;
+ res_node->set_assigned_device_name(device_name);
+ return res_node;
+}
+
+// A synthetic loop, made up of dummy nodes. It performs control-flow actions
+// on behalf of a leader on a different device.
+struct ControlLoop {
+ Node* enter = nullptr;
+ Node* merge = nullptr;
+ Node* switch_node = nullptr;
+};
+
+// Add the control flow info of a new node added during partitioning.
+// The new node has the same control flow info as edge->src().
+void AddControlFlowInfo(const Node* node, const Node* src,
+ std::vector<ControlFlowInfo>* cf_info) {
+ int id = node->id();
+ if (static_cast<size_t>(id) >= cf_info->size()) {
+ cf_info->resize(id + 1);
+ }
+ const ControlFlowInfo& src_info = (*cf_info)[src->id()];
+ ControlFlowInfo* info = &(*cf_info)[id];
+ info->frame = src_info.frame;
+ info->parent_frame = src_info.parent_frame;
+ info->frame_name = src_info.frame_name;
+ info->iter_level = src_info.iter_level;
+}
+
+// Constructs a control loop. Returns a struct containing the newly created
+// enter, merge, and switch nodes. The enter and merge nodes are used in the
+// recursive construction of control loops for nested frames (loops). The
+// switch node will be connected to the LoopCond node. The merge node will
+// be connected to all the recvs of the same frame by control edges when
+// the actual partitioning happens.
+Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src,
+ const Edge* edge, Node* loop_cond,
+ std::vector<ControlFlowInfo>* cf_info,
+ ControlLoop* loop) {
+ Status status;
+ GraphDefBuilder::Options bopts(g, &status);
+ const ControlFlowInfo& src_info = (*cf_info)[src->id()];
+ const string& device_name = edge->dst()->assigned_device_name();
+ const string& frame_name = src_info.frame_name;
+ int parallel_iterations;
+ status = GetNodeAttr(src_info.frame->def(), "parallel_iterations",
+ &parallel_iterations);
+ if (!status.ok()) return status;
+
+ // The names of the nodes to be added.
+ const string& enter_name =
+ ControlLoopName(opts.new_name(edge->dst()->name()));
+ const string& merge_name =
+ ControlLoopName(opts.new_name(edge->dst()->name()));
+ const string& switch_name =
+ ControlLoopName(opts.new_name(edge->dst()->name()));
+ const string& next_name = ControlLoopName(opts.new_name(edge->dst()->name()));
+
+ // Add the nodes to the graph g.
+ Node* enter = AddControlEnter(g, enter_name, device_name, frame_name,
+ parallel_iterations, &status);
+ if (!status.ok()) return status;
+ Node* merge = AddControlMerge(enter_name, next_name, g, merge_name,
+ device_name, &status);
+ if (!status.ok()) return status;
+ Node* switch_node = AddControlSwitch(merge, loop_cond, device_name,
+ bopts.WithName(switch_name));
+ if (!status.ok()) return status;
+ Node* next =
+ AddControlNext({switch_node, 1}, device_name, bopts.WithName(next_name));
+ if (!status.ok()) return status;
+
+ // Add control flow info for these new nodes:
+ AddControlFlowInfo(enter, src, cf_info);
+ AddControlFlowInfo(merge, src, cf_info);
+ AddControlFlowInfo(switch_node, src, cf_info);
+ AddControlFlowInfo(next, src, cf_info);
+
+ // Add input edges for the newly created merge node:
+ g->AddEdge(enter, 0, merge, 0);
+ g->AddEdge(next, 0, merge, 1);
+
+ loop->enter = enter;
+ loop->merge = merge;
+ loop->switch_node = switch_node;
+ return Status::OK();
+}
+
+// Build memory and device type info for every node in the graph.
+// TODO(yuanbyu): It might be simpler if we convert MemoryType to
+// DeviceType for the inputs/outputs of each node.
+Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) {
+ Status status;
+ MemoryTypeVector input_memory_types;
+ MemoryTypeVector output_memory_types;
+
+ info->device_types.resize(g.num_node_ids(), DEVICE_CPU);
+ for (const Node* node : g.nodes()) {
+ if (!node->IsOp()) continue; // Skip Sink/Source nodes.
+
+ DeviceNameUtils::ParsedName parsed;
+ if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(),
+ &parsed)) {
+ return errors::Internal("Malformed assigned device '",
+ node->assigned_device_name(), "'");
+ }
+
+ input_memory_types.clear();
+ input_memory_types.resize(node->num_inputs());
+ output_memory_types.clear();
+ output_memory_types.resize(node->num_outputs());
+ status = MemoryTypesForNode(g.op_registry(), DeviceType(parsed.type),
+ node->def(), &input_memory_types,
+ &output_memory_types);
+ if (!status.ok()) return status;
+
+ int node_id = node->id();
+ info->device_types[node_id] = DeviceType(parsed.type);
+ for (size_t i = 0; i < input_memory_types.size(); ++i) {
+ info->input_types[{node_id, i}] = input_memory_types[i];
+ }
+ for (size_t i = 0; i < output_memory_types.size(); ++i) {
+ info->output_types[{node_id, i}] = output_memory_types[i];
+ }
+ }
+ return status;
+}
+
+// Each participating device needs to decide a) if there is a next iteration,
+// and b) if the loop terminates. We take the approach to encode this control
+// flow logic in the dataflow graph. There are at least two possible encodings.
+// In a completely decentralized encoding, the participants communicate peer
+// to peer. The other encoding uses a frame leader (the participant who owns
+// the pivot termination predicate) to broadcast the termination condition to
+// all the participants. For now we take the latter because it is simpler.
+//
+// TODO(yuanbyu): The correctness of this construction is rather subtle. I got
+// it wrong many times so it would be nice to write a proof to be sure.
+Status AddControlFlow(const PartitionOptions& opts, Graph* g,
+ GraphInfo* g_info) {
+ Status status;
+ GraphDefBuilder::Options bopts(g, &status);
+ std::vector<ControlFlowInfo>& cf_info = g_info->cf_info;
+
+ // Build the control flow info for every node.
+ status = BuildControlFlowInfo(g, &cf_info);
+ if (!status.ok()) return status;
+
+ // The map from frames to their LoopCond nodes.
+ std::unordered_map<string, Node*> frame_cond_map;
+ int num_node_ids = g->num_node_ids();
+ for (int i = 0; i < num_node_ids; ++i) {
+ Node* node = g->FindNodeId(i);
+ if (node == nullptr) continue;
+
+ if (IsLoopCond(node)) {
+ const string& frame_name = cf_info[node->id()].frame_name;
+ DCHECK(!frame_name.empty());
+ frame_cond_map[frame_name] = node;
+ }
+ }
+
+ // Add all control loops for cross-device frames.
+ // A control loop is added only when there is a cross-device edge in a
+ // non-root frame. Nothing is added if there is no loops. We also don't
+ // add anything for a frame that is completely local to a device. For
+ // nested loops, we stack the control loops together by connecting
+ // the merge of the outer loop to the enter of the inner loop.
+ //
+ // A map from <frame_name, device_name> to ControlLoop.
+ std::unordered_map<string, ControlLoop> control_loops;
+ int num_edge_ids = g->num_edge_ids();
+ for (int i = 0; i < num_edge_ids; ++i) {
+ const Edge* edge = g->FindEdgeId(i);
+ if (edge == nullptr) continue;
+
+ const Node* src = edge->src();
+ const Node* dst = edge->dst();
+ // Skip Sink/Source nodes.
+ if (!src->IsOp() || !dst->IsOp()) continue;
+
+ const string& src_device = src->assigned_device_name();
+ const string& dst_device = dst->assigned_device_name();
+ // Skip local edges.
+ if (src_device == dst_device) continue;
+
+ const string& src_frame = cf_info[src->id()].frame_name;
+ const string& dst_frame = cf_info[dst->id()].frame_name;
+ // Skip if src and dst are not in the same frame.
+ if (src_frame.empty() || src_frame != dst_frame) {
+ continue;
+ }
+
+ // Add the control loop. Start by adding the control loop for the
+ // current frame if needed, and recursively adding the control loop
+ // for its outer frame when nested.
+ ControlLoop child_loop;
+ while (true) {
+ const string& curr_frame = cf_info[src->id()].frame_name;
+ if (curr_frame.empty()) {
+ // We have reached the root frame.
+ if (child_loop.merge != nullptr) {
+ const string& node_name = opts.new_name(edge->dst()->name());
+ const string& device_name = edge->dst()->assigned_device_name();
+ Node* const_node =
+ AddControlConst(device_name, bopts.WithName(node_name));
+ if (!status.ok()) return status;
+ AddControlFlowInfo(const_node, src, &cf_info);
+ g->AddEdge(const_node, 0, child_loop.enter, 0);
+ }
+ break;
+ }
+
+ const string& cl_key = strings::StrCat(curr_frame, "$$", dst_device);
+ auto it = control_loops.find(cl_key);
+ if (it != control_loops.end()) {
+ if (child_loop.enter != nullptr) {
+ g->AddEdge(it->second.merge, 0, child_loop.enter, 0);
+ }
+ break;
+ }
+
+ // Get the frame's LoopCond.
+ auto cond_it = frame_cond_map.find(curr_frame);
+ if (cond_it == frame_cond_map.end()) {
+ return errors::InvalidArgument(
+ "A cross-device loop must have a pivot predicate: ", curr_frame);
+ }
+ Node* loop_cond = cond_it->second;
+
+ // Add the control loop.
+ ControlLoop curr_loop;
+ status =
+ AddControlLoop(opts, g, src, edge, loop_cond, &cf_info, &curr_loop);
+ if (!status.ok()) return status;
+ control_loops[cl_key] = curr_loop;
+
+ if (child_loop.enter != nullptr) {
+ // Connect the merge of the outer loop to the enter of the inner.
+ g->AddEdge(curr_loop.merge, 0, child_loop.enter, 0);
+ }
+ src = cf_info[src->id()].parent_frame;
+ child_loop = curr_loop;
+ }
+ }
+
+ // For a cross-device edge, on the dst device, add a control edge
+ // from the merge node of the control loop to dst. If a send/recv is
+ // introduced for this edge in future partitioning, we delete this
+ // control edge and add a new control edge from the merge to the recv.
+ num_edge_ids = g->num_edge_ids();
+ for (int i = 0; i < num_edge_ids; ++i) {
+ const Edge* edge = g->FindEdgeId(i);
+ if (edge == nullptr) continue;
+
+ const Node* src = edge->src();
+ Node* dst = edge->dst();
+ // Skip Sink/Source nodes.
+ if (!src->IsOp() || !dst->IsOp()) continue;
+
+ const string& src_device = src->assigned_device_name();
+ const string& dst_device = dst->assigned_device_name();
+ if (src_device != dst_device) {
+ const string& src_frame = cf_info[src->id()].frame_name;
+ const string& dst_frame = cf_info[dst->id()].frame_name;
+ if (!src_frame.empty() && src_frame == dst_frame) {
+ const string& cl_key = strings::StrCat(dst_frame, "$$", dst_device);
+ ControlLoop loop = control_loops[cl_key];
+ DCHECK(loop.enter != nullptr);
+ g->AddControlEdge(loop.merge, dst);
+ }
+ }
+ }
+ return Status::OK();
+}
+
+} // end namespace
+
+Status AddControlEdges(const PartitionOptions& opts,
+ std::unordered_map<string, GraphDef>* partitions) {
+ Status status;
+ // TODO(yuanbyu): Very naive for now. To be improved.
+ const int num_epochs = 100;
+ const int prefetch = 6;
+
+ typedef std::pair<const NodeDef*, int64> NodeStartTime;
+ for (auto& part : *partitions) {
+ GraphDef* gdef = &part.second;
+
+ std::vector<NodeStartTime> start_times;
+ start_times.resize(gdef->node_size());
+ for (int n = 0; n < gdef->node_size(); ++n) {
+ const NodeDef& ndef = gdef->node(n);
+ int64 start_time;
+ status = GetNodeAttr(ndef, "_start_time", &start_time);
+ if (!status.ok()) {
+ return status;
+ }
+ start_times[n] = std::make_pair(&ndef, start_time);
+ }
+
+ // Sort the nodes based on their start times.
+ std::sort(
+ start_times.begin(), start_times.end(),
+ [](NodeStartTime x, NodeStartTime y) { return x.second < y.second; });
+
+ // Add a dummy node for every epoch, and add a control edge from the
+ // "last" node in the preceding epoch to the dummy node.
+ string device_name = gdef->node(0).device();
+ int64 makespan = start_times.back().second;
+ int64 resolution = (makespan / num_epochs) + 1;
+
+ int i = 0;
+ int j = 0;
+ std::vector<NodeDef*> dummys;
+ while (i < num_epochs && static_cast<size_t>(j) < start_times.size()) {
+ if (i * resolution > start_times[j].second) {
+ j++;
+ } else {
+ NodeDef* dummy = AddControlTrigger(opts, gdef, device_name, i,
+ i * resolution, &status);
+ if (!status.ok()) {
+ return status;
+ }
+ dummys.push_back(dummy);
+ if (j > 0) {
+ string src_name = start_times[j - 1].first->name();
+ AddInput(dummy, src_name, Graph::kControlSlot);
+ }
+ i++;
+ }
+ }
+
+ // Finally, add the control edges to recvs.
+ for (int n = 0; n < gdef->node_size(); ++n) {
+ NodeDef* ndef = gdef->mutable_node(n);
+ if (ndef->op() == "_Recv") {
+ int64 start_time;
+ status = GetNodeAttr(*ndef, "_start_time", &start_time);
+ if (!status.ok()) {
+ return status;
+ }
+ int recv_epoch = start_time / resolution;
+ if (recv_epoch >= prefetch) {
+ NodeDef* dummy = dummys[recv_epoch - prefetch];
+ AddInput(ndef, dummy->name(), Graph::kControlSlot);
+ }
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status Partition(const PartitionOptions& opts, Graph* g,
+ std::unordered_map<string, GraphDef>* partitions) {
+ Status status;
+ partitions->clear();
+
+ GraphInfo g_info;
+ if (!opts.control_flow_added) {
+ // Add the "code" for distributed execution of control flow. Code is
+ // added only for the frames that are placed on multiple devices. The
+ // new graph is an equivalent transformation of the original graph and
+ // has the property that it can be subsequently partitioned arbitrarily
+ // (down to the level of individual device) for distributed execution.
+ status = AddControlFlow(opts, g, &g_info);
+ if (!status.ok()) return status;
+ }
+ // At this point, all the graph mutations have been done. Build memory
+ // and device type info for every node and edge in the graph.
+ status = BuildMemoryDeviceInfo(*g, &g_info);
+ if (!status.ok()) return status;
+
+ string dstp;
+ std::vector<const Edge*> inputs;
+ DupRecvTable dup_recv(3);
+ // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref
+ // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref
+ // edge to dst. We will add a control edge for every pair in
+ // (ref_recvs x ref_control_inputs).
+ std::vector<NodeDef*> ref_recvs;
+ std::vector<string> ref_control_inputs;
+
+ int32 num_data = 0;
+ int32 num_control = 0;
+ for (const Node* dst : g->nodes()) {
+ if (!dst->IsOp()) continue; // Skip Sink/Source nodes.
+
+ dstp = opts.node_to_loc(dst);
+ GraphDef* dst_graph = &(*partitions)[dstp];
+ NodeDef* dst_def = dst_graph->add_node();
+ *dst_def = dst->def();
+ dst_def->set_device(dst->assigned_device_name());
+ dst_def->clear_input(); // Inputs are filled below
+ if (opts.need_to_record_start_times) {
+ int64 start_time = opts.start_times[dst->id()].value();
+ AddNodeAttr("_start_time", start_time, dst_def);
+ }
+
+ // Arrange the incoming edges to dst so that input[i] holds the
+ // input flowing into slot numbered i. Trailing entries in input[]
+ // hold control edges.
+ inputs.clear();
+ inputs.resize(dst->num_inputs(), nullptr);
+ ref_recvs.clear();
+ ref_control_inputs.clear();
+ const Edge* control_flow_edge = nullptr;
+ for (const Edge* edge : dst->in_edges()) {
+ if (edge->IsControlEdge()) {
+ if (IsMerge(edge->src()) && IsControlLoop(edge->src())) {
+ // This is one of the control edges added for control flow. There
+ // can be multiple such edges as the dest node may have multiple
+ // remote inputs. We will just take one and ignore the others.
+ control_flow_edge = edge;
+ } else {
+ inputs.push_back(edge);
+ }
+ } else {
+ DCHECK(inputs[edge->dst_input()] == nullptr);
+ inputs[edge->dst_input()] = edge;
+ }
+ }
+
+ // Process in order so that all data edges are added as inputs to
+ // dst in Edge::dst_input() order.
+ bool recv_added = false;
+ for (const Edge* edge : inputs) {
+ const Node* src = edge->src();
+ if (!src->IsOp()) continue; // Skip Sink/Source nodes.
+
+ GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)];
+ if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) {
+ // Same partition and compatible memory types:
+ AddInput(dst_def, src->name(), edge->src_output());
+ if (edge->IsControlEdge() ||
+ !IsRefType(src->output_type(edge->src_output()))) {
+ ref_control_inputs.push_back(src->name());
+ }
+ continue;
+ }
+
+ int64 send_start_time = 0;
+ int64 recv_start_time = 0;
+ if (opts.scheduling_for_recvs) {
+ if (opts.need_to_record_start_times) {
+ send_start_time = opts.start_times[src->id()].value();
+ recv_start_time = opts.start_times[dst->id()].value();
+ } else {
+ status = GetNodeAttr(src->def(), "_start_time", &send_start_time);
+ if (!status.ok()) {
+ return status;
+ }
+ status = GetNodeAttr(dst->def(), "_start_time", &recv_start_time);
+ if (!status.ok()) {
+ return status;
+ }
+ }
+ }
+
+ // Check whether there is already a send/recv pair transferring
+ // the same tensor/control from the src to dst partition.
+ const bool on_host = IsDstInputOnHost(edge, g_info);
+ DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
+ auto iter = dup_recv.find(key);
+ if (iter != dup_recv.end()) {
+ // We found one. Reuse the data/control transferred already.
+ const string& recv_node_name = iter->second.recv->name();
+ if (edge->IsControlEdge()) {
+ AddInput(dst_def, recv_node_name, Graph::kControlSlot);
+ } else {
+ AddInput(dst_def, recv_node_name, 0);
+ }
+ // We want the start_time for the recv to be the smallest of the start
+ // times of it's consumers. So we update this whenever we use a recv,
+ // and write it out to the attribute at the end of the subroutine
+ if (iter->second.start_time > recv_start_time) {
+ iter->second.start_time = recv_start_time;
+ }
+ continue;
+ }
+
+ NodeDefBuilder::NodeOut send_from;
+ if (edge->IsControlEdge()) {
+ // Insert a dummy const node that will generate a tiny
+ // data element to be sent from send to recv.
+ VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "["
+ << src->name() << "] -> " << dst->assigned_device_name() << "["
+ << dst->name() << "]";
+ NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status);
+ if (!status.ok()) return status;
+ // Set the start time for this dummy node.
+ if (opts.scheduling_for_recvs) {
+ AddNodeAttr("_start_time", send_start_time, dummy);
+ }
+ AddInput(dummy, src->name(), Graph::kControlSlot);
+ send_from.Reset(dummy->name(), 0, DT_FLOAT);
+ } else {
+ send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
+ }
+
+ // Need to split edge by placing matching send/recv nodes on
+ // the src/dst sides of the edge.
+ NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from,
+ send_start_time, &status);
+ if (!status.ok()) return status;
+
+ NodeDef* real_recv = nullptr;
+ NodeDef* recv =
+ AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status);
+ if (!status.ok()) return status;
+
+ // Fix up the control flow edge. Redirect it to the recv.
+ // NOTE(yuanbyu): 'real_recv' must be the real recv node.
+ recv_added = true;
+ if (control_flow_edge != nullptr) {
+ AddInput(real_recv, control_flow_edge->src()->name(),
+ Graph::kControlSlot);
+ }
+
+ // For same device send/recv, add a control edge from send to recv.
+ // This prevents the asynchronous recv kernel from being scheduled
+ // immediately.
+ if (src_graph == dst_graph) {
+ AddInput(real_recv, send->name(), Graph::kControlSlot);
+ }
+
+ if (!edge->IsControlEdge() &&
+ IsRefType(src->output_type(edge->src_output()))) {
+ // If src is of ref type and the edge is not a control edge, dst has
+ // read semantics and therefore we must control the recv.
+ ref_recvs.push_back(real_recv);
+ } else {
+ // Memorize the send/recv pair, only if this is not a "ref" edge.
+ // NOTE(yuanbyu): Collapsing ref edges requires extreme care so
+ // for now we don't do it.
+ dup_recv[key] = {recv, real_recv, recv_start_time};
+ ref_control_inputs.push_back(recv->name());
+ }
+
+ if (edge->IsControlEdge()) {
+ ++num_control;
+ AddInput(dst_def, recv->name(), Graph::kControlSlot);
+ } else {
+ ++num_data;
+ AddInput(dst_def, recv->name(), 0);
+ }
+ }
+
+ // Add control edges from 'ref_control_inputs' to 'ref_recvs'.
+ // NOTE(yuanbyu): Adding these control edges should not introduce
+ // deadlocks. 'dst' has implicit "read" nodes that, when we split
+ // across devices, are made explicit; Retargettig the dependencies
+ // to 'dst' to those nodes would not introduce cycles if there isn't
+ // one before the transformation.
+ // NOTE(yuanbyu): This may impact performance because it defers the
+ // execution of recvs until all the other inputs become available.
+ AddReadControl(ref_recvs, ref_control_inputs);
+
+ // Add back this control edge for control flow if not used.
+ if (!recv_added && (control_flow_edge != nullptr)) {
+ AddInput(dst_def, control_flow_edge->src()->name(), Graph::kControlSlot);
+ }
+ }
+
+ // Set the start times for recvs at the very end.
+ if (opts.scheduling_for_recvs) {
+ for (auto& it : dup_recv) {
+ AddNodeAttr("_start_time", it.second.start_time, it.second.recv);
+ if (it.second.real_recv != it.second.recv) {
+ AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv);
+ }
+ }
+ }
+
+ VLOG(1) << "Added send/recv: controls=" << num_control
+ << ", data=" << num_data;
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph_partition.h b/tensorflow/core/graph/graph_partition.h
new file mode 100644
index 0000000000..eb88ff71b1
--- /dev/null
+++ b/tensorflow/core/graph/graph_partition.h
@@ -0,0 +1,77 @@
+#ifndef TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
+#define TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
+
+#include <functional>
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/costmodel.h"
+
+namespace tensorflow {
+
+struct PartitionOptions {
+ // A function that returns a location for the execution of a given
+ // Node.
+ typedef std::function<string(const Node*)> NodeToLocFunc;
+ NodeToLocFunc node_to_loc = nullptr;
+
+ // A function that returns a unique graph node name with the given
+ // prefix.
+ typedef std::function<string(const string&)> NewNameFunc;
+ NewNameFunc new_name = nullptr;
+
+ // A function that returns the incarnation of a device given the
+ // device's fullname. If not found, GetIncarnationFunc should return
+ // kIlledgalIncarnation.
+ static const uint64 kIllegalIncarnation = 0;
+ typedef std::function<uint64(const string&)> GetIncarnationFunc;
+ GetIncarnationFunc get_incarnation = nullptr;
+
+ // True if all the control flow "code" has already been added. The
+ // control flow code needs to be added when we still have the entire
+ // graph before any partitioning. So this flag should be false for
+ // the first partitioning but true for all subsequent partitioning.
+ //
+ // TODO(yuanbyu): We could also make the addition of the control
+ // flow code incremental based on 'node_to_loc'. This makes the
+ // communication a broadcast tree, which could be more efficient when
+ // the number of participating devices is large.
+ bool control_flow_added;
+
+ // A function that returns the data type into which the tensor
+ // should be cast before sent over the wire.
+ typedef std::function<DataType(const Edge*)> ShouldCastFunc;
+ ShouldCastFunc should_cast = nullptr;
+
+ // Schedule the execution of the recvs based on their start times
+ // computed by some scheduling algorithm. The recvs are divided into
+ // epochs based on their start times. A recv is enabled only when
+ // execution reaches its epoch - N for some predefined N.
+ bool scheduling_for_recvs = false;
+ // The start time for each node in the graph computed by some scheduling
+ // algorithm. If 'need_to_record_start_times' is true, we record them
+ // in the graph as a node attribute.
+ bool need_to_record_start_times = false;
+ std::vector<Microseconds> start_times;
+};
+
+// Partition "input" graph into a set of graphs, one per location.
+// The location for node n is derived by calling opts.node_to_loc(n).
+// New nodes added by Partition use "opts.new_name(old_name)" to
+// generate node names.
+//
+// Stores the partitions in *partitions.
+Status Partition(const PartitionOptions& opts, Graph* input,
+ std::unordered_map<string, GraphDef>* partitions);
+
+// Add control edges to the partitions to control the ordering
+// and timing of the recv nodes based on the start times calculated
+// using some scheduling algorithm.
+Status AddControlEdges(const PartitionOptions& opts,
+ std::unordered_map<string, GraphDef>* partitions);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc
new file mode 100644
index 0000000000..d912c94025
--- /dev/null
+++ b/tensorflow/core/graph/graph_partition_test.cc
@@ -0,0 +1,316 @@
+#include "tensorflow/core/graph/graph_partition.h"
+
+#include <unordered_map>
+
+#include <gtest/gtest.h>
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/control_flow_ops.h"
+#include "tensorflow/cc/ops/random_ops.h"
+#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/equal_graph_def.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace {
+
+const char gpu_device[] = "/job:a/replica:0/task:0/gpu:0";
+
+string SplitByDevice(const Node* node) { return node->assigned_device_name(); }
+
+string DeviceName(const Node* node) {
+ char first = node->name()[0];
+ if (first == 'G') {
+ return gpu_device;
+ } else {
+ const string cpu_prefix = "/job:a/replica:0/task:0/cpu:";
+ int index = first - 'A';
+ return strings::StrCat(cpu_prefix, index);
+ }
+}
+
+void Partition(const GraphDef& graph_def,
+ std::unordered_map<string, GraphDef>* partitions) {
+ Graph g(OpRegistry::Global());
+ GraphConstructorOptions opts;
+ TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &g));
+
+ // Assigns devices to each node. Uses 1st letter of the node name as
+ // the device index.
+ for (Node* node : g.nodes()) {
+ node->set_assigned_device_name(DeviceName(node));
+ }
+
+ PartitionOptions popts;
+ popts.node_to_loc = SplitByDevice;
+ popts.new_name = [&g](const string& prefix) { return g.NewName(prefix); };
+ popts.get_incarnation = [](const string& name) {
+ return (name[0] - 'A') + 100;
+ };
+ popts.control_flow_added = false;
+ Status s = Partition(popts, &g, partitions);
+ CHECK(s.ok()) << s;
+}
+
+void CheckLoopConstruction(const GraphDef& graph_def) {
+ std::unordered_map<string, GraphDef> partitions;
+ Partition(graph_def, &partitions);
+ GraphConstructorOptions opts;
+ for (const auto& kv : partitions) {
+ const GraphDef& gdef = kv.second;
+ bool has_control_enter = false;
+ bool has_control_merge = false;
+ bool has_control_switch = false;
+ bool has_control_next = false;
+ for (const NodeDef& ndef : gdef.node()) {
+ // _recvs must have a control input
+ if (ndef.op() == "_Recv") {
+ bool has_control = false;
+ for (const string& input_name : ndef.input()) {
+ if (StringPiece(input_name).starts_with("^")) {
+ has_control = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(has_control);
+ }
+ // Must have a control loop
+ if (StringPiece(ndef.name()).starts_with("_cloop")) {
+ if (ndef.op() == "Enter") {
+ has_control_enter = true;
+ }
+ if (ndef.op() == "Merge") {
+ has_control_merge = true;
+ }
+ if (ndef.op() == "Switch") {
+ has_control_switch = true;
+ }
+ if (ndef.op() == "NextIteration") {
+ has_control_next = true;
+ }
+ }
+ }
+ EXPECT_TRUE(has_control_enter);
+ EXPECT_TRUE(has_control_merge);
+ EXPECT_TRUE(has_control_switch);
+ EXPECT_TRUE(has_control_next);
+ }
+}
+
+REGISTER_OP("Input").Output("o: float");
+REGISTER_OP("BoolInput").Output("o: bool");
+REGISTER_OP("Cross").Input("a: float").Input("b: float").Output("o: float");
+
+Node* Input(const GraphDefBuilder::Options& opts) {
+ return ops::SourceOp("Input", opts);
+}
+
+Node* BoolInput(const GraphDefBuilder::Options& opts) {
+ return ops::SourceOp("BoolInput", opts);
+}
+
+Node* Cross(ops::NodeOut a, ops::NodeOut b,
+ const GraphDefBuilder::Options& opts) {
+ return ops::BinaryOp("Cross", a, b, opts);
+}
+
+class GraphPartitionTest : public ::testing::Test {
+ protected:
+ GraphPartitionTest()
+ : in_(GraphDefBuilder::kFailImmediately),
+ builder_a_(GraphDefBuilder::kFailImmediately),
+ builder_b_(GraphDefBuilder::kFailImmediately),
+ a_opts_(builder_a_.opts().WithDevice("/job:a/replica:0/task:0/cpu:0")),
+ b_opts_(builder_b_.opts().WithDevice("/job:a/replica:0/task:0/cpu:1")) {
+ RequireDefaultOps();
+ }
+
+ const GraphDef& ToGraphDef() {
+ in_.ToGraphDef(&in_graph_def_);
+ return in_graph_def_;
+ }
+
+ void ExpectMatchA() {
+ GraphDef graph_def;
+ builder_a_.ToGraphDef(&graph_def);
+ string a = "/job:a/replica:0/task:0/cpu:0";
+ TF_EXPECT_GRAPH_EQ(graph_def, partitions_[a]);
+ }
+
+ void ExpectMatchB() {
+ GraphDef graph_def;
+ builder_b_.ToGraphDef(&graph_def);
+ string b = "/job:a/replica:0/task:0/cpu:1";
+ TF_EXPECT_GRAPH_EQ(graph_def, partitions_[b]);
+ }
+
+ GraphDefBuilder in_;
+ GraphDef in_graph_def_;
+ GraphDefBuilder builder_a_;
+ GraphDefBuilder builder_b_;
+ GraphDefBuilder::Options a_opts_;
+ GraphDefBuilder::Options b_opts_;
+ std::unordered_map<string, GraphDef> partitions_;
+};
+
+TEST_F(GraphPartitionTest, SingleDevice) {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Node* a1 = Input(in_.opts().WithName("A1"));
+ Cross(a1, a1, in_.opts().WithName("A2"));
+
+ Partition(ToGraphDef(), &partitions_);
+ EXPECT_EQ(1, partitions_.size());
+
+ a1 = Input(a_opts_.WithName("A1"));
+ Cross(a1, a1, a_opts_.WithName("A2"));
+ ExpectMatchA();
+}
+
+TEST_F(GraphPartitionTest, CrossDeviceData) {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Node* a1 = Input(in_.opts().WithName("A1"));
+ Node* b1 = Input(in_.opts().WithName("B1"));
+ Cross(a1, b1, in_.opts().WithName("B2"));
+
+ Partition(ToGraphDef(), &partitions_);
+ EXPECT_EQ(2, partitions_.size());
+
+ string a = "/job:a/replica:0/task:0/cpu:0";
+ string b = "/job:a/replica:0/task:0/cpu:1";
+ a1 = Input(a_opts_.WithName("A1"));
+ _Send(a1, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_0"));
+ ExpectMatchA();
+
+ b1 = Input(b_opts_.WithName("B1"));
+ Node* recv =
+ _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_1"));
+ Cross(recv, b1, b_opts_.WithName("B2"));
+ ExpectMatchB();
+}
+
+TEST_F(GraphPartitionTest, CrossDeviceControl) {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Node* a1 = Input(in_.opts().WithName("A1"));
+ Node* b1 = Input(in_.opts().WithName("B1"));
+ Cross(b1, b1, in_.opts().WithName("B2").WithControlInput(a1));
+
+ Partition(ToGraphDef(), &partitions_);
+ EXPECT_EQ(2, partitions_.size());
+
+ string a = "/job:a/replica:0/task:0/cpu:0";
+ string b = "/job:a/replica:0/task:0/cpu:1";
+ a1 = Input(a_opts_.WithName("A1"));
+ Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1));
+ _Send(c, "edge_3_A1", a, 82, b, a_opts_.WithName("A1/_1"));
+ ExpectMatchA();
+
+ Node* recv =
+ _Recv(DT_FLOAT, "edge_3_A1", a, 82, b, b_opts_.WithName("A1/_2"));
+ Node* id = Identity(recv, b_opts_.WithName("A1/_3"));
+ b1 = Input(b_opts_.WithName("B1"));
+ Cross(b1, b1, b_opts_.WithName("B2").WithControlInput(id));
+ ExpectMatchB();
+}
+
+TEST_F(GraphPartitionTest, CrossDeviceData_MultiUse) {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Node* a1 = Input(in_.opts().WithName("A1"));
+ Node* b1 = Input(in_.opts().WithName("B1"));
+ Cross(a1, b1, in_.opts().WithName("B2"));
+ Cross(a1, a1, in_.opts().WithName("B3"));
+
+ Partition(ToGraphDef(), &partitions_);
+ EXPECT_EQ(2, partitions_.size());
+
+ string a = "/job:a/replica:0/task:0/cpu:0";
+ string b = "/job:a/replica:0/task:0/cpu:1";
+ a1 = Input(a_opts_.WithName("A1"));
+ _Send(a1, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_0"));
+ ExpectMatchA();
+
+ Node* recv =
+ _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_1"));
+ b1 = Input(b_opts_.WithName("B1"));
+ Cross(recv, b1, b_opts_.WithName("B2"));
+ Cross(recv, recv, b_opts_.WithName("B3"));
+ ExpectMatchB();
+}
+
+TEST_F(GraphPartitionTest, CrossDeviceControl_MultiUse) {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Node* a1 = Input(in_.opts().WithName("A1"));
+ Node* b1 = Input(in_.opts().WithName("B1"));
+ Cross(b1, b1, in_.opts().WithName("B2").WithControlInput(a1));
+ Input(in_.opts().WithName("B3").WithControlInput(a1));
+
+ Partition(ToGraphDef(), &partitions_);
+ EXPECT_EQ(2, partitions_.size());
+
+ string a = "/job:a/replica:0/task:0/cpu:0";
+ string b = "/job:a/replica:0/task:0/cpu:1";
+ a1 = Input(a_opts_.WithName("A1"));
+ Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1));
+ _Send(c, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_1"));
+ ExpectMatchA();
+
+ Node* recv =
+ _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_2"));
+ Node* id = Identity(recv, b_opts_.WithName("A1/_3"));
+ b1 = Input(b_opts_.WithName("B1"));
+ Cross(b1, b1, b_opts_.WithName("B2").WithControlInput(id));
+ Input(b_opts_.WithName("B3").WithControlInput(id));
+ ExpectMatchB();
+}
+
+TEST_F(GraphPartitionTest, CrossDevice_DataControl) {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Node* a1 = Input(in_.opts().WithName("A1"));
+ Node* b1 = Input(in_.opts().WithName("B1"));
+ Cross(a1, b1, in_.opts().WithName("B2"));
+ Input(in_.opts().WithName("B3").WithControlInput(a1));
+
+ Partition(ToGraphDef(), &partitions_);
+ EXPECT_EQ(2, partitions_.size());
+
+ string a = "/job:a/replica:0/task:0/cpu:0";
+ string b = "/job:a/replica:0/task:0/cpu:1";
+ a1 = Input(a_opts_.WithName("A1"));
+ Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1));
+ // NOTE: Send 0 A1/_1 -> A1/_2 is not necessarily needed. We could
+ // use A1/_0 -> A1/_4 as the control as a minor optimization.
+ _Send(c, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_1"));
+ _Send(a1, "edge_2_A1", a, 82, b, a_opts_.WithName("A1/_4"));
+ ExpectMatchA();
+
+ Node* recv1 =
+ _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_2"));
+ Node* id1 = Identity(recv1, b_opts_.WithName("A1/_3"));
+ Node* recv2 =
+ _Recv(DT_FLOAT, "edge_2_A1", a, 82, b, b_opts_.WithName("A1/_5"));
+ b1 = Input(b_opts_.WithName("B1"));
+ Cross(recv2, b1, b_opts_.WithName("B2"));
+ Input(b_opts_.WithName("B3").WithControlInput(id1));
+ ExpectMatchB();
+}
+
+TEST_F(GraphPartitionTest, CrossDeviceLoop) {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Node* a1 = BoolInput(in_.opts().WithName("A1"));
+ Node* a2 = Enter(a1, "foo", in_.opts().WithName("A2"));
+ Node* a3 = Merge({a2, {"A5", 0, DT_BOOL}}, in_.opts().WithName("A3"));
+ LoopCond(a3, in_.opts().WithName("A4"));
+ Node* b1 = Identity(a3, in_.opts().WithName("B1"));
+ NextIteration(b1, in_.opts().WithName("A5"));
+
+ CheckLoopConstruction(ToGraphDef());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc
new file mode 100644
index 0000000000..f7a8ffde89
--- /dev/null
+++ b/tensorflow/core/graph/graph_test.cc
@@ -0,0 +1,252 @@
+#include "tensorflow/core/graph/graph.h"
+
+#include <set>
+#include <gtest/gtest.h>
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+class GraphTest : public ::testing::Test {
+ protected:
+ GraphTest() : graph_(OpRegistry::Global()) { RequireDefaultOps(); }
+ ~GraphTest() override {}
+
+ static void VerifyNodes(Node* node, std::vector<Node*> expected_in,
+ std::vector<Node*> expected_out) {
+ std::vector<Node*> in;
+ for (const Edge* e : node->in_edges()) {
+ in.push_back(e->src());
+ }
+ EXPECT_EQ(Stringify(expected_in), Stringify(in));
+
+ std::vector<Node*> out;
+ for (const Edge* e : node->out_edges()) {
+ out.push_back(e->dst());
+ }
+ EXPECT_EQ(Stringify(expected_out), Stringify(out));
+ }
+
+ Node* AddNodeWithName(const string& name) {
+ Node* node;
+ TF_CHECK_OK(NodeBuilder(name, "NoOp").Finalize(&graph_, &node));
+ return node;
+ }
+
+ Graph graph_;
+
+ private:
+ // Convert a list of nodes to a sorted list of strings so failure messages
+ // are readable.
+ static std::vector<string> Stringify(const std::vector<Node*>& nodes) {
+ std::vector<string> result;
+ for (Node* n : nodes) {
+ result.push_back(n->DebugString());
+ }
+ std::sort(result.begin(), result.end());
+ return result;
+ }
+};
+
+TEST_F(GraphTest, Constructor) {
+ Node* source = graph_.source_node();
+ EXPECT_NE(source, nullptr);
+ Node* sink = graph_.sink_node();
+ EXPECT_NE(sink, nullptr);
+ VerifyNodes(source, {}, {sink});
+ VerifyNodes(sink, {source}, {});
+ EXPECT_EQ(2, graph_.num_node_ids());
+}
+
+TEST_F(GraphTest, RemoveThenAdd) {
+ AddNodeWithName("A");
+ Node* b = AddNodeWithName("B");
+ const int b_id = b->id();
+ AddNodeWithName("C");
+ EXPECT_EQ(5, graph_.num_node_ids());
+ graph_.RemoveNode(b);
+ EXPECT_EQ(5, graph_.num_node_ids());
+ Node* d = AddNodeWithName("D");
+ EXPECT_NE(b_id, d->id()); // Ids should not be reused.
+ EXPECT_EQ(6, graph_.num_node_ids());
+}
+
+TEST_F(GraphTest, InNodesAndOutNodes) {
+ Node* a = AddNodeWithName("A");
+ Node* b = AddNodeWithName("B");
+ Node* c = AddNodeWithName("C");
+ graph_.RemoveNode(b);
+ Node* d = AddNodeWithName("D");
+
+ const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a);
+ graph_.AddControlEdge(a, graph_.sink_node());
+ graph_.AddEdge(a, 0, c, 0);
+ graph_.AddControlEdge(c, graph_.sink_node());
+
+ EXPECT_EQ("A", a->name());
+ VerifyNodes(a, {graph_.source_node()}, {c, graph_.sink_node()});
+
+ EXPECT_EQ("C", c->name());
+ VerifyNodes(c, {a}, {graph_.sink_node()});
+
+ EXPECT_EQ("D", d->name());
+ VerifyNodes(d, {}, {});
+
+ VerifyNodes(graph_.source_node(), {}, {a, graph_.sink_node()});
+ VerifyNodes(graph_.sink_node(), {a, c, graph_.source_node()}, {});
+
+ graph_.RemoveEdge(source_to_a);
+ VerifyNodes(a, {}, {c, graph_.sink_node()});
+ VerifyNodes(graph_.source_node(), {}, {graph_.sink_node()}); // no more a
+
+ graph_.RemoveNode(c);
+ VerifyNodes(a, {}, {graph_.sink_node()}); // no more c
+ VerifyNodes(graph_.sink_node(), {a, graph_.source_node()}, {}); // no more c
+ EXPECT_EQ(6, graph_.num_node_ids());
+ EXPECT_EQ(5, graph_.num_edge_ids());
+}
+
+TEST_F(GraphTest, NodeIteration) {
+ // Set up the graph with some holes due to removals.
+ Node* a = AddNodeWithName("A");
+ Node* b = AddNodeWithName("B");
+ Node* c = AddNodeWithName("C");
+ graph_.RemoveNode(b);
+ Node* d = AddNodeWithName("D");
+ const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a);
+ graph_.AddControlEdge(a, graph_.sink_node());
+ graph_.AddEdge(a, 0, c, 0);
+ graph_.AddControlEdge(c, graph_.sink_node());
+ graph_.RemoveEdge(source_to_a);
+ graph_.RemoveNode(c);
+
+ // expected = set of all node DebugStrings we expect in the graph
+ std::set<string> expected;
+ expected.insert(graph_.source_node()->DebugString());
+ expected.insert(a->DebugString());
+ expected.insert(d->DebugString());
+ expected.insert(graph_.sink_node()->DebugString());
+
+ // Verify that iterating through ids gets the same set of nodes.
+ std::set<string> actual;
+ for (int id = 0; id < graph_.num_node_ids(); ++id) {
+ Node* node = graph_.FindNodeId(id);
+ if (node != nullptr) {
+ actual.insert(node->DebugString());
+ }
+ }
+ EXPECT_EQ(expected, actual);
+
+ // Verify that range-based for loop gets the same set of nodes.
+ actual.clear();
+ for (Node* node : graph_.nodes()) {
+ actual.insert(node->DebugString());
+ }
+ EXPECT_EQ(expected, actual);
+}
+
+static void CheckType(Node* node, bool b) {
+ EXPECT_TRUE(b) << node->DebugString();
+ // Make sure none of the other IsFoo() methods return true.
+ int count = 0;
+ if (node->IsSource()) count++;
+ if (node->IsSink()) count++;
+ if (node->IsOp()) count++;
+ EXPECT_EQ(1, count) << node->DebugString();
+}
+
+TEST_F(GraphTest, Type) {
+ Node* op = AddNodeWithName("A");
+ CheckType(graph_.source_node(), graph_.source_node()->IsSource());
+ CheckType(graph_.sink_node(), graph_.sink_node()->IsSink());
+ CheckType(op, op->IsOp());
+}
+
+// Convert edge iteration results into a sorted string.
+static string EdgeIter(const Graph& g) {
+ std::vector<std::pair<int, int> > edges;
+ for (const Edge* e : g.edges()) {
+ edges.push_back(std::make_pair(e->src()->id(), e->dst()->id()));
+ }
+ std::sort(edges.begin(), edges.end());
+ string result;
+ for (auto& p : edges) {
+ strings::StrAppend(&result, p.first, "->", p.second, ";");
+ }
+ return result;
+}
+
+TEST_F(GraphTest, EdgeIteration) {
+ EXPECT_EQ("0->1;", EdgeIter(graph_));
+
+ Node* a = AddNodeWithName("A");
+ Node* b = AddNodeWithName("B");
+ EXPECT_EQ("0->1;", EdgeIter(graph_)); // Since a,b are currently disconnected
+
+ graph_.AddEdge(a, 0, b, 0);
+ EXPECT_EQ("0->1;2->3;", EdgeIter(graph_));
+
+ graph_.AddControlEdge(graph_.source_node(), a);
+ graph_.AddControlEdge(b, graph_.sink_node());
+ EXPECT_EQ("0->1;0->2;2->3;3->1;", EdgeIter(graph_));
+
+ graph_.AddEdge(a, 1, a, 0);
+ EXPECT_EQ("0->1;0->2;2->2;2->3;3->1;", EdgeIter(graph_));
+}
+
+TEST_F(GraphTest, NewName) {
+ string a1 = graph_.NewName("A");
+ string a2 = graph_.NewName("A");
+ string b1 = graph_.NewName("B");
+ EXPECT_NE(a1, a2);
+ EXPECT_NE(a1, b1);
+ EXPECT_NE(a2, b1);
+ EXPECT_TRUE(StringPiece(a1).starts_with("A")) << a1;
+}
+
+REGISTER_OP("Input").Output("o: float");
+REGISTER_OP("In2Out1").Input("a: float").Input("b: float").Output("o: float");
+
+static void BM_InEdgeIteration(int iters, int num_nodes) {
+ testing::StopTiming();
+ string s;
+ for (int in = 0; in < 10; in++) {
+ s += strings::Printf("node { name: 'in%04d' op: 'Input' }", in);
+ }
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ for (int op = 0; op < num_nodes; op++) {
+ s += strings::Printf(
+ "node { name: 'op%04d' op: 'In2Out1' input: ['in%04d', 'in%04d' ] }",
+ op, rnd.Uniform(10), rnd.Uniform(10));
+ }
+
+ Graph graph(OpRegistry::Global());
+ GraphDef graph_def;
+ CHECK(protobuf::TextFormat::ParseFromString(s, &graph_def));
+ GraphConstructorOptions opts;
+ TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
+
+ int64 sum = 0;
+ testing::StartTiming();
+ for (int i = 0; i < iters; i += graph.num_node_ids()) {
+ for (const Node* node : graph.nodes()) {
+ for (auto e : node->in_edges()) {
+ sum += e->id();
+ }
+ }
+ }
+ VLOG(1) << sum;
+}
+BENCHMARK(BM_InEdgeIteration)->Range(10, 100000);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
new file mode 100644
index 0000000000..8c34323dbe
--- /dev/null
+++ b/tensorflow/core/graph/node_builder.cc
@@ -0,0 +1,115 @@
+#include "tensorflow/core/graph/node_builder.h"
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+NodeBuilder::NodeBuilder(const string& name, const string& op_name,
+ const OpRegistryInterface* op_registry)
+ : def_builder_(name, op_name, op_registry) {}
+
+NodeBuilder::NodeBuilder(const string& name, const OpDef* op_def)
+ : def_builder_(name, op_def) {}
+
+NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) {
+ inputs_.emplace_back(src_node, src_index);
+ DataType dt;
+ if (GetOutputType(src_node, src_index, &dt)) {
+ def_builder_.Input(src_node->name(), src_index, dt);
+ }
+ return *this;
+}
+
+NodeBuilder& NodeBuilder::Input(NodeOut src) {
+ if (src.error) {
+ AddIndexError(src.node, src.index);
+ } else {
+ inputs_.emplace_back(src.node, src.index);
+ def_builder_.Input(src.name, src.index, src.dt);
+ }
+ return *this;
+}
+
+NodeBuilder& NodeBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
+ std::vector<NodeDefBuilder::NodeOut> srcs;
+ srcs.reserve(src_list.size());
+ for (const auto& node_out : src_list) {
+ if (node_out.error) {
+ AddIndexError(node_out.node, node_out.index);
+ } else {
+ srcs.emplace_back(node_out.name, node_out.index, node_out.dt);
+ inputs_.emplace_back(node_out.node, node_out.index);
+ }
+ }
+ def_builder_.Input(srcs);
+ return *this;
+}
+
+NodeBuilder& NodeBuilder::ControlInput(Node* src_node) {
+ control_inputs_.emplace_back(src_node);
+ def_builder_.ControlInput(src_node->name());
+ return *this;
+}
+
+NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) {
+ control_inputs_.insert(control_inputs_.end(), src_nodes.begin(),
+ src_nodes.end());
+ for (Node* src_node : src_nodes) {
+ def_builder_.ControlInput(src_node->name());
+ }
+ return *this;
+}
+
+NodeBuilder& NodeBuilder::Device(const string& device_spec) {
+ def_builder_.Device(device_spec);
+ return *this;
+}
+
+Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
+ // In case of error, set *created_node to nullptr.
+ if (created_node != nullptr) *created_node = nullptr;
+ if (!errors_.empty()) {
+ return errors::InvalidArgument(str_util::Join(errors_, "\n"));
+ }
+
+ NodeDef node_def;
+ TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def));
+ TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def()));
+ Status status;
+ Node* node = graph->AddNode(node_def, &status);
+ if (!status.ok()) return status;
+
+ for (size_t i = 0; i < inputs_.size(); ++i) {
+ if (inputs_[i].node != nullptr) { // Skip back edges.
+ graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i);
+ }
+ }
+ for (Node* control_input : control_inputs_) {
+ graph->AddControlEdge(control_input, node);
+ }
+ if (created_node != nullptr) *created_node = node;
+ return Status::OK();
+}
+
+void NodeBuilder::AddIndexError(Node* node, int i) {
+ if (node == nullptr) {
+ errors_.emplace_back(
+ strings::StrCat("Attempt to add nullptr Node to node with type",
+ def_builder_.op_def().name()));
+ } else {
+ errors_.emplace_back(
+ strings::StrCat("Attempt to add output ", i, " of ", node->name(),
+ " not in range [0, ", node->num_outputs(),
+ ") to node with type ", def_builder_.op_def().name()));
+ }
+}
+
+bool NodeBuilder::GetOutputType(Node* node, int i, DataType* dt) {
+ bool error;
+ *dt = SafeGetOutput(node, i, &error);
+ if (error) AddIndexError(node, i);
+ return !error;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h
new file mode 100644
index 0000000000..dd34b97f23
--- /dev/null
+++ b/tensorflow/core/graph/node_builder.h
@@ -0,0 +1,146 @@
+#ifndef TENSORFLOW_GRAPH_NODE_BUILDER_H_
+#define TENSORFLOW_GRAPH_NODE_BUILDER_H_
+
+#include <vector>
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace tensorflow {
+
+// This is a helper for creating a Node and adding it to a Graph.
+// Internally, it uses a NodeDefBuilder to automatically set attrs
+// that can be inferred from the inputs, and use default values
+// (where they exist) for unspecified attrs. Example usage:
+//
+// Node* node;
+// Status status = NodeBuilder(node_name, op_name)
+// .Input(...)
+// .Attr(...)
+// .Finalize(&graph, &node);
+// if (!status.ok()) return status;
+// // Use node here.
+class NodeBuilder {
+ public:
+ // For specifying the output of a Node to provide to one of the Input()
+ // functions below. It supports both regular inputs (where you are
+ // connecting to an existing Node*), and inputs from outside the graph
+ // (or haven't been added to the graph yet, like back edges, where
+ // you don't have a Node*). Both types can be mixed, e.g. in an
+ // ArraySlice.
+ struct NodeOut {
+ // For referencing an existing Node.
+ NodeOut(Node* n, int i = 0) // NOLINT(runtime/explicit)
+ : node(n),
+ error(false),
+ name(node != nullptr ? node->name() : (error = true, "")),
+ index(i),
+ dt(SafeGetOutput(node, i, &error)) {}
+
+ // For referencing Nodes not in the graph being built. It is
+ // useful when preparing a graph for ExtendSession or creating a
+ // back edge to a node that hasn't been added to the graph yet,
+ // but will be.
+ NodeOut(const string& name, int i, DataType t)
+ : node(nullptr), error(false), name(name), index(i), dt(t) {}
+
+ // Default constructor for std::vector<NodeOut>.
+ NodeOut() {}
+
+ Node* node = nullptr;
+ // error is set to true if:
+ // * the NodeOut was default constructed and never overwritten,
+ // * a nullptr Node* was passed to the NodeOut constructor, or
+ // * an out-of-range index was passed to the NodeOut constructor.
+ bool error = true;
+ string name;
+ int index = 0;
+ DataType dt = DT_FLOAT;
+ };
+
+ // Specify the name and the Op (either via an OpDef or the name of
+ // the Op plus a registry) for the Node. Other fields are
+ // specified by calling the methods below.
+ // REQUIRES: The OpDef must satisfy ValidateOpDef().
+ NodeBuilder(const string& name, const string& op_name,
+ const OpRegistryInterface* op_registry = OpRegistry::Global());
+ NodeBuilder(const string& name, const OpDef* op_def);
+
+ // You must call one Input() function per input_arg in the Op,
+ // *and in the same order as the input_args appear in the OpDef.*
+
+ // For inputs that take a single tensor.
+ NodeBuilder& Input(Node* src_node, int src_index = 0);
+ NodeBuilder& Input(NodeOut src);
+
+ // For inputs that take a list of tensors.
+ NodeBuilder& Input(gtl::ArraySlice<NodeOut> src_list);
+
+ // Require that this node run after src_node(s).
+ NodeBuilder& ControlInput(Node* src_node);
+ NodeBuilder& ControlInputs(gtl::ArraySlice<Node*> src_nodes);
+
+ // Sets the "requested device spec" in the NodeDef (not the
+ // "assigned device" in the Node).
+ NodeBuilder& Device(const string& device_spec);
+
+ // Set the value of an attr. attr_name must match the name of one of
+ // attrs defined by the Op, and value must have the corresponding type
+ // (see SetAttrValue() in ../framework/attr_value_util.h for legal
+ // types for value). Note that attrs will be set automatically if
+ // they can be determined by the inputs.
+ template <class T>
+ NodeBuilder& Attr(const string& attr_name, T&& value);
+ template <class T>
+ NodeBuilder& Attr(const string& attr_name, std::initializer_list<T> value);
+
+ // Validates the described node and adds it to *graph, adding edges
+ // for all (non-back) inputs. If created_node is not nullptr,
+ // *created_node will be set to the new node (or nullptr on error).
+ Status Finalize(Graph* graph, Node** created_node) const;
+
+ private:
+ static DataType SafeGetOutput(Node* node, int i, bool* error) {
+ if (node != nullptr && i >= 0 && i < node->num_outputs()) {
+ *error = false;
+ return node->output_type(i);
+ } else {
+ *error = true;
+ return DT_FLOAT;
+ }
+ }
+
+ // If SafeGetOutput indicates a range error, add it to errors_.
+ void AddIndexError(Node* node, int i);
+
+ // Set *dt and returns true if i is in range. Combines
+ // SafeGetOutput() and AddIndexError().
+ bool GetOutputType(Node* node, int i, DataType* dt);
+
+ NodeDefBuilder def_builder_;
+ std::vector<NodeOut> inputs_;
+ std::vector<Node*> control_inputs_;
+ std::vector<string> errors_;
+};
+
+// IMPLEMENTATION -------------------------------------------------------------
+
+template <class T>
+inline NodeBuilder& NodeBuilder::Attr(const string& attr_name, T&& value) {
+ def_builder_.Attr(attr_name, std::forward<T>(value));
+ return *this;
+}
+
+template <class T>
+NodeBuilder& NodeBuilder::Attr(const string& attr_name,
+ std::initializer_list<T> value) {
+ def_builder_.Attr(attr_name, value);
+ return *this;
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_NODE_BUILDER_H_
diff --git a/tensorflow/core/graph/node_builder_test.cc b/tensorflow/core/graph/node_builder_test.cc
new file mode 100644
index 0000000000..9f667d00e4
--- /dev/null
+++ b/tensorflow/core/graph/node_builder_test.cc
@@ -0,0 +1,59 @@
+#include "tensorflow/core/graph/node_builder.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+REGISTER_OP("Source").Output("o: out_types").Attr("out_types: list(type)");
+REGISTER_OP("Sink").Input("i: T").Attr("T: type");
+
+TEST(NodeBuilderTest, Simple) {
+ RequireDefaultOps();
+ Graph graph(OpRegistry::Global());
+ Node* source_node;
+ EXPECT_OK(NodeBuilder("source_op", "Source")
+ .Attr("out_types", {DT_INT32, DT_STRING})
+ .Finalize(&graph, &source_node));
+ ASSERT_TRUE(source_node != nullptr);
+
+ // Try connecting to each of source_node's outputs.
+ EXPECT_OK(NodeBuilder("sink1", "Sink")
+ .Input(source_node)
+ .Finalize(&graph, nullptr));
+ EXPECT_OK(NodeBuilder("sink2", "Sink")
+ .Input(source_node, 1)
+ .Finalize(&graph, nullptr));
+
+ // Generate an error if the index is out of range.
+ EXPECT_FALSE(NodeBuilder("sink3", "Sink")
+ .Input(source_node, 2)
+ .Finalize(&graph, nullptr)
+ .ok());
+ EXPECT_FALSE(NodeBuilder("sink4", "Sink")
+ .Input(source_node, -1)
+ .Finalize(&graph, nullptr)
+ .ok());
+ EXPECT_FALSE(NodeBuilder("sink5", "Sink")
+ .Input({source_node, -1})
+ .Finalize(&graph, nullptr)
+ .ok());
+
+ // Generate an error if the node is nullptr. This can happen when using
+ // GraphDefBuilder if there was an error creating the input node.
+ EXPECT_FALSE(NodeBuilder("sink6", "Sink")
+ .Input(nullptr)
+ .Finalize(&graph, nullptr)
+ .ok());
+ EXPECT_FALSE(NodeBuilder("sink7", "Sink")
+ .Input(NodeBuilder::NodeOut(nullptr, 0))
+ .Finalize(&graph, nullptr)
+ .ok());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/optimizer_cse.cc b/tensorflow/core/graph/optimizer_cse.cc
new file mode 100644
index 0000000000..2fa6f075c0
--- /dev/null
+++ b/tensorflow/core/graph/optimizer_cse.cc
@@ -0,0 +1,220 @@
+// This module implements a common subexpression elimination pass. We
+// process the nodes in the graph in reverse postorder
+// (i.e. inputs before their downstream dependencies). The rough algorithm is
+// as follows:
+//
+// std::unordered_map<size_t, Node*> available
+// for each node n in forward topological order:
+// h = NodeHash(n)
+// if available[h] exists and Equivalent(available(h), h)
+// redirect downstream uses of outputs of n to available[h]
+// remove n from graph
+// else
+// if available[h] does not exist
+// available[h] = n
+//
+// This is similar to the global value number algorithm describe in this
+// paper:
+//
+// "Global code motion/global value numbering", Cliff Click, PLDI '95
+// Proceedings of the ACM SIGPLAN 1995 conference on Programming
+// language design and implementation, Pages 246-257
+// http://dl.acm.org/citation.cfm?id=207154
+
+#include "tensorflow/core/graph/optimizer_cse.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+class OptimizerCSE {
+ public:
+ explicit OptimizerCSE(Graph* g) : g_(g) {}
+
+ void Optimize(std::function<bool(const Node*)> consider_fn);
+
+ private:
+ struct Scratch;
+
+ static size_t NodeHash(const Node* n);
+ static bool Equivalent(const Node* a, const Node* b, Scratch* s);
+ static bool EqualAttrs(const Node* a, const Node* b, Scratch* s);
+
+ Graph* g_;
+};
+
+static void FillInputs(const Node* n,
+ gtl::InlinedVector<Node*, 4>* control_edges,
+ gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
+ DCHECK_EQ(in->size(), n->num_inputs());
+ control_edges->clear();
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) {
+ control_edges->push_back(e->src());
+ } else {
+ (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
+ }
+ }
+ std::sort(control_edges->begin(), control_edges->end());
+ if (n->op_def().is_commutative()) {
+ // For commutative inputs, we sort the input by the input Node*
+ // to get a canonical ordering (so that add(a,b) and add(b, a) will
+ // hash to the same value if is_commutative is true for 'add').
+ std::sort(in->begin(), in->end());
+ }
+}
+
+static size_t kIllegalNodeHash = 0;
+
+size_t OptimizerCSE::NodeHash(const Node* n) {
+ const DataTypeVector& out = n->output_types();
+ string str_to_hash = strings::StrCat(n->type_string(), out.size());
+ for (DataType dt : out) {
+ strings::StrAppend(&str_to_hash, dt);
+ }
+
+ const int N_in = n->num_inputs();
+ strings::StrAppend(&str_to_hash, N_in);
+ gtl::InlinedVector<Node*, 4> control_edges;
+ gtl::InlinedVector<std::pair<Node*, int>, 4> in(N_in);
+ FillInputs(n, &control_edges, &in);
+ for (const auto& edge : in) {
+ strings::StrAppend(&str_to_hash, edge.first->id(), edge.second);
+ }
+
+ size_t h = Hash64(str_to_hash);
+
+#if !defined(__ANDROID__) && !defined(ANDROID)
+ // Hash the attrs. For example, this makes sure different constants
+ // end up in different hash buckets.
+ string tmp;
+ for (const auto& attr : n->def().attr()) {
+ tmp = attr.first;
+ attr.second.AppendToString(&tmp);
+ // Add hashes of attrs, so the order of attrs doesn't matter.
+ h += Hash32(tmp.data(), tmp.size(), 0x87341245);
+ }
+#endif
+
+ if (h == kIllegalNodeHash) h = kIllegalNodeHash + 1;
+ return h;
+}
+
+struct OptimizerCSE::Scratch {
+ // For EqualAttrs():
+ string a;
+ string b;
+};
+
+bool OptimizerCSE::EqualAttrs(const Node* a, const Node* b, Scratch* scratch) {
+ if (a->def().attr_size() != b->def().attr_size()) return false;
+
+ for (const auto& attr : b->def().attr()) {
+ auto iter = a->def().attr().find(attr.first);
+ if (iter == a->def().attr().end()) return false;
+ // Note: it should be safe to compare proto serializations of the attr
+ // values since at most one field should be set in each (indeed, it
+ // should be the same field).
+ iter->second.SerializeToString(&scratch->a);
+ attr.second.SerializeToString(&scratch->b);
+ if (scratch->a != scratch->b) return false;
+ }
+ return true;
+}
+
+static bool HasRefInput(const Node* n) {
+ for (auto dt : n->input_types()) {
+ if (IsRefType(dt)) return true;
+ }
+ return false;
+}
+
+bool OptimizerCSE::Equivalent(const Node* a, const Node* b, Scratch* scratch) {
+ // Different op names are different
+ if (a->type_string() != b->type_string()) return false;
+
+ // Never consider stateful nodes (such as non-const inputs) equivalent.
+ if (a->op_def().is_stateful()) return false;
+
+ // For now, we consider any node that takes a ref input to not be
+ // equivalent to any other node.
+ if (HasRefInput(a) || HasRefInput(b)) return false;
+
+ // Compare attrs. Note that equal attrs implies equal input and
+ // output types.
+ if (!EqualAttrs(a, b, scratch)) return false;
+
+ // Compare input sources
+ if (a->num_inputs() != b->num_inputs()) return false;
+ const int N_in = a->num_inputs();
+ gtl::InlinedVector<Node*, 4> a_control_edges;
+ gtl::InlinedVector<Node*, 4> b_control_edges;
+ gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
+ gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in);
+ FillInputs(a, &a_control_edges, &a_in);
+ FillInputs(b, &b_control_edges, &b_in);
+ if (a_in != b_in) return false;
+ if (a_control_edges != b_control_edges) return false;
+
+ return true;
+}
+
+void OptimizerCSE::Optimize(std::function<bool(const Node*)> consider_fn) {
+ // This very simple implementation works if the whole graph is one
+ // giant basic block (because we just traverse nodes in a
+ // topological order). We'll need to do something more
+ // sophisticated when we have control flow/loops/etc.
+
+ // TODO(jeff): We need to handle Update nodes specially, but dealing
+ // with more general control flow will also solve this issue, and for
+ // now, our updates are almost always the most downstream nodes in
+ // the graph.
+ std::vector<Node*> order;
+ GetReversePostOrder(*g_, &order);
+
+ // Our value is just a single Node*, meaning we keep just a single
+ // candidate for a given node hash value. This may cause us to
+ // (rarely) lose some optimization opportunities if there are
+ // hash collisions, but it allows us to avoid having the value
+ // be a set<Node*> (or equivalent).
+ std::unordered_map<size_t, Node*> available;
+
+ // Scratch space for Equivalent calls. Allocated here and passed in to
+ // Equivalent to avoid allocation inside the loop below.
+ Scratch scratch;
+ for (Node* n : order) {
+ if (!n->IsOp()) continue;
+
+ // See if we should consider this node at all
+ if (consider_fn != nullptr && !consider_fn(n)) continue;
+
+ size_t h = NodeHash(n);
+ Node** candidate = &available[h];
+ if (*candidate == nullptr) {
+ // No existing match: insert "n" into the hash table under "h"
+ *candidate = n;
+ } else if (Equivalent(*candidate, n, &scratch)) {
+ VLOG(1) << "CSE: equivalent: " << (*candidate)->name() << " and "
+ << n->name();
+ // *candidate and n are equivalent. Therefore, we can replace
+ // n with *candidate by fixing up outgoing edges from "n" to instead
+ // come from "*candidate", and then delete n from the graph
+ for (const Edge* e : n->out_edges()) {
+ g_->AddEdge(*candidate, e->src_output(), e->dst(), e->dst_input());
+ }
+ g_->RemoveNode(n);
+ }
+ }
+}
+
+void OptimizeCSE(Graph* g, std::function<bool(const Node*)> consider_fn) {
+ OptimizerCSE opt(g);
+ opt.Optimize(consider_fn);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/optimizer_cse.h b/tensorflow/core/graph/optimizer_cse.h
new file mode 100644
index 0000000000..430c97a449
--- /dev/null
+++ b/tensorflow/core/graph/optimizer_cse.h
@@ -0,0 +1,19 @@
+// An optimization pass that performs common subexpression elimination.
+
+#ifndef TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_
+#define TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_
+
+#include <sys/types.h>
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+
+// Perform common-subexpression elimination on the graph "*g". If
+// "consider_fn" is not nullptr, then only nodes for which
+// consider_fn(node) returns true will be considered for combining
+// during the common subexpression elimination.
+extern void OptimizeCSE(Graph* g, std::function<bool(const Node*)> consider_fn);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_
diff --git a/tensorflow/core/graph/optimizer_cse_test.cc b/tensorflow/core/graph/optimizer_cse_test.cc
new file mode 100644
index 0000000000..ebbb948fdc
--- /dev/null
+++ b/tensorflow/core/graph/optimizer_cse_test.cc
@@ -0,0 +1,365 @@
+#include "tensorflow/core/graph/optimizer_cse.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+namespace {
+
+static void InitGraph(const string& s, Graph* graph) {
+ GraphDef graph_def;
+
+ auto parser = protobuf::TextFormat::Parser();
+ // parser.AllowRelaxedWhitespace(true);
+ CHECK(parser.MergeFromString(s, &graph_def)) << s;
+ GraphConstructorOptions opts;
+ TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
+}
+
+class OptimizerCSETest : public ::testing::Test {
+ public:
+ OptimizerCSETest() : graph_(OpRegistry::Global()) { RequireDefaultOps(); }
+
+ void InitGraph(const string& s) {
+ ::tensorflow::InitGraph(s, &graph_);
+ original_ = CanonicalGraphString(&graph_);
+ }
+
+ static bool IncludeNode(const Node* n) { return n->IsOp(); }
+
+ static string EdgeId(const Node* n, int index) {
+ if (index == 0) {
+ return n->name();
+ } else if (index == Graph::kControlSlot) {
+ return strings::StrCat(n->name(), ":control");
+ } else {
+ return strings::StrCat(n->name(), ":", index);
+ }
+ }
+
+ string CanonicalGraphString(Graph* g) {
+ std::vector<string> nodes;
+ std::vector<string> edges;
+ for (const Node* n : g->nodes()) {
+ if (IncludeNode(n)) {
+ nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")"));
+ }
+ }
+ for (const Edge* e : g->edges()) {
+ if (IncludeNode(e->src()) && IncludeNode(e->dst())) {
+ edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->",
+ EdgeId(e->dst(), e->dst_input())));
+ }
+ }
+ // Canonicalize
+ std::sort(nodes.begin(), nodes.end());
+ std::sort(edges.begin(), edges.end());
+ return strings::StrCat(str_util::Join(nodes, ";"), "|",
+ str_util::Join(edges, ";"));
+ }
+
+ string DoCSE(std::function<bool(const Node*)> consider_fn = nullptr) {
+ string before = CanonicalGraphString(&graph_);
+ LOG(ERROR) << "Before rewrites: " << before;
+
+ OptimizeCSE(&graph_, consider_fn);
+
+ string result = CanonicalGraphString(&graph_);
+ LOG(ERROR) << "After rewrites: " << result;
+ return result;
+ }
+
+ const string& OriginalGraph() const { return original_; }
+
+ Graph graph_;
+ string original_;
+};
+
+REGISTER_OP("Input").Output("o: float").SetIsStateful();
+
+// Note that the "rules" in these tests are not meant to be logically correct
+TEST_F(OptimizerCSETest, Simple) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoCSE(),
+ "A(Input);B(Input);D(Mul)|"
+ "A->D;B->D:1");
+}
+
+TEST_F(OptimizerCSETest, Simple_ThreeEquivalent) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoCSE(),
+ "A(Input);B(Input);E(Mul)|"
+ "A->E;B->E:1");
+}
+
+TEST_F(OptimizerCSETest, Simple_WithFixups) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['C', 'D'] }");
+ EXPECT_EQ(DoCSE(),
+ "A(Input);B(Input);D(Mul);E(Mul)|"
+ "A->D;B->D:1;D->E;D->E:1");
+}
+
+TEST_F(OptimizerCSETest, Simple_Commutative) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['B', 'A'] }");
+ EXPECT_EQ(DoCSE(),
+ "A(Input);B(Input);D(Mul)|"
+ "A->D:1;B->D");
+}
+
+static bool IsNotMultiply(const Node* n) { return n->type_string() != "Mul"; }
+
+// Like Simple_Commutative,
+TEST_F(OptimizerCSETest, Simple_Filtered) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['B', 'A'] }");
+ EXPECT_EQ(DoCSE(IsNotMultiply), OriginalGraph());
+}
+
+TEST_F(OptimizerCSETest, Simple_NotCommutative) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['B', 'A'] }");
+ EXPECT_EQ(DoCSE(), OriginalGraph());
+}
+
+TEST_F(OptimizerCSETest, NotEquivalent_Ops) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoCSE(), OriginalGraph());
+}
+
+TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs1) {
+ // Should still do CSE for ops with attrs if they match.
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] attr { key: 'shape'"
+ " value { shape: { dim: { size: 37 name: 'SAME_NAME' } } } } }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] attr { key: 'shape'"
+ " value { shape: { dim: { size: 37 name: 'SAME_NAME' } } } } }");
+ EXPECT_EQ(DoCSE(),
+ "A(Input);B(Input);D(Mul)|"
+ "A->D;B->D:1");
+}
+
+TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs2) {
+ // Should still do CSE for ops with attrs if they match, even if they
+ // are not in the same order.
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B']"
+ " attr { key: 'a' value { i: 3 } }"
+ " attr { key: 't' value { type: DT_INT32 } } }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B']"
+ " attr { key: 't' value { type: DT_INT32 } }"
+ " attr { key: 'a' value { i: 3 } } }");
+ EXPECT_EQ(DoCSE(),
+ "A(Input);B(Input);D(Mul)|"
+ "A->D;B->D:1");
+}
+
+TEST_F(OptimizerCSETest, SameConstants) {
+ // Should still do CSE for ops with constants if the values are identical
+ InitGraph(
+ "node { name: 'A' op: 'Const' "
+ " attr { key: 'dtype' value { type: DT_INT32 } }"
+ " attr { key: 'value' value {"
+ " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+ " int_val: 0 } } } }"
+ "node { name: 'B' op: 'Const' "
+ " attr { key: 'dtype' value { type: DT_INT32 } }"
+ " attr { key: 'value' value {"
+ " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+ " int_val: 0 } } } }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_INT32 } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoCSE(),
+ "B(Const);D(Mul)|"
+ "B->D;B->D:1");
+}
+
+TEST_F(OptimizerCSETest, DifferentConstants) {
+ // Should still do CSE for ops with extensions if the extensions are identical
+ InitGraph(
+ "node { name: 'A' op: 'Const' "
+ " attr { key: 'dtype' value { type: DT_INT32 } }"
+ " attr { key: 'value' value {"
+ " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+ " int_val: 0 } } } }"
+ "node { name: 'B' op: 'Const' "
+ " attr { key: 'dtype' value { type: DT_INT32 } }"
+ " attr { key: 'value' value {"
+ " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+ " int_val: 100000 } } } }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_INT32 } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoCSE(),
+ "A(Const);B(Const);D(Mul)|"
+ "A->D;B->D:1");
+}
+
+TEST_F(OptimizerCSETest, SameOps_DifferentAttrs1) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B']"
+ " attr { key: 'a' value { i: 3 } }"
+ " attr { key: 't' value { type: DT_INT32 } } }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B']"
+ " attr { key: 't' value { type: DT_INT32 } }"
+ " attr { key: 'a' value { i: 4 } } }");
+ EXPECT_EQ(DoCSE(), OriginalGraph());
+}
+
+TEST_F(OptimizerCSETest, SameOps_DifferentAttrs2) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B']"
+ " attr { key: 'a' value { i: 3 } }"
+ " attr { key: 't' value { type: DT_FLOAT } } }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B']"
+ " attr { key: 't' value { type: DT_INT32 } }"
+ " attr { key: 'a' value { i: 3 } } }");
+ EXPECT_EQ(DoCSE(), OriginalGraph());
+}
+
+TEST_F(OptimizerCSETest, NotEquivalent_Inputs) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'C'] }");
+ EXPECT_EQ(DoCSE(), OriginalGraph());
+}
+
+TEST_F(OptimizerCSETest, Constant_Dedup) {
+ Tensor a(DT_FLOAT, TensorShape({1}));
+ a.flat<float>()(0) = 1.0;
+ Tensor b(DT_DOUBLE, TensorShape({1})); // Different type
+ b.flat<double>()(0) = 1.0;
+ Tensor c(DT_FLOAT, TensorShape({1, 1})); // Different shape
+ c.flat<float>()(0) = 1.0;
+ Tensor d(DT_FLOAT, TensorShape({1})); // Different value
+ d.flat<float>()(0) = 2.0;
+
+ // A graph contains a bunch of constants.
+ Graph g(OpRegistry::Global());
+ for (auto val : {a, b, c, d, d, c, b, a}) {
+ test::graph::Constant(&g, val); // Node name is n/_0, n/_1, ...
+ }
+ GraphDef gdef;
+ test::graph::ToGraphDef(&g, &gdef);
+ InitGraph(gdef.DebugString());
+
+ EXPECT_EQ(OriginalGraph(),
+ "n/_0(Const);n/_1(Const);n/_2(Const);n/_3(Const);"
+ "n/_4(Const);n/_5(Const);n/_6(Const);n/_7(Const)|");
+ // In theory, there are 2^4 possible correct output of CSE. In this
+ // test, it happens happens to eliminate the first 4 nodes.
+ EXPECT_EQ(DoCSE(), "n/_4(Const);n/_5(Const);n/_6(Const);n/_7(Const)|");
+}
+
+static void BM_CSE(int iters, int op_nodes) {
+ testing::StopTiming();
+ string s;
+ for (int in = 0; in < 10; in++) {
+ s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in);
+ }
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ for (int op = 0; op < op_nodes; op++) {
+ s += strings::Printf(
+ "node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { "
+ "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
+ op, rnd.Uniform(10), rnd.Uniform(10));
+ }
+
+ bool first = true;
+ while (iters > 0) {
+ Graph* graph = new Graph(OpRegistry::Global());
+ InitGraph(s, graph);
+ int N = graph->num_node_ids();
+ if (first) {
+ testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N));
+ first = false;
+ }
+ {
+ testing::StartTiming();
+ OptimizeCSE(graph, nullptr);
+ testing::StopTiming();
+ }
+ iters -= N; // Our benchmark units are individual graph nodes,
+ // not whole graphs
+ delete graph;
+ }
+}
+BENCHMARK(BM_CSE)->Arg(1000)->Arg(10000);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc
new file mode 100644
index 0000000000..7910511dfb
--- /dev/null
+++ b/tensorflow/core/graph/subgraph.cc
@@ -0,0 +1,258 @@
+#include "tensorflow/core/graph/subgraph.h"
+
+#include <algorithm>
+#include <deque>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/tensor_id.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// ----------------------------------------------------------------------------
+// Subgraph construction-related routines
+// ----------------------------------------------------------------------------
+// TODO(vrv): Profile the unordered_set and unordered_map use in this file to
+// see if we should use an alternative implementation.
+
+namespace {
+
+typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher> NameIndex;
+
+// Rewrite graph by replacing the output tensors specified in
+// "fed_outputs" with special feed nodes for each specified output
+// tensor, and removing any nodes that are now disconnected from the
+// part of the graph that reaches the sink node. The set of special
+// feed nodes added to the graph are returned in "*feed_nodes".
+//
+// Return true on success. On error, return false and sets *error to
+// an appropriate error message (and *g is left in an indeterminate
+// state).
+static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
+ const gtl::ArraySlice<string>& fed_outputs,
+ NameIndex* name_index) {
+ for (const string& t : fed_outputs) {
+ TensorId id(ParseTensorName(t));
+
+ auto iter = name_index->find(id.first);
+ if (iter == name_index->end()) {
+ return errors::NotFound("FeedInputs: unable to find feed output ", t);
+ }
+ const Node* n = iter->second;
+ DCHECK_EQ(n->name(), id.first);
+ if (id.second >= n->num_outputs()) {
+ return errors::InvalidArgument(
+ "FeedInputs: ", t, " should have output index < ", n->num_outputs());
+ }
+
+ Node* recv_node;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second),
+ "_Recv")
+ .Attr("tensor_type", BaseType(n->output_type(id.second)))
+ .Attr("tensor_name", t)
+ .Attr("send_device", device_info.name())
+ .Attr("recv_device", device_info.name())
+ .Attr("send_device_incarnation",
+ static_cast<int64>(device_info.incarnation()))
+ .Attr("client_terminated", true)
+ .Finalize(g, &recv_node));
+ recv_node->set_assigned_device_name(device_info.name());
+
+ // Update name_index
+ (*name_index)[recv_node->name()] = recv_node;
+ g->AddControlEdge(g->source_node(), recv_node);
+
+ // Look through edges coming out of "n" for edges whose src_output() index
+ // matches "output_index". If found, replace the edges with a connection
+ // from the special feed node.
+ std::vector<const Edge*> to_remove;
+ for (const Edge* e : n->out_edges()) {
+ if (e->src_output() == id.second) {
+ to_remove.emplace_back(e);
+ } else if (e->src_output() == Graph::kControlSlot &&
+ n->def().op() == "Placeholder") {
+ // When feeding a Placeholder node, any outgoing control edges
+ // will be replaced with a control edge from the replacement
+ // recv_node.
+ // TODO(josh11b,mrry): Come up with a more elegant way of addressing
+ // the general version of this problem.
+ to_remove.emplace_back(e);
+ }
+ }
+
+ for (const Edge* e : to_remove) {
+ if (e->src_output() == id.second) {
+ g->AddEdge(recv_node, 0, e->dst(), e->dst_input());
+ } else {
+ CHECK_EQ(Graph::kControlSlot, e->src_output());
+ g->AddControlEdge(recv_node, e->dst());
+ }
+ g->RemoveEdge(e);
+ }
+ }
+ return Status::OK();
+}
+
+// Augment "*g" by adding special "fetch" nodes that connect to the
+// tensor outputs specified in "fetch_outputs" to retrieve the output
+// of the tensors. The new nodes added are set up to execute on
+// "client_device_name", and are returned in "*fetch_nodes".
+//
+// Return true on success. On error, return false and sets *error to
+// an appropriate error message (and *g is left in an indeterminate
+// state).
+static Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
+ const gtl::ArraySlice<string>& fetch_outputs,
+ NameIndex* name_index,
+ std::vector<Node*>* fetch_nodes) {
+ fetch_nodes->clear();
+ for (const string& t : fetch_outputs) {
+ // Parse t into node_name and output_index.
+ TensorId id(ParseTensorName(t));
+
+ // Find node in graph with that name.
+ auto iter = name_index->find(id.first);
+ if (iter == name_index->end()) {
+ return errors::NotFound("FetchOutputs node ", t, ": not found");
+ }
+ Node* n = iter->second;
+ DCHECK_EQ(n->name(), id.first);
+ VLOG(2) << "Found fetch node for " << t;
+
+ // Validate output_index
+ if (id.second >= n->num_outputs()) {
+ return errors::InvalidArgument("FetchOutputs ", t,
+ ": output index too large, must be < ",
+ n->num_outputs());
+ }
+
+ // Create the fetch Node and connect it up
+ Node* send_node;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second),
+ "_Send")
+ .Input(n, id.second)
+ .Attr("tensor_name", t)
+ .Attr("send_device", device_info.name())
+ .Attr("recv_device", device_info.name())
+ .Attr("send_device_incarnation",
+ static_cast<int64>(device_info.incarnation()))
+ .Attr("client_terminated", true)
+ .Finalize(g, &send_node));
+ send_node->set_assigned_device_name(device_info.name());
+ VLOG(1) << "Created fetch node: " << SummarizeNodeDef(send_node->def());
+
+ // Update the index.
+ (*name_index)[send_node->name()] = send_node;
+
+ g->AddControlEdge(send_node, g->sink_node());
+ fetch_nodes->push_back(send_node);
+ }
+
+ return Status::OK();
+}
+
+static bool AddNodeToTargets(const string& node_or_tensor_name,
+ const NameIndex& name_index,
+ std::unordered_set<const Node*>* targets) {
+ TensorId id = ParseTensorName(node_or_tensor_name);
+ auto iter = name_index.find(id.first);
+ if (iter == name_index.end()) {
+ return false;
+ }
+ const Node* n = iter->second;
+ if (n->name() != node_or_tensor_name) {
+ return false;
+ }
+
+ targets->insert(n);
+ return true;
+}
+
+static Status PruneForTargets(Graph* g, const NameIndex& name_index,
+ const std::vector<Node*>& fetch_nodes,
+ const gtl::ArraySlice<string>& target_nodes) {
+ string not_found;
+ std::unordered_set<const Node*> targets;
+ for (Node* n : fetch_nodes) {
+ if (!AddNodeToTargets(n->name(), name_index, &targets)) {
+ strings::StrAppend(&not_found, n->name(), " ");
+ }
+ }
+ for (const string& s : target_nodes) {
+ if (!AddNodeToTargets(s, name_index, &targets)) {
+ strings::StrAppend(&not_found, s, " ");
+ }
+ }
+ if (!not_found.empty()) {
+ return errors::NotFound("PruneForTargets: Some target nodes not found: ",
+ not_found);
+ }
+ PruneForReverseReachability(g, targets);
+
+ return Status::OK();
+}
+
+} // namespace
+
+namespace subgraph {
+
+Status RewriteGraphForExecution(
+ Graph* g, const gtl::ArraySlice<string>& fed_outputs,
+ const gtl::ArraySlice<string>& fetch_outputs,
+ const gtl::ArraySlice<string>& target_node_names,
+ const DeviceAttributes& device_info) {
+ std::unordered_set<string> endpoints(fed_outputs.begin(), fed_outputs.end());
+ for (const auto& fetch : fetch_outputs) {
+ if (endpoints.count(fetch) > 0) {
+ return errors::InvalidArgument(fetch, " is both fed and fetched.");
+ }
+ }
+
+ // A separate index mapping name to Node*, for use by FeedInputs,
+ // FetchOutputs, and PruneForTargets
+ NameIndex name_index;
+ for (Node* n : g->nodes()) {
+ name_index[n->name()] = n;
+ }
+
+ // Add the feeds. This may replace nodes in the graph, including the nodes
+ // currently listed in "fetch_nodes". We pass "name_index" so the index is
+ // kept up to date.
+ if (!fed_outputs.empty()) {
+ TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs, &name_index));
+ }
+
+ // Add the fetch nodes, also updating "name_index".
+ std::vector<Node*> fetch_nodes;
+ if (!fetch_outputs.empty()) {
+ TF_RETURN_IF_ERROR(
+ FetchOutputs(g, device_info, fetch_outputs, &name_index, &fetch_nodes));
+ }
+
+ // Prune the graph to only compute what is needed for the fetch nodes and the
+ // targets nodes.
+ if (!fetch_nodes.empty() || !target_node_names.empty()) {
+ TF_RETURN_IF_ERROR(
+ PruneForTargets(g, name_index, fetch_nodes, target_node_names));
+ }
+
+ return Status::OK();
+}
+
+} // namespace subgraph
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/subgraph.h b/tensorflow/core/graph/subgraph.h
new file mode 100644
index 0000000000..d2e138e8ae
--- /dev/null
+++ b/tensorflow/core/graph/subgraph.h
@@ -0,0 +1,49 @@
+#ifndef TENSORFLOW_GRAPH_SUBGRAPH_H_
+#define TENSORFLOW_GRAPH_SUBGRAPH_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+namespace subgraph {
+
+// Rewrite the graph structure of "*g" to deal with feeding node
+// outputs, fetching node outputs, and only running a subset of the
+// graph. "fed_outputs" and "fetch_outputs" are both lists of
+// output tensor identifiers in the form of
+// "<name>[:<optional_output_index>]", and "target_nodes_str" is a
+// lists of of target node names in "*g" "g".
+//
+// In the resulting graph "*g", output edges in "fed_outputs" have
+// been redirected to special "_recv" nodes introduced into the graph.
+// If these fed nodes are not needed in order to compute the effects
+// of the nodes in "targets_nodes" and "fetch_outputs", then these may
+// be omitted from the graph.
+//
+// In the resulting graph "*g", additional "_send" nodes are connected
+// to every output in "fetch_outputs". These "_send" nodes are set up
+// to execute on the device described by device_info.
+//
+// On success, returns OK, and sets "*g" to a version of "*g"
+// that represents the portions of the graph necessary for producing
+// the output of all nodes listed in "target_node_names" and fetching the
+// specific node outputs specified in "fetch_outputs".
+//
+// On failure, returns the error status. Possible errors include:
+// - fed output "node:output_index" does not exist in "*g"
+// - fetch output "node:output_index" does not exist in "*g"
+// - target node "node" does not exist in "*g"
+Status RewriteGraphForExecution(
+ Graph* g, const gtl::ArraySlice<string>& fed_outputs,
+ const gtl::ArraySlice<string>& fetch_outputs,
+ const gtl::ArraySlice<string>& target_node_names,
+ const DeviceAttributes& device_info);
+
+} // namespace subgraph
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_SUBGRAPH_H_
diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc
new file mode 100644
index 0000000000..ffb3e6e403
--- /dev/null
+++ b/tensorflow/core/graph/subgraph_test.cc
@@ -0,0 +1,305 @@
+#include "tensorflow/core/graph/subgraph.h"
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/status.h"
+#include <gtest/gtest.h>
+
+// TODO(josh11b): Test setting the "device" field of a NodeDef.
+// TODO(josh11b): Test that feeding won't prune targets.
+
+namespace tensorflow {
+namespace {
+
+class SubgraphTest : public ::testing::Test {
+ protected:
+ SubgraphTest() : g_(new Graph(OpRegistry::Global())) {
+ RequireDefaultOps();
+ device_info_.set_name("/job:a/replica:0/task:0/cpu:0");
+ device_info_.set_device_type(DeviceType(DEVICE_CPU).type());
+ device_info_.set_incarnation(0);
+ }
+
+ ~SubgraphTest() override {}
+
+ void ExpectOK(const string& gdef_ascii) {
+ CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef_));
+ GraphConstructorOptions opts;
+ TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef_, g_.get()));
+ }
+
+ Node* FindNode(const string& name) {
+ for (Node* n : g_->nodes()) {
+ if (n->name() == name) return n;
+ }
+ return nullptr;
+ }
+
+ bool HasNode(const string& name) { return FindNode(name) != nullptr; }
+
+ void ExpectNodes(const string& nodes) {
+ int count = 0;
+ std::vector<string> actual_nodes;
+ for (Node* n : g_->nodes()) {
+ if (n->IsOp()) {
+ count++;
+ actual_nodes.push_back(n->name());
+ }
+ }
+ std::sort(actual_nodes.begin(), actual_nodes.end());
+
+ LOG(INFO) << "Nodes present: " << str_util::Join(actual_nodes, " ");
+
+ std::vector<string> expected_nodes = str_util::Split(nodes, ',');
+ std::sort(expected_nodes.begin(), expected_nodes.end());
+ for (const string& s : expected_nodes) {
+ Node* n = FindNode(s);
+ EXPECT_TRUE(n != nullptr) << s;
+ if (n->def().op() == "_Send" || n->def().op() == "_Recv") {
+ EXPECT_EQ(device_info_.name(), n->assigned_device_name()) << s;
+ }
+ }
+
+ EXPECT_TRUE(actual_nodes.size() == expected_nodes.size())
+ << "\nActual: " << str_util::Join(actual_nodes, ",")
+ << "\nExpected: " << str_util::Join(expected_nodes, ",");
+ }
+
+ bool HasEdge(const string& src, int src_out, const string& dst, int dst_in) {
+ for (const Edge* e : g_->edges()) {
+ if (e->src()->name() == src && e->src_output() == src_out &&
+ e->dst()->name() == dst && e->dst_input() == dst_in)
+ return true;
+ }
+ return false;
+ }
+ bool HasControlEdge(const string& src, const string& dst) {
+ return HasEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot);
+ }
+
+ string Subgraph(const string& fed_str, const string& fetch_str,
+ const string& targets_str) {
+ Graph* subgraph = new Graph(OpRegistry::Global());
+ CopyGraph(*g_, subgraph);
+ std::vector<string> fed =
+ str_util::Split(fed_str, ',', str_util::SkipEmpty());
+ std::vector<string> fetch =
+ str_util::Split(fetch_str, ',', str_util::SkipEmpty());
+ std::vector<string> targets =
+ str_util::Split(targets_str, ',', str_util::SkipEmpty());
+
+ Status s = subgraph::RewriteGraphForExecution(subgraph, fed, fetch,
+ targets, device_info_);
+ if (!s.ok()) {
+ delete subgraph;
+ return s.ToString();
+ }
+
+ // Replace the graph with the subgraph for the rest of the display program
+ g_.reset(subgraph);
+ return "OK";
+ }
+
+ Graph* graph() { return g_.get(); }
+
+ private:
+ GraphDef gdef_;
+ std::unique_ptr<Graph> g_;
+ DeviceAttributes device_info_;
+};
+
+REGISTER_OP("TestParams").Output("o: float");
+REGISTER_OP("TestInput").Output("a: float").Output("b: float");
+REGISTER_OP("TestRelu").Input("i: float").Output("o: float");
+REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
+
+TEST_F(SubgraphTest, Targets1) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'W2' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
+ "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
+ "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
+ "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
+ EXPECT_EQ("OK", Subgraph("", "", "t1"));
+ ExpectNodes("W1,input,t1");
+}
+
+TEST_F(SubgraphTest, Targets2) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'W2' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: 'W1' input: 'input:1' }"
+ "node { name: 't2' op: 'TestMul' input: 'W2' input: 't1' }"
+ "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
+ "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
+ EXPECT_EQ("OK", Subgraph("", "", "t2,t3_a"));
+ ExpectNodes("W1,W2,input,t1,t2,t3_a");
+}
+
+TEST_F(SubgraphTest, FedOutputs1) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'W2' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
+ "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
+ "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
+ "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
+ EXPECT_EQ("OK", Subgraph("input:1", "", "t2"));
+ ExpectNodes("W1,W2,_recv_input_1,t1,t2");
+}
+
+TEST_F(SubgraphTest, FedRefNode) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'W2' op: 'TestParams' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W2', 'W1' ] }");
+ EXPECT_EQ("OK", Subgraph("W1:0", "", "t1"));
+ ExpectNodes("_recv_W1_0,W2,t1");
+ Node* n = FindNode("_recv_W1_0");
+ EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0)));
+}
+
+TEST_F(SubgraphTest, FedOutputs2) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'W2' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
+ "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
+ "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
+ "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
+ // We feed input:1, but nothing connects to it, so the _recv(input:1)
+ // node also disappears.
+ EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2"));
+ ExpectNodes("_recv_t1_0,_recv_W2_0,t2");
+}
+
+TEST_F(SubgraphTest, FetchOutputs1) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'W2' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
+ "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
+ "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
+ "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
+ EXPECT_EQ("OK", Subgraph("", "W2,input:1,t1,t2", "t2"));
+ ExpectNodes(
+ "W1,W2,input,t1,t2,_send_W2_0,_send_input_1,_send_t1_0,_send_t2_0");
+}
+
+TEST_F(SubgraphTest, FetchOutputs2) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'W2' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
+ "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
+ "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
+ "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
+ EXPECT_EQ("OK", Subgraph("", "t3_a", "t2"));
+ ExpectNodes("W1,W2,input,t1,t2,t3_a,_send_t3_a_0");
+}
+
+TEST_F(SubgraphTest, ChainOfFools) {
+ ExpectOK(
+ "node { name: 'a' op: 'TestParams' }"
+ "node { name: 'b' op: 'TestRelu' input: 'a'}"
+ "node { name: 'c' op: 'TestRelu' input: 'b'}"
+ "node { name: 'd' op: 'TestRelu' input: 'c'}"
+ "node { name: 'e' op: 'TestRelu' input: 'd'}"
+ "node { name: 'f' op: 'TestRelu' input: 'e'}");
+ EXPECT_EQ("OK", Subgraph("c:0", "b:0,e:0", ""));
+ ExpectNodes("a,b,_send_b_0,_recv_c_0,d,e,_send_e_0");
+ EXPECT_TRUE(HasEdge("a", 0, "b", 0));
+ EXPECT_TRUE(HasEdge("b", 0, "_send_b_0", 0));
+ EXPECT_TRUE(HasEdge("_recv_c_0", 0, "d", 0));
+ EXPECT_TRUE(HasEdge("d", 0, "e", 0));
+ EXPECT_TRUE(HasEdge("e", 0, "_send_e_0", 0));
+}
+
+static bool HasSubstr(const string& base, const string& substr) {
+ bool ok = StringPiece(base).contains(substr);
+ EXPECT_TRUE(ok) << base << ", expected substring " << substr;
+ return ok;
+}
+
+TEST_F(SubgraphTest, Errors) {
+ ExpectOK(
+ "node { name: 'a' op: 'TestParams' }"
+ "node { name: 'b' op: 'TestRelu' input: 'a'}"
+ "node { name: 'c' op: 'TestRelu' input: 'b'}"
+ "node { name: 'd' op: 'TestRelu' input: 'c'}"
+ "node { name: 'e' op: 'TestRelu' input: 'd'}"
+ "node { name: 'f' op: 'TestRelu' input: 'e'}");
+ // Duplicated feed and fetch
+ EXPECT_TRUE(
+ HasSubstr(Subgraph("c:0", "b:0,c:0", ""), "both fed and fetched"));
+ // Feed not found.
+ EXPECT_TRUE(HasSubstr(Subgraph("foo:0", "", ""), "unable to find"));
+ // Fetch not found.
+ EXPECT_TRUE(HasSubstr(Subgraph("", "foo:0", ""), "not found"));
+ // Target not found.
+ EXPECT_TRUE(HasSubstr(Subgraph("", "", "foo"), "not found"));
+}
+
+REGISTER_OP("In").Output("o: float");
+REGISTER_OP("Op").Input("i: float").Output("o: float");
+
+static void BM_Subgraph(int iters, int num_nodes) {
+ DeviceAttributes device_info;
+ device_info.set_name("/job:a/replica:0/task:0/cpu:0");
+ device_info.set_device_type(DeviceType(DEVICE_CPU).type());
+ device_info.set_incarnation(0);
+
+ testing::StopTiming();
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* last_node = nullptr;
+ for (int i = 0; i < num_nodes; i++) {
+ string name = strings::StrCat("N", i);
+ if (i > 0) {
+ last_node = ops::UnaryOp("Op", last_node, b.opts().WithName(name));
+ } else {
+ last_node = ops::SourceOp("In", b.opts().WithName(name));
+ }
+ }
+ TF_CHECK_OK(b.ToGraph(&g));
+ }
+
+ std::vector<string> fed;
+ if (num_nodes > 1000) {
+ fed.push_back(strings::StrCat("N", num_nodes - 1000));
+ }
+ std::vector<string> fetch;
+ std::vector<string> targets = {strings::StrCat("N", num_nodes - 1)};
+ testing::StartTiming();
+ while (--iters > 0) {
+ Graph* subgraph = new Graph(OpRegistry::Global());
+ CopyGraph(g, subgraph);
+ TF_CHECK_OK(subgraph::RewriteGraphForExecution(subgraph, fed, fetch,
+ targets, device_info));
+ delete subgraph;
+ }
+}
+BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/tensor_id.cc b/tensorflow/core/graph/tensor_id.cc
new file mode 100644
index 0000000000..f789110ff3
--- /dev/null
+++ b/tensorflow/core/graph/tensor_id.cc
@@ -0,0 +1,41 @@
+#include "tensorflow/core/graph/tensor_id.h"
+
+#include <string>
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+
+TensorId ParseTensorName(const string& name) {
+ return ParseTensorName(StringPiece(name.data(), name.size()));
+}
+
+TensorId ParseTensorName(StringPiece name) {
+ // Parse either a name, or a name:digits. To do so, we go backwards
+ // from the end of the string, skipping over a run of digits. If
+ // we hit a ':' character, then we know we are in the 'name:digits'
+ // regime. Otherwise, the output index is implicitly 0, and the whole
+ // name string forms the first part of the tensor name.
+ //
+ // Equivalent to matching with this regexp: ([^:]+):(\\d+)
+ const char* base = name.data();
+ const char* p = base + name.size() - 1;
+ int index = 0;
+ int mul = 1;
+ while (p > base && (*p >= '0' && *p <= '9')) {
+ index += ((*p - '0') * mul);
+ mul *= 10;
+ p--;
+ }
+ TensorId id;
+ if (p > base && *p == ':' && mul > 1) {
+ id.first = StringPiece(base, p - base);
+ id.second = index;
+ } else {
+ id.first = name;
+ id.second = 0;
+ }
+ return id;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/tensor_id.h b/tensorflow/core/graph/tensor_id.h
new file mode 100644
index 0000000000..f1f3846875
--- /dev/null
+++ b/tensorflow/core/graph/tensor_id.h
@@ -0,0 +1,28 @@
+#ifndef TENSORFLOW_GRAPH_TENSOR_ID_H_
+#define TENSORFLOW_GRAPH_TENSOR_ID_H_
+
+#include <string>
+
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+
+// Identifier for a tensor within a step.
+// first == operation_name, second == output_index
+// Note: does not own backing storage for name.
+struct TensorId : public std::pair<StringPiece, int> {
+ typedef std::pair<StringPiece, int> Base;
+
+ // Inherit the set of constructors.
+ using Base::pair;
+
+ string ToString() const { return strings::StrCat(first, ":", second); }
+};
+
+TensorId ParseTensorName(const string& name);
+TensorId ParseTensorName(StringPiece name);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_TENSOR_ID_H_
diff --git a/tensorflow/core/graph/tensor_id_test.cc b/tensorflow/core/graph/tensor_id_test.cc
new file mode 100644
index 0000000000..b945774cc3
--- /dev/null
+++ b/tensorflow/core/graph/tensor_id_test.cc
@@ -0,0 +1,77 @@
+#include "tensorflow/core/graph/tensor_id.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+static string ParseHelper(const string& n) {
+ TensorId id = ParseTensorName(n);
+ return strings::StrCat(id.first, ":", id.second);
+}
+
+TEST(TensorIdTest, ParseTensorName) {
+ EXPECT_EQ(ParseHelper("W1"), "W1:0");
+ EXPECT_EQ(ParseHelper("weights:0"), "weights:0");
+ EXPECT_EQ(ParseHelper("W1:1"), "W1:1");
+ EXPECT_EQ(ParseHelper("W1:17"), "W1:17");
+ EXPECT_EQ(ParseHelper("xyz1_17"), "xyz1_17:0");
+}
+
+static uint32 Skewed(random::SimplePhilox* rnd, int max_log) {
+ const uint32 space = 1 << (rnd->Rand32() % (max_log + 1));
+ return rnd->Rand32() % space;
+}
+
+static void BM_ParseTensorName(int iters, int arg) {
+ testing::StopTiming();
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ std::vector<string> names;
+ for (int i = 0; i < 100; i++) {
+ string name;
+ switch (arg) {
+ case 0: { // Generate random names
+ size_t len = Skewed(&rnd, 4);
+ while (name.size() < len) {
+ name += rnd.OneIn(4) ? '0' : 'a';
+ }
+ if (rnd.OneIn(3)) {
+ strings::StrAppend(&name, ":", rnd.Uniform(12));
+ }
+ break;
+ }
+ case 1:
+ name = "W1";
+ break;
+ case 2:
+ name = "t0003";
+ break;
+ case 3:
+ name = "weights";
+ break;
+ case 4:
+ name = "weights:17";
+ break;
+ default:
+ LOG(FATAL) << "Unexpected arg";
+ break;
+ }
+ names.push_back(name);
+ }
+ testing::StartTiming();
+ TensorId id;
+ int index = 0;
+ int sum = 0;
+ while (--iters > 0) {
+ id = ParseTensorName(names[index++ % names.size()]);
+ sum += id.second;
+ }
+ VLOG(2) << sum; // Prevent compiler from eliminating loop body
+}
+BENCHMARK(BM_ParseTensorName)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
new file mode 100644
index 0000000000..e49d5e819a
--- /dev/null
+++ b/tensorflow/core/graph/testlib.cc
@@ -0,0 +1,299 @@
+#include "tensorflow/core/graph/testlib.h"
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+namespace test {
+namespace graph {
+
+Node* Send(Graph* g, Node* input, const string& tensor, const string& sender,
+ const uint64 sender_incarnation, const string& receiver) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send")
+ .Input(input, 0)
+ .Attr("tensor_name", tensor)
+ .Attr("send_device", sender)
+ .Attr("send_device_incarnation",
+ static_cast<int64>(sender_incarnation))
+ .Attr("recv_device", receiver)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Recv(Graph* g, const string& tensor, const string& type,
+ const string& sender, const uint64 sender_incarnation,
+ const string& receiver) {
+ Node* ret;
+ DataType dtype;
+ CHECK(DataTypeFromString(type, &dtype));
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv")
+ .Attr("tensor_type", dtype)
+ .Attr("tensor_name", tensor)
+ .Attr("send_device", sender)
+ .Attr("send_device_incarnation",
+ static_cast<int64>(sender_incarnation))
+ .Attr("recv_device", receiver)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Constant(Graph* g, const Tensor& tensor) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const")
+ .Attr("dtype", tensor.dtype())
+ .Attr("value", tensor)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Constant(Graph* g, const Tensor& tensor, const string& name) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(name, "Const")
+ .Attr("dtype", tensor.dtype())
+ .Attr("value", tensor)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable")
+ .Attr("dtype", dtype)
+ .Attr("shape", shape)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Assign(Graph* g, Node* var, Node* val) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign")
+ .Input(var)
+ .Input(val)
+ .Attr("use_locking", true)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes,
+ bool keep_dims) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce)
+ .Input(data)
+ .Input(axes)
+ .Attr("keep_dims", keep_dims)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* QuantizeToUINT8(Graph* g, Node* data) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize")
+ .Input(data)
+ .Attr("T", DT_QUINT8)
+ .Attr("max_range", 1.0f)
+ .Attr("min_range", -1.0f)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a,
+ bool transpose_b) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul")
+ .Input(in0)
+ .Input(in1)
+ .Attr("transpose_a", transpose_a)
+ .Attr("transpose_b", transpose_b)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* RandomNumberGenerator(const string& op, Graph* g, Node* input,
+ DataType dtype) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), op)
+ .Input(input)
+ .Attr("dtype", dtype)
+ .Attr("seed", 0)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* RandomUniform(Graph* g, Node* input, DataType dtype) {
+ return RandomNumberGenerator("RandomUniform", g, input, dtype);
+}
+
+Node* RandomGaussian(Graph* g, Node* input, DataType dtype) {
+ return RandomNumberGenerator("RandomStandardNormal", g, input, dtype);
+}
+
+Node* RandomParameters(Graph* g, Node* input, DataType dtype) {
+ return RandomNumberGenerator("RandomParameters", g, input, dtype);
+}
+
+Node* Unary(Graph* g, const string& func, Node* input, int index) {
+ Node* ret;
+ TF_CHECK_OK(
+ NodeBuilder(g->NewName("n"), func).Input(input, index).Finalize(g, &ret));
+ return ret;
+}
+
+Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), func)
+ .Input(in0)
+ .Input(in1)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) {
+ Node* ret;
+ auto b = NodeBuilder(g->NewName("n"), func);
+ for (Node* n : ins) b = b.Input(n);
+ TF_CHECK_OK(b.Finalize(g, &ret));
+ return ret;
+}
+
+Node* Identity(Graph* g, Node* input, int index) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity")
+ .Input(input, index)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); }
+
+Node* Error(Graph* g, Node* input, const string& errmsg) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error")
+ .Input(input)
+ .Attr("message", errmsg)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) {
+ DCHECK(out_type != invalid_type);
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType")
+ .Attr("TIn", out_type)
+ .Attr("TOut", invalid_type)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Delay(Graph* g, Node* input, Microseconds delay_micros) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay")
+ .Input(input)
+ .Attr("micros", delay_micros.value())
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp")
+ .ControlInputs(control_inputs)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Switch(Graph* g, Node* in0, Node* in1) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch")
+ .Input(in0)
+ .Input(in1)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Enter(Graph* g, Node* input, const string& frame_name) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter")
+ .Input(input)
+ .Attr("frame_name", frame_name)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Exit(Graph* g, Node* input) {
+ Node* ret;
+ TF_CHECK_OK(
+ NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret));
+ return ret;
+}
+
+Node* Merge(Graph* g, Node* in0, Node* in1) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge")
+ .Input({in0, in1})
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in) {
+ std::vector<NodeBuilder::NodeOut> inputs;
+ inputs.reserve(remaining_in.size() + 1);
+ inputs.emplace_back(in0);
+ for (const string& in_name : remaining_in) {
+ inputs.emplace_back(in_name, 0, inputs[0].dt);
+ }
+
+ Node* ret;
+ TF_CHECK_OK(
+ NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret));
+ return ret;
+}
+
+Node* Next(Graph* g, const string& name, Node* input) {
+ Node* ret;
+ TF_CHECK_OK(
+ NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret));
+ return ret;
+}
+
+Node* LoopCond(Graph* g, Node* input) {
+ Node* ret;
+ TF_CHECK_OK(
+ NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret));
+ return ret;
+}
+
+Node* Less(Graph* g, Node* in0, Node* in1) {
+ return Binary(g, "Less", in0, in1);
+}
+
+Node* Select(Graph* g, Node* c, Node* inx, Node* iny) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select")
+ .Input(c)
+ .Input(inx)
+ .Input(iny)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Cast(Graph* g, Node* in, DataType dst) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast")
+ .Input(in)
+ .Attr("DstT", dst)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
+
+} // end namespace graph
+} // end namespace test
+} // end namespace tensorflow
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
new file mode 100644
index 0000000000..11905bbf6a
--- /dev/null
+++ b/tensorflow/core/graph/testlib.h
@@ -0,0 +1,141 @@
+// DEPRECATED: Use GraphDefBuilder instead.
+
+#ifndef TENSORFLOW_GRAPH_TESTLIB_H_
+#define TENSORFLOW_GRAPH_TESTLIB_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/types.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+namespace test {
+namespace graph {
+
+// Converts "g" into its corresponding GraphDef "def".
+// DEPRECATED: call g->ToGraphDef(def) instead.
+void ToGraphDef(Graph* g, GraphDef* def);
+
+// A few helpers to construct a graph.
+
+// Adds a node in "g" producing a constant "tensor".
+Node* Constant(Graph* g, const Tensor& tensor);
+Node* Constant(Graph* g, const Tensor& tensor, const string& name);
+
+// Adds a variable in "g" of the given "shape" and "dtype".
+Node* Var(Graph* g, const DataType dtype, const TensorShape& shape);
+
+// Adds an assign node in "g" which assigns "val" into "var".
+Node* Assign(Graph* g, Node* var, Node* val);
+
+// Adds a send node "g" sending "input" as a named "tensor" from
+// "sender" to "receiver".
+Node* Send(Graph* g, Node* input, const string& tensor, const string& sender,
+ const uint64 sender_incarnation, const string& receiver);
+
+// Adds a recv node in "g" receiving a named "tensor" from "sender"
+// to "receiver".
+Node* Recv(Graph* g, const string& tensor, const string& type,
+ const string& sender, const uint64 sender_incarnation,
+ const string& receiver);
+
+// Adds a reduction "node" in "g" doing sum(data, axes). "reduce" is
+// a reduction, e.g., Sum, Max, Min, Mean, etc.
+Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes,
+ bool keep_dims = false);
+
+// Adds a Matmul node in g doing in0.contract(in1).
+Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a,
+ bool transpose_b);
+
+// Adds a Quantize node into g that quantize floats into QUINT8. The range of
+// the input float tensor is assumed to be [-1, 1].
+Node* QuantizeToUINT8(Graph* g, Node* data);
+
+// Adds a unary function "func" "node" in "g" taking "input".
+Node* Unary(Graph* g, const string& func, Node* input, int index = 0);
+
+// Adds an identity node in "g" taking "input" and producing an
+// identity copy.
+Node* Identity(Graph* g, Node* input, int index = 0);
+
+// Adds a binary function "func" node in "g" taking "in0" and "in1".
+// Requires that "func" name an attr-style Op.
+Node* Binary(Graph* g, const string& func, Node* in0, Node* in1);
+
+// Adds a function "func" node in "g" taking inputs "ins".
+// Requires that "func" name an attr-style Op.
+Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins);
+
+// Adds a binary add node in "g" doing in0 + in1.
+Node* Add(Graph* g, Node* in0, Node* in1);
+
+// Generates random unit uniform distribution of the input shape.
+Node* RandomUniform(Graph* g, Node* input, DataType dtype);
+
+// Generates random unit normal distribution of the input shape.
+Node* RandomGaussian(Graph* g, Node* input, DataType dtype);
+
+// Generates random parameters from the truncated standard normal distribution
+// of the nput shape
+Node* RandomParameters(Graph* g, Node* input, DataType dtype);
+
+// Adds an error node in "g". The node's computation always
+// generates an error with the given error message "errmsg".
+Node* Error(Graph* g, Node* input, const string& errmsg);
+
+// Adds a node that generates a invalid ref output.
+Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type);
+
+// Adds a node in "g". Its Compute() sleeps a while and outputs the
+// input (i.e., same as identity).
+Node* Delay(Graph* g, Node* input, Microseconds delay_micros);
+
+// Adds a no-op "node" in "g", with control inputs from all nodes in
+// control_inputs vector.
+Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs);
+
+// Adds a Switch node in "g". If "in1" is true, it forwards "in0" to
+// output 1. Otherwise, it forwards "in0" to output 0.
+Node* Switch(Graph* g, Node* in0, Node* in1);
+
+// Adds an Enter node in "g", which enters a new frame.
+Node* Enter(Graph* g, Node* input, const string& frame_name);
+
+// Adds an Exit node in "g", which exits a frame.
+Node* Exit(Graph* g, Node* input);
+
+// Adds a Merge node in "g" with two inputs "in0" and "in1".
+Node* Merge(Graph* g, Node* in0, Node* in1);
+
+// Adds a Merge node in "g". The first input is "in0", the remaining
+// inputs are only given by their names in remaining_in.
+Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in);
+
+// Adds a NextIteration node in "g", which makes its input available
+// to the next iteration.
+Node* Next(Graph* g, const string& name, Node* input);
+
+// Adds a LoopCond node in "g", representing the "pivot" termination
+// condition of a loop.
+Node* LoopCond(Graph* g, Node* input);
+
+// Adds a less node in "g", which returns true iff "in0" < "in1".
+Node* Less(Graph* g, Node* in0, Node* in1);
+
+// Adds a select node in "g", which outputs either "inx" or "iny"
+// depending on the boolean value of "c".
+Node* Select(Graph* g, Node* c, Node* inx, Node* iny);
+
+// Casts "in" into data type "dst".
+Node* Cast(Graph* g, Node* in, DataType dst);
+
+} // end namespace graph
+} // end namespace test
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_TESTLIB_H_
diff --git a/tensorflow/core/graph/types.h b/tensorflow/core/graph/types.h
new file mode 100644
index 0000000000..41400611a9
--- /dev/null
+++ b/tensorflow/core/graph/types.h
@@ -0,0 +1,17 @@
+#ifndef TENSORFLOW_GRAPH_TYPES_H_
+#define TENSORFLOW_GRAPH_TYPES_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/gtl/int_type.h"
+
+namespace tensorflow {
+
+// We model running time in microseconds.
+TF_LIB_GTL_DEFINE_INT_TYPE(Microseconds, int64);
+
+// We model size in bytes.
+TF_LIB_GTL_DEFINE_INT_TYPE(Bytes, int64);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_TYPES_H_