diff options
Diffstat (limited to 'tensorflow/core/graph')
44 files changed, 7485 insertions, 0 deletions
diff --git a/tensorflow/core/graph/algorithm.cc b/tensorflow/core/graph/algorithm.cc new file mode 100644 index 0000000000..fd79ead0b1 --- /dev/null +++ b/tensorflow/core/graph/algorithm.cc @@ -0,0 +1,107 @@ +#include "tensorflow/core/graph/algorithm.h" + +#include <algorithm> +#include <deque> +#include <vector> + +namespace tensorflow { + +void DFS(const Graph& g, std::function<void(Node*)> enter, + std::function<void(Node*)> leave) { + // Stack of work to do. + struct Work { + Node* node; + bool leave; // Are we entering or leaving n? + }; + std::vector<Work> stack; + stack.push_back(Work{g.source_node(), false}); + + std::vector<bool> visited(g.num_node_ids(), false); + while (!stack.empty()) { + Work w = stack.back(); + stack.pop_back(); + + Node* n = w.node; + if (w.leave) { + leave(n); + continue; + } + + if (visited[n->id()]) continue; + visited[n->id()] = true; + if (enter) enter(n); + + // Arrange to call leave(n) when all done with descendants. + if (leave) stack.push_back(Work{n, true}); + + // Arrange to work on descendants. + for (Node* out : n->out_nodes()) { + if (!visited[out->id()]) { + // Note; we must not mark as visited until we actually process it. + stack.push_back(Work{out, false}); + } + } + } +} + +void GetPostOrder(const Graph& g, std::vector<Node*>* order) { + order->clear(); + DFS(g, nullptr, [order](Node* n) { order->push_back(n); }); +} + +void GetReversePostOrder(const Graph& g, std::vector<Node*>* order) { + GetPostOrder(g, order); + std::reverse(order->begin(), order->end()); +} + +void PruneForReverseReachability(Graph* g, + const std::unordered_set<const Node*>& nodes) { + std::unordered_set<const Node*> visited; + + // Compute set of nodes that we need to traverse in order to reach + // the nodes in "nodes" by performing a breadth-first search from those + // nodes, and accumulating the visited nodes. + std::deque<const Node*> queue; + for (const Node* n : nodes) { + queue.push_back(n); + } + while (!queue.empty()) { + const Node* n = queue.front(); + queue.pop_front(); + if (visited.insert(n).second) { + for (const Node* in : n->in_nodes()) { + queue.push_back(in); + } + } + } + + // Make a pass over the graph to remove nodes not in "visited" + std::vector<Node*> all_nodes; + for (Node* n : g->nodes()) { + all_nodes.push_back(n); + } + + for (Node* n : all_nodes) { + if (visited.count(n) == 0 && !n->IsSource() && !n->IsSink()) { + g->RemoveNode(n); + } + } + + // Reconnect nodes with no outgoing edges to the sink node + FixupSourceAndSinkEdges(g); +} + +void FixupSourceAndSinkEdges(Graph* g) { + // Connect all nodes with no incoming edges to source. + // Connect all nodes with no outgoing edges to sink. + for (Node* n : g->nodes()) { + if (!n->IsSource() && n->in_edges().empty()) { + g->AddControlEdge(g->source_node(), n); + } + if (!n->IsSink() && n->out_edges().empty()) { + g->AddControlEdge(n, g->sink_node()); + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/algorithm.h b/tensorflow/core/graph/algorithm.h new file mode 100644 index 0000000000..58b74a0ace --- /dev/null +++ b/tensorflow/core/graph/algorithm.h @@ -0,0 +1,40 @@ +#ifndef TENSORFLOW_GRAPH_ALGORITHM_H_ +#define TENSORFLOW_GRAPH_ALGORITHM_H_ + +#include <functional> +#include <unordered_set> + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Perform a depth-first-search on g starting at the source node. +// If enter is not empty, calls enter(n) before visiting any children of n. +// If leave is not empty, calls leave(n) after visiting all children of n. +extern void DFS(const Graph& g, std::function<void(Node*)> enter, + std::function<void(Node*)> leave); + +// Stores in *order the post-order numbering of all nodes +// in graph found via a depth first search starting at the source node. +// +// Note that this is equivalent to topological sorting when the +// graph does not have cycles. +// +// REQUIRES: order is not NULL. +void GetPostOrder(const Graph& g, std::vector<Node*>* order); + +// Stores in *order the reverse post-order numbering of all nodes +void GetReversePostOrder(const Graph& g, std::vector<Node*>* order); + +// Prune nodes in "g" that are not in some path from the source node +// to any node in 'nodes'. +void PruneForReverseReachability(Graph* g, + const std::unordered_set<const Node*>& nodes); + +// Connect all nodes with no incoming edges to source. +// Connect all nodes with no outgoing edges to sink. +void FixupSourceAndSinkEdges(Graph* g); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_ALGORITHM_H_ diff --git a/tensorflow/core/graph/algorithm_test.cc b/tensorflow/core/graph/algorithm_test.cc new file mode 100644 index 0000000000..48f2e1ebd7 --- /dev/null +++ b/tensorflow/core/graph/algorithm_test.cc @@ -0,0 +1,103 @@ +#include "tensorflow/core/graph/algorithm.h" + +#include <string> +#include <vector> + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/kernels/ops_util.h" +#include <gtest/gtest.h> +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +// TODO(josh11b): Test setting the "device" field of a NodeDef. +// TODO(josh11b): Test that feeding won't prune targets. + +namespace tensorflow { +namespace { + +REGISTER_OP("TestParams").Output("o: float"); +REGISTER_OP("TestInput").Output("a: float").Output("b: float"); +REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); + +// Compares that the order of nodes in 'inputs' respects the +// pair orders described in 'ordered_pairs'. +bool ExpectBefore(const std::vector<std::pair<string, string>>& ordered_pairs, + const std::vector<Node*>& inputs, string* error) { + for (const std::pair<string, string>& pair : ordered_pairs) { + const string& before_node = pair.first; + const string& after_node = pair.second; + bool seen_before = false; + bool seen_both = false; + for (const Node* node : inputs) { + if (!seen_before && after_node == node->name()) { + *error = strings::StrCat("Saw ", after_node, " before ", before_node); + return false; + } + + if (before_node == node->name()) { + seen_before = true; + } else if (after_node == node->name()) { + seen_both = seen_before; + break; + } + } + if (!seen_both) { + *error = strings::StrCat("didn't see either ", before_node, " or ", + after_node); + return false; + } + } + + return true; +} + +TEST(AlgorithmTest, ReversePostOrder) { + RequireDefaultOps(); + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* w1 = SourceOp("TestParams", b.opts().WithName("W1")); + Node* w2 = SourceOp("TestParams", b.opts().WithName("W2")); + Node* input = + SourceOp("TestInput", b.opts().WithName("input").WithControlInput(w1)); + Node* t1 = BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t1")); + BinaryOp("TestMul", w1, {input, 1}, + b.opts().WithName("t2").WithControlInput(t1)); + BinaryOp("TestMul", w2, {input, 1}, b.opts().WithName("t3")); + + Graph g(OpRegistry::Global()); + ASSERT_OK(b.ToGraph(&g)); + std::vector<Node*> order; + + // Test reverse post order: + GetReversePostOrder(g, &order); + + // Check that the order respects the dependencies correctly. + std::vector<std::pair<string, string>> reverse_orders = { + {"W1", "input"}, {"W1", "t1"}, {"W1", "t2"}, {"W1", "t3"}, + {"input", "t1"}, {"input", "t3"}, {"t1", "t2"}, {"W2", "t3"}}; + string error; + EXPECT_TRUE(ExpectBefore(reverse_orders, order, &error)) << error; + + // A false ordering should fail the check. + reverse_orders = {{"input", "W1"}}; + EXPECT_FALSE(ExpectBefore(reverse_orders, order, &error)); + + // Test post order: + GetPostOrder(g, &order); + + // Check that the order respects the dependencies correctly. + std::vector<std::pair<string, string>> orders = { + {"input", "W1"}, {"t1", "W1"}, {"t2", "W1"}, {"t3", "W1"}, + {"t1", "input"}, {"t3", "input"}, {"t2", "t1"}, {"t3", "W2"}}; + EXPECT_TRUE(ExpectBefore(orders, order, &error)) << error; + + // A false ordering should fail the check. + orders = {{"W1", "t3"}}; + EXPECT_FALSE(ExpectBefore(orders, order, &error)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/colors.cc b/tensorflow/core/graph/colors.cc new file mode 100644 index 0000000000..0eb2fc3740 --- /dev/null +++ b/tensorflow/core/graph/colors.cc @@ -0,0 +1,25 @@ +#include "tensorflow/core/graph/colors.h" + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// Color palette +// http://www.mulinblog.com/a-color-palette-optimized-for-data-visualization/ +static const char* kColors[] = { + "#F15854", // red + "#5DA5DA", // blue + "#FAA43A", // orange + "#60BD68", // green + "#F17CB0", // pink + "#B2912F", // brown + "#B276B2", // purple + "#DECF3F", // yellow + "#4D4D4D", // gray +}; + +const char* ColorFor(int dindex) { + return kColors[dindex % TF_ARRAYSIZE(kColors)]; +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/colors.h b/tensorflow/core/graph/colors.h new file mode 100644 index 0000000000..150c8dc025 --- /dev/null +++ b/tensorflow/core/graph/colors.h @@ -0,0 +1,14 @@ +#ifndef TENSORFLOW_GRAPH_COLORS_H_ +#define TENSORFLOW_GRAPH_COLORS_H_ + +namespace tensorflow { + +// Return a color drawn from a palette to represent an entity +// identified by "i". The return value has the form "#RRGGBB" Note +// that the palette has a limited set of colors and therefore colors +// will be reused eventually. +const char* ColorFor(int dindex); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_COLORS_H_ diff --git a/tensorflow/core/graph/costmodel.cc b/tensorflow/core/graph/costmodel.cc new file mode 100644 index 0000000000..89bc41acfd --- /dev/null +++ b/tensorflow/core/graph/costmodel.cc @@ -0,0 +1,308 @@ +#include "tensorflow/core/graph/costmodel.h" + +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace { +const Microseconds kDefaultTimeEstimate(1); +const Microseconds kMinTimeEstimate(1); +} // namespace + +void CostModel::SuppressInfrequent() { + // Find the median of the non-zero counts, and use half of its value + // as the cutoff for a "normal" execution mode node. + if (count_.empty()) return; + std::vector<int32> non_zero; + for (auto v : count_) { + if (v > 0) non_zero.push_back(v); + } + const size_t sz = non_zero.size(); + if (sz > 0) { + std::nth_element(non_zero.begin(), non_zero.begin() + sz / 2, + non_zero.end()); + int32 median_value = non_zero[sz / 2]; + min_count_ = median_value / 2; + VLOG(1) << "num non_zero vals: " << non_zero.size() << " median_value " + << median_value; + } else { + min_count_ = 1; + } +} + +void CostModel::MergeFromLocal(const Graph& g, const CostModel& cm) { + CHECK(is_global_); + CHECK(!cm.is_global()); + for (const Node* n : g.nodes()) { + const int local_id = cm.Id(n); + const int global_id = Id(n); + if (local_id < 0 || global_id < 0) continue; + Ensure(global_id); + count_[global_id] += cm.count_[local_id]; + time_[global_id] += cm.time_[local_id]; + int num_slots = cm.slot_bytes_[local_id].size(); + if (num_slots > 0) { + if (slot_bytes_[global_id].size() == 0) { + slot_bytes_[global_id].resize(num_slots); + } else { + CHECK_EQ(num_slots, slot_bytes_[global_id].size()); + } + for (int s = 0; s < num_slots; ++s) { + slot_bytes_[global_id][s] += cm.slot_bytes_[local_id][s]; + } + } + } +} + +void CostModel::MergeFromGlobal(const CostModel& cm) { + CHECK(is_global_); + CHECK_EQ(true, cm.is_global()); + const int num_nodes = cm.count_.size(); + Ensure(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + count_[i] += cm.count_[i]; + time_[i] += cm.time_[i]; + int num_slots = cm.slot_bytes_[i].size(); + if (num_slots > 0) { + if (slot_bytes_[i].size() == 0) { + slot_bytes_[i].resize(num_slots); + } else { + CHECK_EQ(num_slots, slot_bytes_[i].size()); + } + for (int s = 0; s < num_slots; ++s) { + slot_bytes_[i][s] += cm.slot_bytes_[i][s]; + } + } + } +} + +void CostModel::MergeFromStats(const NodeNameToCostIdMap& map, + const StepStats& ss) { + CHECK(is_global_); + for (auto& ds : ss.dev_stats()) { + for (auto& ns : ds.node_stats()) { + NodeNameToCostIdMap::const_iterator iter = map.find(ns.node_name()); + // We don't keep stats for nodes not in the global graph, i.e. + // copy/send/recv nodes, feed/fetch, etc. + if (iter == map.end()) continue; + int32 global_id = iter->second; + Ensure(global_id); + int64 elapsed_micros = ns.op_end_rel_micros() - ns.op_start_rel_micros(); + count_[global_id]++; + time_[global_id] += elapsed_micros; + for (auto& no : ns.output()) { + int si = no.slot(); + if (static_cast<size_t>(si) >= slot_bytes_[global_id].size()) { + slot_bytes_[global_id].resize(1 + si); + } + slot_bytes_[global_id][si] += + no.tensor_description().allocation_description().requested_bytes(); + } + } + } +} + +void CostModel::Ensure(int id) { + if (slot_bytes_.size() <= static_cast<size_t>(id)) { + slot_bytes_.resize(id + 1); + count_.resize(id + 1); + time_.resize(id + 1); + } +} + +void CostModel::SetNumOutputs(const Node* node, int num_outputs) { + const int id = Id(node); + if (id < 0) return; + Ensure(id); + auto perslot = &slot_bytes_[id]; + if (perslot->size() > 0) { + CHECK_EQ(num_outputs, perslot->size()) << "Cannot resize slot_bytes, node=" + << node->name(); + } else { + perslot->resize(num_outputs, Bytes(-1)); + } +} + +void CostModel::RecordCount(const Node* node, int count) { + const int id = Id(node); + if (id < 0) return; + CHECK_LT(id, slot_bytes_.size()); + count_[id] += count; +} + +int32 CostModel::TotalCount(const Node* node) const { + const int id = Id(node); + if (id < 0) return 0; + return (static_cast<size_t>(id) < slot_bytes_.size()) ? count_[id] : 0; +} + +void CostModel::RecordSize(const Node* node, int slot, Bytes bytes) { + const int id = Id(node); + if (id < 0) return; + CHECK_LT(id, slot_bytes_.size()); + auto perslot = &slot_bytes_[id]; + CHECK_LT(slot, perslot->size()); + auto v = &(*perslot)[slot]; + if (*v >= 0) { + *v += bytes; + } else { + *v = bytes; + } +} + +Bytes CostModel::TotalBytes(const Node* node, int slot) const { + const int id = Id(node); + if (id < 0 || static_cast<size_t>(id) >= slot_bytes_.size() || + slot_bytes_[id].size() <= static_cast<size_t>(slot)) { + return Bytes(0); + } + return slot_bytes_[id][slot]; +} + +Bytes CostModel::SizeEstimate(const Node* node, int slot) const { + int32 count = TotalCount(node); + if (count < min_count_) return Bytes(0); + return TotalBytes(node, slot) / std::max(1, TotalCount(node)); +} + +void CostModel::RecordTime(const Node* node, Microseconds time) { + const int id = Id(node); + if (id < 0) return; + DCHECK(node->IsOp()) << node->DebugString(); + Ensure(id); + time_[id] += time; +} + +Microseconds CostModel::TotalTime(const Node* node) const { + DCHECK(node->IsOp()) << node->DebugString(); + const int id = Id(node); + if (id < 0 || static_cast<size_t>(id) >= time_.size() || + time_[id] < Microseconds(0)) { + return Microseconds(0); + } + return time_[id]; +} + +Microseconds CostModel::TimeEstimate(const Node* node) const { + int32 count = TotalCount(node); + if (count <= min_count_) return kMinTimeEstimate; + return std::max(kMinTimeEstimate, TotalTime(node) / std::max(1, count)); +} + +void CostModel::CheckInitialized(const Graph& graph) const { + for (const Node* n : graph.nodes()) { + if (n->IsOp()) { + CHECK(static_cast<size_t>(n->id()) < time_.size() && + time_[n->id()] >= Microseconds(0)) + << ": no time estimate for " << n->DebugString(); + + CHECK(static_cast<size_t>(n->id()) < slot_bytes_.size()) + << ": no size estimate for " << n->DebugString(); + const auto& perslot = slot_bytes_[n->id()]; + for (size_t i = 0; i < perslot.size(); i++) { + CHECK_GE(perslot[i], Bytes(0)) << ": no size estimate for output# " << i + << " of " << n->DebugString(); + } + } + } +} + +Microseconds CostModel::CopyTimeEstimate(Bytes b, double network_latency_millis, + double estimated_gbps) { + // TODO(jeff,sanjay): estimate cost based on bandwidth along the + // communication path and the type of transport we are using between + // devices. + // + // We assume the copy time follows a linear model: + // copy_time = copy_bytes / rate + min_time + int64 copy_bytes = b.value(); + const double bytes_per_usec = estimated_gbps * 1000.0 / 8; + const double min_micros = network_latency_millis * 1000.0; + return Microseconds( + static_cast<int64>(copy_bytes / bytes_per_usec + min_micros)); +} + +Microseconds CostModel::ComputationTimeEstimate(int64 math_ops) { + // TODO(jeff,sanjay): Eventually we should pass in the type of device + // (GPU vs. CPU) and use that to affect the estimate. + + // We estimate the microseconds using that value. We divide + // by 1000 to convert the madd number into microseconds (assuming + // roughly 1000 madds per microsecond (~1 GHz for one core)). + return Microseconds(math_ops / 1000); +} + +// ---------------------------------------------------------------------------- +// InitCostModel +// ---------------------------------------------------------------------------- + +namespace { + +static void AddNodesToCostModel(const Graph& g, CostModel* cost_model) { + for (Node* n : g.nodes()) { + const int num_outputs = n->num_outputs(); + cost_model->SetNumOutputs(n, num_outputs); + for (int output = 0; output < num_outputs; output++) { + // Set up an initial bogus estimate for the node's outputs + cost_model->RecordSize(n, output, Bytes(1)); + } + } +} + +static void AssignSizes(const Graph& g, CostModel* cost_model) { + for (const Edge* e : g.edges()) { + // Skip if it is a control edge. + if (e->IsControlEdge()) { + continue; + } + Node* src = e->src(); + + // TODO(josh11b): Get an estimate from the Op + Bytes size(1); + cost_model->RecordSize(src, e->src_output(), size); + } +} + +// This generates an extremely simple initial guess for the +// computation cost of each node. For ordinary Ops, its value should quickly +// be wiped out by the real runtime measurements. For other Ops we don't +// actually generate measurements, so suppression of infrequent Ops ends up +// giving them 0 costs. So, this is not of much consequence except perhaps +// in tests. +static Microseconds TimeEstimateForNode(CostModel* cost_model, Node* n) { + CHECK(n->IsOp()); + VLOG(2) << "Node " << n->id() << ": " << n->name() + << " type_string: " << n->type_string(); + if (IsConstant(n) || IsVariable(n)) { + return Microseconds(0); + } + return kDefaultTimeEstimate; +} + +static void EstimateComputationCosts(const Graph& g, CostModel* cost_model) { + for (Node* n : g.nodes()) { + if (!n->IsOp()) continue; + cost_model->RecordTime(n, TimeEstimateForNode(cost_model, n)); + } +} + +} // namespace + +void CostModel::InitFromGraph(const Graph& g) { + AddNodesToCostModel(g, this); + AssignSizes(g, this); + EstimateComputationCosts(g, this); + CheckInitialized(g); +} + +void CostModel::WriteToLog() { + LOG(INFO) << " min_count_=" << min_count_; + for (size_t i = 0; i < count_.size(); ++i) { + LOG(INFO) << "Node " << i << " count " << count_[i] << " total time " + << time_[i] << " avg time " + << (time_[i] / (std::max(1, count_[i]))); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/costmodel.h b/tensorflow/core/graph/costmodel.h new file mode 100644 index 0000000000..4d7dd65f5a --- /dev/null +++ b/tensorflow/core/graph/costmodel.h @@ -0,0 +1,123 @@ +#ifndef TENSORFLOW_GRAPH_COSTMODEL_H_ +#define TENSORFLOW_GRAPH_COSTMODEL_H_ + +#include <unordered_map> +#include <vector> + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { +typedef std::unordered_map<string, int32> NodeNameToCostIdMap; + +class StepStats; + +// CostModel keeps track of the following runtime statistics for nodes +// of a single Graph: +// * The total number of times a node has executed. +// * The accumulated execution time (in microseconds) of a node. +// * The accumulated size (in bytes) of each node's output. +// +// This class is NOT thread-safe. +class CostModel { + public: + // If "global" is true, maintains costs based on Node::cost_id, otherwise + // maintains costs based on Node::id. + explicit CostModel(bool is_global) : is_global_(is_global) {} + + // Assigns min_count_ as a function of the median count for a Node. + // This value is then used for suppressing the time/size costs of + // infrequent operations. + // NOTE(tucker): Maybe this should move to a subclass of CostModel. + void SuppressInfrequent(); + + bool is_global() const { return is_global_; } + + // Initializes cost model for 'g'. + void InitFromGraph(const Graph& g); + + // Merges costs from cm. + // REQUIRES: is_global_ is true for this and for "cm" + void MergeFromGlobal(const CostModel& cm); + + // Merges costs from "cm", which has been computed relative to "g". + // REQUIRES: is_global_ is true for this, and false for "cm". + void MergeFromLocal(const Graph& g, const CostModel& cm); + + void MergeFromStats(const NodeNameToCostIdMap& map, const StepStats& ss); + + // Sets the number of outputs of "node". + void SetNumOutputs(const Node* node, int num_outputs); + + // Records that "node" has executed "num_count" more times. + void RecordCount(const Node* node, int num_count); + + // Returns how many times "node" has been executed. + int32 TotalCount(const Node* node) const; + + // Records that "output_slot" of "node" has produced tensors of + // aggregated "bytes". + void RecordSize(const Node* node, int output_slot, Bytes bytes); + + // Returns total bytes of tensors produced by "node"s output slot. + Bytes TotalBytes(const Node* node, int output_slot) const; + + // Returns a prediction for the size of the tensor at the + // output_slot produced by one execution of "node". + Bytes SizeEstimate(const Node* node, int output_slot) const; + + // Records that Executions of "node" have taken "time" microseconds. + void RecordTime(const Node* node, Microseconds time); + + // Returns the total execution time for "node". + Microseconds TotalTime(const Node* node) const; + + // Returns a prediction for one execution of "node". + Microseconds TimeEstimate(const Node* node) const; + + // Check that an estimate is available for every OP node in graph. + void CheckInitialized(const Graph& graph) const; + + // Helper routines to encapsulate static estimatation heuristics + + // Compute an estimate of the time to copy "b" bytes over the network, + // given a fixed cost of "network_latency_millis" milliseconds and + // an estimated bandwidth of "estimated_gbps" gigabits per second (note that + // this value is in gigabits, not gigabytes). + static Microseconds CopyTimeEstimate(Bytes b, double network_latency_millis, + double estimated_gbps); + static Microseconds ComputationTimeEstimate(int64 mathops); + + // Write the contents of the CostModel to the INFO log. + void WriteToLog(); + + private: + const bool is_global_; + inline int Id(const Node* n) const { + if (is_global_) { + return n->cost_id(); + } else { + return n->id(); + } + } + // Resizes vectors so that they are large enough for "id". + void Ensure(int id); + + // Nodes and Edges whose count is < this value + // get type/byte estimates of 0. + int32 min_count_ = 0; + + // Number of times each Node has been executed. + std::vector<int32> count_; + // Cumulative execution time. + std::vector<Microseconds> time_; + // Cumulative Bytes output on each channel. + std::vector<gtl::InlinedVector<Bytes, 2> > slot_bytes_; + + TF_DISALLOW_COPY_AND_ASSIGN(CostModel); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_COSTMODEL_H_ diff --git a/tensorflow/core/graph/costutil.cc b/tensorflow/core/graph/costutil.cc new file mode 100644 index 0000000000..f8e2d9fe68 --- /dev/null +++ b/tensorflow/core/graph/costutil.cc @@ -0,0 +1,22 @@ +#include "tensorflow/core/graph/costutil.h" + +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/costmodel.h" + +namespace tensorflow { + +std::vector<int64> LongestOutgoingPathCost(const Graph& graph, + const CostModel& cm) { + std::vector<int64> result(graph.num_node_ids()); + DFS(graph, nullptr, [&result, &cm](Node* n) { + int64 max_child = 0; + for (const Node* out : n->out_nodes()) { + max_child = std::max(max_child, result[out->id()]); + } + result[n->id()] = max_child + (n->IsOp() ? cm.TimeEstimate(n).value() : 0); + }); + return result; +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/costutil.h b/tensorflow/core/graph/costutil.h new file mode 100644 index 0000000000..46e5215132 --- /dev/null +++ b/tensorflow/core/graph/costutil.h @@ -0,0 +1,19 @@ +#ifndef TENSORFLOW_GRAPH_COSTUTIL_H_ +#define TENSORFLOW_GRAPH_COSTUTIL_H_ + +#include <vector> +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class CostModel; +class Graph; + +// result[i] is an estimate of the longest execution path from +// the node with id i to the sink node. +std::vector<int64> LongestOutgoingPathCost(const Graph& graph, + const CostModel& cm); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_COSTUTIL_H_ diff --git a/tensorflow/core/graph/default_device.h b/tensorflow/core/graph/default_device.h new file mode 100644 index 0000000000..30cd4e8a57 --- /dev/null +++ b/tensorflow/core/graph/default_device.h @@ -0,0 +1,25 @@ +#ifndef TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_ +#define TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_ + +#include <string> + +#include "tensorflow/core/framework/graph.pb.h" + +namespace tensorflow { +namespace graph { + +// Sets the default device for all nodes in graph_def to "device", +// only if not already set. +inline void SetDefaultDevice(const string& device, GraphDef* graph_def) { + for (int i = 0; i < graph_def->node_size(); ++i) { + auto node = graph_def->mutable_node(i); + if (node->device().empty()) { + node->set_device(device); + } + } +} + +} // namespace graph +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_ diff --git a/tensorflow/core/graph/dot.cc b/tensorflow/core/graph/dot.cc new file mode 100644 index 0000000000..6d6e46ce61 --- /dev/null +++ b/tensorflow/core/graph/dot.cc @@ -0,0 +1,289 @@ +#include "tensorflow/core/graph/dot.h" + +#include <map> +#include <unordered_map> +#include <unordered_set> + +#include "tensorflow/core/graph/colors.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/util/util.h" + +namespace tensorflow { + +static string GraphNodeName(const DotOptions& opts, const Node* n) { + return strings::StrCat("N", n->id()); +} + +bool ShoulDisplayOpType(const Node* n) { + if (n->type_string() == "NoOp") { + return false; + } + const string& op_name = n->def().name(); + if (op_name.find(n->type_string() + "_") == 0) { + return false; + } + return true; +} + +string DotGraph(const Graph& g, const DotOptions& opts) { + RegexpStringPiece flag(opts.prefix_collapse_regexp); + if (flag == "all") { + flag = "."; + } else if (flag == "none") { + flag = "^$"; + } + RE2 cluster_name_pattern(flag); + string result; + strings::StrAppend(&result, "digraph G {\n"); + strings::StrAppend(&result, "rankdir=\"BT\"\n"); + + std::map<string, int> device_index; // Map from device name to index. + std::unordered_set<Node*> visible_nodes; // Nodes to display. + // Cluster name => set of nodes. + std::unordered_map<string, std::unordered_set<Node*> > clusters; + // Node* => Cluster + std::unordered_map<Node*, string> node_cluster; + for (Node* src : g.nodes()) { + if (opts.include_node_function != nullptr && + !opts.include_node_function(src)) { + continue; + } + // Do not display source and sink nodes + if (src->IsSource() || src->IsSink()) { + continue; + } + visible_nodes.insert(src); + const string name_prefix = NodeNamePrefix(src->def().name()).ToString(); + if (!name_prefix.empty()) { + clusters[name_prefix].insert(src); + node_cluster[src] = name_prefix; + } + // Record device if present. + if (src->IsOp()) { + const string& d = src->assigned_device_name(); + if (!d.empty()) { + device_index[d] = -1; // Assigned later + } + } + } + + // Add nodes whose name is exactly a cluster name to the cluster itself. + for (Node* src : g.nodes()) { + if (node_cluster.count(src) == 0) { + const string name = src->def().name(); + auto it = clusters.find(name); + if (it != clusters.end()) { + it->second.insert(src); + node_cluster[src] = name; + } + } + } + + auto node_in_collapsed_cluster = [&node_cluster, + &cluster_name_pattern](Node* n) { + return node_cluster.count(n) > 0 && + RE2::PartialMatch(node_cluster[n], cluster_name_pattern); + }; + + // Assign device indices in sorted order. + int num = 0; + for (auto& e : device_index) { + e.second = num++; + } + + double total_node_cost = 0; + double avg_node_cost = 1; + if (opts.node_cost) { + int node_count = 0; + for (const Node* n : g.nodes()) { + total_node_cost += opts.node_cost(n); + ++node_count; + } + if (total_node_cost > 0) avg_node_cost = total_node_cost / node_count; + } + + for (Node* src : g.nodes()) { + if (visible_nodes.count(src) == 0 || node_in_collapsed_cluster(src)) { + continue; + } + string label = src->name(); + if (ShoulDisplayOpType(src)) { + // Append the op type if it is not directly deducible from the op name. + strings::StrAppend(&label, "\\n(", src->type_string(), ")"); + } + const char* shape = "box"; + const char* color = nullptr; + if (src->IsSource()) { + shape = "oval"; + } else if (src->IsSink()) { + shape = "oval"; + } else { + const string& d = src->assigned_device_name(); + const int dindex = (!d.empty()) ? device_index[d] : -1; + if (dindex >= 0) { + color = ColorFor(dindex); + } + + shape = "box"; + } + + if (opts.node_label) { + string extra = opts.node_label(src); + if (!extra.empty()) { + strings::StrAppend(&label, "\\n", extra); + } + } + + strings::StrAppend(&result, GraphNodeName(opts, src), "[shape=", shape, + ", label=\"", label, "\""); + if (opts.node_cost && total_node_cost > 0) { + // Pick fontsize in range [8..40] so that area is proportional to cost. + const double cost = opts.node_cost(src); + const double relcost = fabs(cost / avg_node_cost); + // Average cost node has font size of 12. + const int fs = 8 + static_cast<int>(4.0 * std::min(sqrt(relcost), 8.0)); + strings::StrAppend(&result, ", width=0, height=0, fontsize=", fs); + VLOG(2) << "Node: " << cost << " => " << relcost << " => " << fs; + } + if (color != nullptr) { + strings::StrAppend(&result, ", fillcolor=\"", color, + "\", fontcolor=\"white\", style=\"filled\""); + } + strings::StrAppend(&result, "]\n"); + } + + for (auto c : clusters) { + const string& cluster_name = c.first; + const std::unordered_set<Node*> nodes = c.second; + std::unordered_map<string, int> node_colors; + for (auto n : nodes) { + const string& d = n->assigned_device_name(); + const int dindex = (!d.empty()) ? device_index[d] : -1; + if (dindex >= 0) { + ++node_colors[ColorFor(dindex)]; + } + } + + string majority_color; + if (node_colors.empty()) { + majority_color = ColorFor(0); + } else { + majority_color = std::max_element(node_colors.begin(), node_colors.end(), + [](const std::pair<string, int>& x, + const std::pair<string, int>& y) { + return x.second < y.second; + }) + ->first; + } + + if (!RE2::PartialMatch(cluster_name, cluster_name_pattern)) { + strings::StrAppend(&result, "subgraph cluster_", cluster_name, "{\n"); + for (auto n : nodes) { + strings::StrAppend(&result, GraphNodeName(opts, n), ";\n"); + } + strings::StrAppend(&result, "}\n"); + } else { + strings::StrAppend(&result, cluster_name, " [shape=oval, fillcolor=\"", + majority_color, "\", label=\"", cluster_name, + "\", style=\"filled\", fontcolor=\"white\"]\n"); + } + } + + std::unordered_set<string> edge_drawn; + + double max_edge_cost = 0; + double total_edge_cost = 0; + double avg_edge_cost = 1; + if (opts.edge_cost && g.edges().size()) { + for (const Edge* e : g.edges()) { + auto cost = opts.edge_cost(e); + total_edge_cost += cost; + max_edge_cost = std::max(max_edge_cost, cost); + } + avg_edge_cost = total_edge_cost / g.edges().size(); + } + VLOG(2) << "Edge cost tot/max/avg: " << total_edge_cost << "/" + << max_edge_cost << "/" << avg_edge_cost; + + for (const Edge* e : g.edges()) { + Node* src = e->src(); + Node* dst = e->dst(); + // If either endpoint isn't drawn in the graph, don't draw the edge + if (visible_nodes.count(src) == 0 || visible_nodes.count(dst) == 0) { + continue; + } + + const string src_name = node_in_collapsed_cluster(src) + ? node_cluster[src] + : GraphNodeName(opts, src); + const string dst_name = node_in_collapsed_cluster(dst) + ? node_cluster[dst] + : GraphNodeName(opts, dst); + // Don't draw self edges + if (src_name == dst_name) { + continue; + } + // And previously drawn edges. + const string& edge_name = strings::StrCat(src_name, ":", dst_name); + if (edge_drawn.count(edge_name) > 0) { + continue; + } + edge_drawn.insert(edge_name); + + strings::StrAppend(&result, src_name, " -> ", dst_name, "["); + string label; + if (e->IsControlEdge()) { + strings::StrAppend(&result, " style=dotted"); + } + if (opts.edge_label) { + string label = opts.edge_label(e); + if (!label.empty()) { + strings::StrAppend(&result, " label=<", label, ">"); + } + } + // Make edge widths proportional to amount of data transferred. + if (opts.edge_cost && max_edge_cost > 0) { + const double cost = opts.edge_cost(e); + const double relcost = fabs(cost / avg_edge_cost); + // Pick penwidth in range [1..6] so that width is proportional to cost. + const int pw = 1 + std::min(5, static_cast<int>(2.0 * relcost)); + strings::StrAppend(&result, " penwidth=", pw); + // Use weight attributes [1..100] to keep heavier edges more vertical. + const int weight = 1 + std::min(99, static_cast<int>(100.0 * relcost)); + strings::StrAppend(&result, " weight=", weight); + VLOG(2) << "Edge: " << cost << " => " << relcost << " => " << pw << "/" + << weight; + } + + strings::StrAppend(&result, "]\n"); + } + // Compute some statistics + int op_nodes = 0; + for (Node* n : g.nodes()) { + if (n->IsOp()) { + op_nodes++; + } + } + + // Emit legend + strings::StrAppend(&result, + "{ rank = source; Legend [shape=box, margin=0, label=<", + "<TABLE BORDER=\"0\" CELLBORDER=\"1\" CELLSPACING=\"0\" ", + "CELLPADDING=\"4\">", "<TR><TD COLSPAN=\"2\">op_nodes: ", + op_nodes, "</TD></TR>\n"); + for (const auto& e : device_index) { + const int dindex = e.second; + strings::StrAppend(&result, "<TR><TD BGCOLOR=\"", ColorFor(dindex), + "\"><FONT COLOR=\"white\">", dindex, "</FONT></TD><TD>", + e.first, "</TD></TR>\n"); + } + strings::StrAppend(&result, "</TABLE>>]}\n"); + + strings::StrAppend(&result, "}\n"); // End digraph + return result; +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/dot.h b/tensorflow/core/graph/dot.h new file mode 100644 index 0000000000..f87f68099c --- /dev/null +++ b/tensorflow/core/graph/dot.h @@ -0,0 +1,43 @@ +#ifndef TENSORFLOW_GRAPH_DOT_H_ +#define TENSORFLOW_GRAPH_DOT_H_ + +#include <functional> +#include <string> +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +class Edge; +class Graph; +class Node; + +struct DotOptions { + bool (*include_node_function)(const Node*) = nullptr; + + // By default, all nodes with the same name prefix are collapsed into + // a single node in the dot graph. This regexp can be changed so that + // only prefixes that match the regexp are collapsed in this fashion. + // 'all' collapses all ops with prefixes, 'none' disables all collapsing. + string prefix_collapse_regexp = "all"; + + // A function that returns a label to embed into the per-node display. + std::function<string(const Node*)> node_label; + + // A function that returns a label to attach to an edge. + std::function<string(const Edge*)> edge_label; + + // A function that returns the "cost" of the node. The dot display + // makes a node size proportional to its cost. + std::function<double(const Node*)> node_cost; + + // A function that returns the "cost" of the edge. The dot display + // makes a edge thickness proportional to its cost. + std::function<double(const Edge*)> edge_cost; +}; + +// Return a string that contains a graphviz specification of the graph. +string DotGraph(const Graph& g, const DotOptions& opts); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_DOT_H_ diff --git a/tensorflow/core/graph/edgeset.cc b/tensorflow/core/graph/edgeset.cc new file mode 100644 index 0000000000..83293c7b4e --- /dev/null +++ b/tensorflow/core/graph/edgeset.cc @@ -0,0 +1,56 @@ +#include "tensorflow/core/graph/edgeset.h" + +namespace tensorflow { + +std::pair<EdgeSet::const_iterator, bool> EdgeSet::insert(value_type value) { + RegisterMutation(); + const_iterator ci; + ci.Init(this); + auto s = get_set(); + if (!s) { + for (int i = 0; i < kInline; i++) { + if (ptrs_[i] == value) { + ci.array_iter_ = &ptrs_[i]; + return std::make_pair(ci, false); + } + } + for (int i = 0; i < kInline; i++) { + if (ptrs_[i] == nullptr) { + ptrs_[i] = value; + ci.array_iter_ = &ptrs_[i]; + return std::make_pair(ci, true); + } + } + // array is full. convert to set. + s = new std::set<const Edge*>; + for (int i = 0; i < kInline; i++) { + s->insert(static_cast<const Edge*>(ptrs_[i])); + } + ptrs_[0] = this; + ptrs_[1] = s; + // fall through. + } + auto p = s->insert(value); + ci.tree_iter_ = p.first; + return std::make_pair(ci, p.second); +} + +EdgeSet::size_type EdgeSet::erase(key_type key) { + RegisterMutation(); + auto s = get_set(); + if (!s) { + for (int i = 0; i < kInline; i++) { + if (ptrs_[i] == key) { + size_t n = size(); + ptrs_[i] = ptrs_[n - 1]; + ptrs_[n - 1] = nullptr; + return 1; + } + } + return 0; + } else { + return s->erase(key); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/edgeset.h b/tensorflow/core/graph/edgeset.h new file mode 100644 index 0000000000..df0d78b8fb --- /dev/null +++ b/tensorflow/core/graph/edgeset.h @@ -0,0 +1,216 @@ +#ifndef TENSORFLOW_GRAPH_EDGESET_H_ +#define TENSORFLOW_GRAPH_EDGESET_H_ + +#include <stddef.h> +#include <set> +#include "tensorflow/core/platform/port.h" + +#include "tensorflow/core/platform/logging.h" +namespace tensorflow { + +class Edge; + +// An unordered set of edges. Uses very little memory for small sets. +// Unlike std::set, EdgeSet does NOT allow mutations during iteration. +class EdgeSet { + public: + EdgeSet(); + ~EdgeSet(); + + typedef const Edge* key_type; + typedef const Edge* value_type; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + + class const_iterator; + typedef const_iterator iterator; + + bool empty() const; + size_type size() const; + void clear(); + std::pair<iterator, bool> insert(value_type value); + size_type erase(key_type key); + + // Caller is not allowed to mutate the EdgeSet while iterating. + const_iterator begin() const; + const_iterator end() const; + + private: + // Up to kInline elements are stored directly in ptrs_ (nullptr means none). + // If ptrs_[0] == this then ptrs_[1] points to a set<const Edge*>. + static const int kInline = 2; // Must be >= 2. + const void* ptrs_[kInline]; + + std::set<const Edge*>* get_set() const { + if (ptrs_[0] == this) { + return static_cast<std::set<const Edge*>*>(const_cast<void*>(ptrs_[1])); + } else { + return nullptr; + } + } + +// To detect mutations while iterating. +#ifdef NDEBUG + void RegisterMutation() {} +#else + uint32 mutations_ = 0; + void RegisterMutation() { mutations_++; } +#endif + + TF_DISALLOW_COPY_AND_ASSIGN(EdgeSet); +}; + +class EdgeSet::const_iterator { + public: + typedef typename EdgeSet::value_type value_type; + typedef const typename EdgeSet::value_type& reference; + typedef const typename EdgeSet::value_type* pointer; + typedef typename EdgeSet::difference_type difference_type; + typedef std::forward_iterator_tag iterator_category; + + const_iterator() {} + + const_iterator& operator++(); + const_iterator operator++(int /*unused*/); + const value_type* operator->() const; + value_type operator*() const; + bool operator==(const const_iterator& other) const; + bool operator!=(const const_iterator& other) const { + return !(*this == other); + } + + private: + friend class EdgeSet; + + void const* const* array_iter_ = nullptr; + typename std::set<const Edge*>::const_iterator tree_iter_; + +#ifdef NDEBUG + inline void Init(const EdgeSet* e) {} + inline void CheckNoMutations() const {} +#else + inline void Init(const EdgeSet* e) { + owner_ = e; + init_mutations_ = e->mutations_; + } + inline void CheckNoMutations() const { + CHECK_EQ(init_mutations_, owner_->mutations_); + } + const EdgeSet* owner_ = nullptr; + uint32 init_mutations_ = 0; +#endif +}; + +inline EdgeSet::EdgeSet() { + for (int i = 0; i < kInline; i++) { + ptrs_[i] = nullptr; + } +} + +inline EdgeSet::~EdgeSet() { delete get_set(); } + +inline bool EdgeSet::empty() const { return size() == 0; } + +inline EdgeSet::size_type EdgeSet::size() const { + auto s = get_set(); + if (s) { + return s->size(); + } else { + size_t result = 0; + for (int i = 0; i < kInline; i++) { + if (ptrs_[i]) result++; + } + return result; + } +} + +inline void EdgeSet::clear() { + RegisterMutation(); + delete get_set(); + for (int i = 0; i < kInline; i++) { + ptrs_[i] = nullptr; + } +} + +inline EdgeSet::const_iterator EdgeSet::begin() const { + const_iterator ci; + ci.Init(this); + auto s = get_set(); + if (s) { + ci.tree_iter_ = s->begin(); + } else { + ci.array_iter_ = &ptrs_[0]; + } + return ci; +} + +inline EdgeSet::const_iterator EdgeSet::end() const { + const_iterator ci; + ci.Init(this); + auto s = get_set(); + if (s) { + ci.tree_iter_ = s->end(); + } else { + ci.array_iter_ = &ptrs_[size()]; + } + return ci; +} + +inline EdgeSet::const_iterator& EdgeSet::const_iterator::operator++() { + CheckNoMutations(); + if (array_iter_ != nullptr) { + ++array_iter_; + } else { + ++tree_iter_; + } + return *this; +} + +inline EdgeSet::const_iterator EdgeSet::const_iterator::operator++( + int /*unused*/) { + CheckNoMutations(); + const_iterator tmp = *this; + operator++(); + return tmp; +} + +// gcc's set and multiset always use const_iterator since it will otherwise +// allow modification of keys. +inline const EdgeSet::const_iterator::value_type* EdgeSet::const_iterator:: +operator->() const { + CheckNoMutations(); + if (array_iter_ != nullptr) { + return reinterpret_cast<const value_type*>(array_iter_); + } else { + return tree_iter_.operator->(); + } +} + +// gcc's set and multiset always use const_iterator since it will otherwise +// allow modification of keys. +inline EdgeSet::const_iterator::value_type EdgeSet::const_iterator::operator*() + const { + CheckNoMutations(); + if (array_iter_ != nullptr) { + return static_cast<value_type>(*array_iter_); + } else { + return *tree_iter_; + } +} + +inline bool EdgeSet::const_iterator::operator==( + const const_iterator& other) const { + DCHECK((array_iter_ == nullptr) == (other.array_iter_ == nullptr)) + << "Iterators being compared must be from same set that has not " + << "been modified since the iterator was constructed"; + CheckNoMutations(); + if (array_iter_ != nullptr) { + return array_iter_ == other.array_iter_; + } else { + return other.array_iter_ == nullptr && tree_iter_ == other.tree_iter_; + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_EDGESET_H_ diff --git a/tensorflow/core/graph/edgeset_test.cc b/tensorflow/core/graph/edgeset_test.cc new file mode 100644 index 0000000000..7909e8ea0a --- /dev/null +++ b/tensorflow/core/graph/edgeset_test.cc @@ -0,0 +1,95 @@ +#include "tensorflow/core/graph/edgeset.h" + +#include "tensorflow/core/graph/graph.h" +#include <gtest/gtest.h> + +namespace tensorflow { +class EdgeSetTest : public ::testing::Test { + public: + EdgeSetTest() : edges_(nullptr), eset_(nullptr) {} + + ~EdgeSetTest() override { + delete eset_; + delete[] edges_; + } + + void MakeEdgeSet(int n) { + delete eset_; + delete[] edges_; + edges_ = new Edge[n]; + eset_ = new EdgeSet; + model_.clear(); + for (int i = 0; i < n; i++) { + eset_->insert(&edges_[i]); + model_.insert(&edges_[i]); + } + } + + void CheckSame() { + EXPECT_EQ(model_.size(), eset_->size()); + EXPECT_EQ(model_.empty(), eset_->empty()); + std::vector<const Edge*> modelv(model_.begin(), model_.end()); + std::vector<const Edge*> esetv(eset_->begin(), eset_->end()); + std::sort(modelv.begin(), modelv.end()); + std::sort(esetv.begin(), esetv.end()); + EXPECT_EQ(modelv.size(), esetv.size()); + for (size_t i = 0; i < modelv.size(); i++) { + EXPECT_EQ(modelv[i], esetv[i]) << i; + } + } + + Edge nonexistent_; + Edge* edges_; + EdgeSet* eset_; + std::set<const Edge*> model_; +}; + +namespace { + +TEST_F(EdgeSetTest, Ops) { + for (int n : {0, 1, 2, 3, 4, 10}) { + MakeEdgeSet(n); + CheckSame(); + EXPECT_EQ((n == 0), eset_->empty()); + EXPECT_EQ(n, eset_->size()); + + eset_->clear(); + model_.clear(); + CheckSame(); + + eset_->insert(&edges_[0]); + model_.insert(&edges_[0]); + CheckSame(); + } +} + +// Try insert/erase of existing elements at different positions. +TEST_F(EdgeSetTest, Exists) { + for (int n : {0, 1, 2, 3, 4, 10}) { + MakeEdgeSet(n); + for (int pos = 0; pos < n; pos++) { + MakeEdgeSet(n); + auto p = eset_->insert(&edges_[pos]); + EXPECT_FALSE(p.second); + EXPECT_EQ(&edges_[pos], *p.first); + + EXPECT_EQ(1, eset_->erase(&edges_[pos])); + model_.erase(&edges_[pos]); + CheckSame(); + } + } +} + +// Try insert/erase of non-existent element. +TEST_F(EdgeSetTest, DoesNotExist) { + for (int n : {0, 1, 2, 3, 4, 10}) { + MakeEdgeSet(n); + EXPECT_EQ(0, eset_->erase(&nonexistent_)); + auto p = eset_->insert(&nonexistent_); + EXPECT_TRUE(p.second); + EXPECT_EQ(&nonexistent_, *p.first); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/equal_graph_def.cc b/tensorflow/core/graph/equal_graph_def.cc new file mode 100644 index 0000000000..35f59b5ed0 --- /dev/null +++ b/tensorflow/core/graph/equal_graph_def.cc @@ -0,0 +1,176 @@ +#include "tensorflow/core/graph/equal_graph_def.h" + +#include <unordered_map> +#include <unordered_set> +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, + string* diff) { + std::unordered_map<string, const NodeDef*> actual_index; + for (const NodeDef& node : actual.node()) { + actual_index[node.name()] = &node; + } + + for (const NodeDef& expected_node : expected.node()) { + auto actual_iter = actual_index.find(expected_node.name()); + if (actual_iter == actual_index.end()) { + if (diff != nullptr) { + *diff = strings::StrCat("Did not find expected node '", + SummarizeNodeDef(expected_node), "'"); + } + return false; + } + + if (!EqualNodeDef(*actual_iter->second, expected_node, diff)) return false; + + actual_index.erase(actual_iter); + } + + if (!actual_index.empty()) { + if (diff != nullptr) { + *diff = strings::StrCat("Found unexpected node '", + SummarizeNodeDef(*actual_index.begin()->second), + "' not in expected graph:\n", + SummarizeGraphDef(expected)); + } + return false; + } + + return true; +} + +namespace { + +string JoinStringField(const protobuf::RepeatedPtrField<string>& f) { + string ret; + for (int i = 0; i < f.size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, f.Get(i)); + } + return ret; +} + +} // namespace + +bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, + string* diff) { + if (actual.name() != expected.name()) { + if (diff != nullptr) { + *diff = strings::StrCat("Actual node name '", actual.name(), + "' is not expected '", expected.name(), "'"); + } + return false; + } + + if (actual.op() != expected.op()) { + if (diff != nullptr) { + *diff = strings::StrCat("Node named '", actual.name(), "' has op '", + actual.op(), "' that is not expected '", + expected.op(), "'"); + } + return false; + } + + if (actual.device() != expected.device()) { + if (diff != nullptr) { + *diff = strings::StrCat("Node named '", actual.name(), "' has device '", + actual.device(), "' that is not expected '", + expected.device(), "'"); + } + return false; + } + + if (actual.input_size() != expected.input_size()) { + if (diff != nullptr) { + *diff = strings::StrCat("Node named '", actual.name(), "' has inputs '", + JoinStringField(actual.input()), + "' that don't match expected '", + JoinStringField(expected.input()), "'"); + } + return false; + } + + int first_control_input = actual.input_size(); + for (int i = 0; i < actual.input_size(); ++i) { + if (StringPiece(actual.input(i)).starts_with("^")) { + first_control_input = i; + break; + } + if (actual.input(i) != expected.input(i)) { + if (diff != nullptr) { + *diff = strings::StrCat("Node named '", actual.name(), "' has input ", + i, " '", actual.input(i), + "' that doesn't match expected '", + expected.input(i), "'"); + } + return false; + } + } + + std::unordered_set<string> actual_control; + std::unordered_set<string> expected_control; + for (int i = first_control_input; i < actual.input_size(); ++i) { + actual_control.insert(actual.input(i)); + expected_control.insert(expected.input(i)); + } + for (const auto& e : expected_control) { + if (actual_control.erase(e) == 0) { + if (diff != nullptr) { + *diff = strings::StrCat("Node named '", actual.name(), + "' missing expected control input '", e, "'"); + } + return false; + } + } + if (!actual_control.empty()) { + if (diff != nullptr) { + *diff = strings::StrCat("Node named '", actual.name(), + "' has unexpected control input '", + *actual_control.begin(), "'"); + } + return false; + } + + std::unordered_set<string> actual_attr; + for (const auto& a : actual.attr()) { + actual_attr.insert(a.first); + } + for (const auto& e : expected.attr()) { + if (actual_attr.erase(e.first) == 0) { + if (diff != nullptr) { + *diff = strings::StrCat("Node named '", actual.name(), + "' missing expected attr '", e.first, + "' with value: ", SummarizeAttrValue(e.second)); + } + return false; + } + auto iter = actual.attr().find(e.first); + if (!AreAttrValuesEqual(e.second, iter->second)) { + if (diff != nullptr) { + *diff = strings::StrCat( + "Node named '", actual.name(), "' has attr '", e.first, + "' with value: ", SummarizeAttrValue(iter->second), + " that does not match expected: ", SummarizeAttrValue(e.second)); + } + return false; + } + } + if (!actual_attr.empty()) { + if (diff != nullptr) { + *diff = strings::StrCat( + "Node named '", actual.name(), "' has unexpected attr '", + *actual_attr.begin(), "' with value: ", + SummarizeAttrValue(actual.attr().find(*actual_attr.begin())->second)); + } + return false; + } + + return true; +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/equal_graph_def.h b/tensorflow/core/graph/equal_graph_def.h new file mode 100644 index 0000000000..7dd8aab340 --- /dev/null +++ b/tensorflow/core/graph/equal_graph_def.h @@ -0,0 +1,32 @@ +#ifndef TENSORFLOW_GRAPH_EQUAL_GRAPH_DEF_H_ +#define TENSORFLOW_GRAPH_EQUAL_GRAPH_DEF_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// Determines if actual and expected are equal, ignoring ordering of +// nodes, attrs, and control inputs. If the GraphDefs are different +// and diff != nullptr, *diff is set to an explanation of the +// difference. Note that we use node names to match up nodes between +// the graphs, and so the naming of nodes must be consistent. +bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, + string* diff); + +// Determines if actual and expected are equal, ignoring ordering of +// attrs and control inputs. If the NodeDefs are different and +// diff != nullptr, *diff is set to an explanation of the difference. +bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff); + +#define TF_EXPECT_GRAPH_EQ(expected, actual) \ + do { \ + string diff; \ + EXPECT_TRUE(EqualGraphDef(actual, expected, &diff)) \ + << diff << "\nActual: " << SummarizeGraphDef(actual); \ + } while (false) + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_EQUAL_GRAPH_DEF_H_ diff --git a/tensorflow/core/graph/equal_graph_def_test.cc b/tensorflow/core/graph/equal_graph_def_test.cc new file mode 100644 index 0000000000..3a38b9e522 --- /dev/null +++ b/tensorflow/core/graph/equal_graph_def_test.cc @@ -0,0 +1,279 @@ +#include "tensorflow/core/graph/equal_graph_def.h" + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/kernels/ops_util.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace { + +REGISTER_OP("Input").Output("o: float"); +REGISTER_OP("Alternate").Output("o: float"); +REGISTER_OP("Cross").Input("a: float").Input("b: float").Output("o: float"); + +Node* Input(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("Input", opts); +} + +Node* Alternate(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("Alternate", opts); +} + +Node* Cross(ops::NodeOut a, ops::NodeOut b, + const GraphDefBuilder::Options& opts) { + return ops::BinaryOp("Cross", a, b, opts); +} + +class EqualGraphDefTest : public ::testing::Test { + protected: + EqualGraphDefTest() + : e_(GraphDefBuilder::kFailImmediately), + a_(GraphDefBuilder::kFailImmediately) { + RequireDefaultOps(); + } + + bool Match() { + GraphDef expected; + e_.ToGraphDef(&expected); + GraphDef actual; + a_.ToGraphDef(&actual); + return EqualGraphDef(actual, expected, &diff_); + } + + GraphDefBuilder e_; + GraphDefBuilder a_; + string diff_; +}; + +TEST_F(EqualGraphDefTest, Match) { + Input(e_.opts().WithName("A")); + Input(a_.opts().WithName("A")); + EXPECT_TRUE(Match()) << diff_; +} + +TEST_F(EqualGraphDefTest, NoMatch) { + Input(e_.opts().WithName("A")); + Input(a_.opts().WithName("B")); + EXPECT_FALSE(Match()); + EXPECT_EQ("Did not find expected node 'A = Input[]()'", diff_); +} + +TEST_F(EqualGraphDefTest, MissingNode) { + Input(e_.opts().WithName("A")); + Input(e_.opts().WithName("B")); + Input(a_.opts().WithName("A")); + EXPECT_FALSE(Match()); + EXPECT_EQ("Did not find expected node 'B = Input[]()'", diff_); +} + +TEST_F(EqualGraphDefTest, ExtraNode) { + Input(e_.opts().WithName("A")); + Input(a_.opts().WithName("A")); + Input(a_.opts().WithName("B")); + EXPECT_FALSE(Match()); + EXPECT_EQ( + "Found unexpected node 'B = Input[]()' not in expected graph:\n" + "A = Input[]();\n", + diff_); +} + +TEST_F(EqualGraphDefTest, NodeOrder) { + Node* a = Input(e_.opts().WithName("A")); + Node* b = Input(e_.opts().WithName("B")); + Cross(a, b, e_.opts().WithName("C")); + + b = Input(a_.opts().WithName("B")); + a = Input(a_.opts().WithName("A")); + Cross(a, b, a_.opts().WithName("C")); + EXPECT_TRUE(Match()) << diff_; +} + +TEST_F(EqualGraphDefTest, NameMismatch) { + Node* a = Input(e_.opts().WithName("A")); + Node* b = Input(e_.opts().WithName("B")); + // Have to call EqualNodeDef() directly here, since EqualGraphDef() + // only calls EqualNodeDef() with nodes that have matching names. + EXPECT_FALSE(EqualNodeDef(a->def(), b->def(), &diff_)); + EXPECT_EQ("Actual node name 'A' is not expected 'B'", diff_); +} + +TEST_F(EqualGraphDefTest, OpMismatch) { + Input(e_.opts().WithName("A")); + Alternate(a_.opts().WithName("A")); + EXPECT_FALSE(Match()); + EXPECT_EQ("Node named 'A' has op 'Alternate' that is not expected 'Input'", + diff_); +} + +TEST_F(EqualGraphDefTest, DeviceMatch) { + Input(e_.opts().WithName("A").WithDevice("/cpu:0")); + Input(a_.opts().WithName("A").WithDevice("/cpu:0")); + EXPECT_TRUE(Match()) << diff_; +} + +TEST_F(EqualGraphDefTest, DeviceMismatch) { + Input(e_.opts().WithName("A").WithDevice("/cpu:0")); + Input(a_.opts().WithName("A").WithDevice("/cpu:1")); + EXPECT_FALSE(Match()); + EXPECT_EQ("Node named 'A' has device '/cpu:1' that is not expected '/cpu:0'", + diff_); +} + +TEST_F(EqualGraphDefTest, InputMismatch) { + Node* a = Input(e_.opts().WithName("A")); + Node* b = Input(e_.opts().WithName("B")); + Cross(a, a, e_.opts().WithName("C")); + + a = Input(a_.opts().WithName("A")); + b = Input(a_.opts().WithName("B")); + Cross(b, b, a_.opts().WithName("C")); + EXPECT_FALSE(Match()); + EXPECT_EQ("Node named 'C' has input 0 'B' that doesn't match expected 'A'", + diff_); +} + +TEST_F(EqualGraphDefTest, InputOrderMismatch) { + Node* a = Input(e_.opts().WithName("A")); + Node* b = Input(e_.opts().WithName("B")); + Cross(a, b, e_.opts().WithName("C")); + + a = Input(a_.opts().WithName("A")); + b = Input(a_.opts().WithName("B")); + Cross(b, a, a_.opts().WithName("C")); + EXPECT_FALSE(Match()); + EXPECT_EQ("Node named 'C' has input 0 'B' that doesn't match expected 'A'", + diff_); +} + +TEST_F(EqualGraphDefTest, ControlInputOrder) { + Node* a = Input(e_.opts().WithName("A")); + Node* b = Input(e_.opts().WithName("B")); + Node* c = Input(e_.opts().WithName("C")); + Node* d = Input(e_.opts().WithName("D")); + Cross(a, a, e_.opts() + .WithName("E") + .WithControlInput(b) + .WithControlInput(c) + .WithControlInput(d)); + + a = Input(a_.opts().WithName("A")); + b = Input(a_.opts().WithName("B")); + c = Input(a_.opts().WithName("C")); + d = Input(a_.opts().WithName("D")); + Cross(a, a, a_.opts() + .WithName("E") + .WithControlInput(c) + .WithControlInput(d) + .WithControlInput(b)); + EXPECT_TRUE(Match()) << diff_; +} + +TEST_F(EqualGraphDefTest, ControlInputMismatch) { + Node* a = Input(e_.opts().WithName("A")); + Node* b = Input(e_.opts().WithName("B")); + Node* c = Input(e_.opts().WithName("C")); + Node* d = Input(e_.opts().WithName("D")); + Cross(a, a, e_.opts().WithName("E").WithControlInput(b).WithControlInput(c)); + + a = Input(a_.opts().WithName("A")); + b = Input(a_.opts().WithName("B")); + c = Input(a_.opts().WithName("C")); + d = Input(a_.opts().WithName("D")); + Cross(a, a, a_.opts().WithName("E").WithControlInput(b).WithControlInput(d)); + EXPECT_FALSE(Match()); + EXPECT_EQ("Node named 'E' missing expected control input '^C'", diff_); +} + +TEST_F(EqualGraphDefTest, ControlInputAdded) { + Node* a = Input(e_.opts().WithName("A")); + Node* b = Input(e_.opts().WithName("B")); + Node* c = Input(e_.opts().WithName("C")); + Cross(a, a, e_.opts().WithName("D").WithControlInput(b)); + + a = Input(a_.opts().WithName("A")); + b = Input(a_.opts().WithName("B")); + c = Input(a_.opts().WithName("C")); + Cross(a, a, a_.opts().WithName("D").WithControlInput(b).WithControlInput(c)); + EXPECT_FALSE(Match()); + EXPECT_EQ( + "Node named 'D' has inputs 'A, A, ^B, ^C' that don't match " + "expected 'A, A, ^B'", + diff_); +} + +TEST_F(EqualGraphDefTest, ControlInputRemoved) { + Node* a = Input(e_.opts().WithName("A")); + Node* b = Input(e_.opts().WithName("B")); + Node* c = Input(e_.opts().WithName("C")); + Cross(a, a, e_.opts().WithName("D").WithControlInput(b).WithControlInput(c)); + + a = Input(a_.opts().WithName("A")); + b = Input(a_.opts().WithName("B")); + c = Input(a_.opts().WithName("C")); + Cross(a, a, a_.opts().WithName("D").WithControlInput(b)); + EXPECT_FALSE(Match()); + EXPECT_EQ( + "Node named 'D' has inputs 'A, A, ^B' that don't match " + "expected 'A, A, ^B, ^C'", + diff_); +} + +TEST_F(EqualGraphDefTest, Attr) { + Node* a = Input(e_.opts().WithName("A")); + NodeDef same(a->def()); + AddNodeAttr("foo", "bar", &same); + EXPECT_TRUE(EqualNodeDef(same, same, &diff_)) << diff_; +} + +TEST_F(EqualGraphDefTest, AttrAdded) { + Node* a = Input(e_.opts().WithName("A")); + NodeDef actual(a->def()); + AddNodeAttr("foo", "bar", &actual); + EXPECT_FALSE(EqualNodeDef(actual, a->def(), &diff_)); + EXPECT_EQ("Node named 'A' has unexpected attr 'foo' with value: \"bar\"", + diff_); +} + +TEST_F(EqualGraphDefTest, AttrRemoved) { + Node* a = Input(e_.opts().WithName("A")); + NodeDef expected(a->def()); + AddNodeAttr("foo", "bar", &expected); + EXPECT_FALSE(EqualNodeDef(a->def(), expected, &diff_)); + EXPECT_EQ("Node named 'A' missing expected attr 'foo' with value: \"bar\"", + diff_); +} + +TEST_F(EqualGraphDefTest, AttrOrder) { + Node* a = Input(e_.opts().WithName("A")); + NodeDef actual(a->def()); + AddNodeAttr("foo", "bar", &actual); + AddNodeAttr("baz", 42, &actual); + + NodeDef expected(a->def()); + AddNodeAttr("baz", 42, &expected); + AddNodeAttr("foo", "bar", &expected); + + EXPECT_TRUE(EqualNodeDef(actual, expected, &diff_)) << diff_; +} + +TEST_F(EqualGraphDefTest, AttrMismatch) { + Node* a = Input(e_.opts().WithName("A")); + NodeDef actual(a->def()); + AddNodeAttr("foo", "bar", &actual); + AddNodeAttr("baz", 5, &actual); + + NodeDef expected(a->def()); + AddNodeAttr("baz", 42, &expected); + AddNodeAttr("foo", "bar", &expected); + + EXPECT_FALSE(EqualNodeDef(actual, expected, &diff_)); + EXPECT_EQ( + "Node named 'A' has attr 'baz' with value: 5 that does not match " + "expected: 42", + diff_); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc new file mode 100644 index 0000000000..0c268a51a9 --- /dev/null +++ b/tensorflow/core/graph/graph.cc @@ -0,0 +1,319 @@ +#include "tensorflow/core/graph/graph.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// Node + +string Node::DebugString() const { + if (this == nullptr) { + return "{nullptr}"; + } + string ret = strings::StrCat("{name:'", name(), "' id:", id_); + if (IsSource()) { + strings::StrAppend(&ret, " source}"); + } else if (IsSink()) { + strings::StrAppend(&ret, " sink}"); + } else { + strings::StrAppend(&ret, " op device:"); + strings::StrAppend(&ret, "{", assigned_device_name_, "}"); + strings::StrAppend(&ret, " def:{", SummarizeNodeDef(def()), "}}"); + } + return ret; +} + +Node::Node() + : id_(-1), cost_id_(-1), props_(nullptr), assigned_device_name_() {} + +Node::~Node() { + if (props_) { + props_->Unref(); + } +} + +void Node::Initialize(int id, int cost_id, Properties* props) { + DCHECK_EQ(id_, -1); + DCHECK(in_edges_.empty()); + DCHECK(out_edges_.empty()); + id_ = id; + cost_id_ = cost_id; + + // Unref the old, assign the new properties. + if (props_) { + props_->Unref(); + } + props_ = props; +} + +void Node::Clear() { + in_edges_.clear(); + out_edges_.clear(); + id_ = -1; + cost_id_ = -1; + + if (props_) { + props_->Unref(); + props_ = nullptr; + } + + assigned_device_name_.clear(); +} + +gtl::iterator_range<NeighborIter> Node::out_nodes() const { + return gtl::make_range(NeighborIter(out_edges_.begin(), false), + NeighborIter(out_edges_.end(), false)); +} + +gtl::iterator_range<NeighborIter> Node::in_nodes() const { + return gtl::make_range(NeighborIter(in_edges_.begin(), true), + NeighborIter(in_edges_.end(), true)); +} + +// Node::Properties + +Node::Properties::Properties(const OpDef* op_def, const NodeDef& node_def, + const DataTypeSlice inputs, + const DataTypeSlice outputs) + : op_def_(op_def), + node_def_(node_def), + input_types_(inputs.begin(), inputs.end()), + output_types_(outputs.begin(), outputs.end()) {} + +Node::Properties::~Properties() {} + +// Graph + +Graph::Graph(const OpRegistryInterface* ops) + : ops_(ops), arena_(8 << 10 /* 8kB */) { + // Source and sink have no endpoints, just control edges. + NodeDef def; + def.set_name("_SOURCE"); + def.set_op("NoOp"); + Status status; + Node* source = AddNode(def, &status); + TF_CHECK_OK(status); + CHECK_EQ(source->id(), kSourceId); + + def.set_name("_SINK"); + Node* sink = AddNode(def, &status); + TF_CHECK_OK(status); + CHECK_EQ(sink->id(), kSinkId); + + AddControlEdge(source, sink); +} + +Graph::~Graph() { + // Manually call the destructors for all the Nodes we constructed using + // placement new. + for (Node* node : nodes_) { + if (node != nullptr) { + node->~Node(); + } + } + for (Node* node : free_nodes_) { + node->~Node(); + } + // Edges have no destructor, and we arena-allocated them, so no need to + // destroy them. +} + +Node* Graph::AddNode(const NodeDef& node_def, Status* status) { + const OpDef* op_def = ops_->LookUp(node_def.op(), status); + if (op_def == nullptr) return nullptr; + + // TODO(vrv,josh11b): Find a location higher in the stack to add these defaults + // to the NodeDef. + NodeDef node_def_with_defaults(node_def); + AddDefaultsToNodeDef(*op_def, &node_def_with_defaults); + + DataTypeVector inputs; + DataTypeVector outputs; + status->Update( + InOutTypesForNode(node_def_with_defaults, *op_def, &inputs, &outputs)); + if (!status->ok()) { + *status = AttachDef(*status, node_def_with_defaults); + return nullptr; + } + + Node* node = AllocateNode( + new Node::Properties(op_def, node_def_with_defaults, inputs, outputs), + nullptr); + return node; +} + +Node* Graph::CopyNode(Node* node) { + DCHECK(!node->IsSource()); + DCHECK(!node->IsSink()); + Node::Properties* props = node->properties(); + props->Ref(); + Node* copy = AllocateNode(props, node); + copy->set_assigned_device_name(node->assigned_device_name()); + return copy; +} + +void Graph::RemoveNode(Node* node) { + DCHECK(IsValidNode(node)) << node->DebugString(); + DCHECK(!node->IsSource()); + DCHECK(!node->IsSink()); + + // Remove any edges involving this node. + while (!node->in_edges_.empty()) { + RemoveEdge(*node->in_edges_.begin()); + } + while (!node->out_edges_.empty()) { + RemoveEdge(*node->out_edges_.begin()); + } + ReleaseNode(node); +} + +const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) { + DCHECK(IsValidNode(source)) << source->DebugString(); + DCHECK(IsValidNode(dest)) << dest->DebugString(); + + // source/sink must only be linked via control slots, and + // control slots must only be linked to control slots. + if (source == source_node() || dest == sink_node() || x == kControlSlot || + y == kControlSlot) { + DCHECK_EQ(x, kControlSlot) << source->DebugString(); + DCHECK_EQ(y, kControlSlot) << dest->DebugString(); + } + + Edge* e = nullptr; + if (free_edges_.empty()) { + e = new (arena_.Alloc(sizeof(Edge))) Edge; // placement new + } else { + e = free_edges_.back(); + free_edges_.pop_back(); + } + e->id_ = edges_.size(); + e->src_ = source; + e->dst_ = dest; + e->src_output_ = x; + e->dst_input_ = y; + CHECK(source->out_edges_.insert(e).second); + CHECK(dest->in_edges_.insert(e).second); + edges_.push_back(e); + edge_set_.insert(e); + return e; +} + +void Graph::RemoveEdge(const Edge* e) { + DCHECK(IsValidNode(e->src_)) << e->src_->DebugString(); + DCHECK(IsValidNode(e->dst_)) << e->dst_->DebugString(); + CHECK_EQ(e->src_->out_edges_.erase(e), 1); + CHECK_EQ(e->dst_->in_edges_.erase(e), 1); + CHECK_EQ(e, edges_[e->id_]); + + CHECK_EQ(edge_set_.erase(e), 1); + edges_[e->id_] = nullptr; + + Edge* del = const_cast<Edge*>(e); + del->src_ = nullptr; + del->dst_ = nullptr; + del->id_ = -1; + del->src_output_ = kControlSlot - 1; + del->dst_input_ = kControlSlot - 1; + free_edges_.push_back(del); +} + +namespace { + +void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) { + if (src_slot == Graph::kControlSlot) { + dst->add_input(strings::StrCat("^", src_name)); + } else if (src_slot == 0) { + dst->add_input(src_name.data(), src_name.size()); + } else { + dst->add_input(strings::StrCat(src_name, ":", src_slot)); + } +} + +} // namespace + +void Graph::ToGraphDef(GraphDef* graph_def) const { + graph_def->Clear(); + std::vector<const Edge*> + inputs; // Construct this outside the loop for speed. + for (const Node* node : nodes()) { + if (!node->IsOp()) continue; + NodeDef* node_def = graph_def->add_node(); + *node_def = node->def(); + + // Use the node's assigned device, if any, instead of the device requested + // in the NodeDef. + if (!node->assigned_device_name().empty()) { + node_def->set_device(node->assigned_device_name()); + } + + // Get the inputs for this Node. We make sure control inputs are + // after data inputs, as required by GraphDef. + inputs.clear(); + inputs.resize(node->num_inputs(), nullptr); + for (const Edge* edge : node->in_edges()) { + if (edge->IsControlEdge()) { + inputs.push_back(edge); + } else { + DCHECK(inputs[edge->dst_input()] == nullptr); + inputs[edge->dst_input()] = edge; + } + } + node_def->clear_input(); + for (size_t i = 0; i < inputs.size(); ++i) { + const Edge* edge = inputs[i]; + if (edge == nullptr) { + node_def->add_input(node->def().input(i)); + } else { + const Node* src = edge->src(); + if (!src->IsOp()) continue; + AddInput(node_def, src->name(), edge->src_output()); + } + } + } +} + +string Graph::NewName(StringPiece prefix) { + return strings::StrCat(prefix, "/_", name_counter_++); +} + +gtl::iterator_range<NodeIter> Graph::nodes() const { + // Note that NodeId 0 is always valid since we don't let the source + // node be removed from the graph. + return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids())); +} + +bool Graph::IsValidNode(Node* node) const { + if (node == nullptr) return false; + const int id = node->id(); + if (id < 0 || static_cast<size_t>(id) >= nodes_.size()) return false; + return nodes_[id] == node; +} + +Node* Graph::AllocateNode(Node::Properties* props, const Node* cost_node) { + Node* node = nullptr; + if (free_nodes_.empty()) { + node = new (arena_.Alloc(sizeof(Node))) Node; // placement new + } else { + node = free_nodes_.back(); + free_nodes_.pop_back(); + } + const int id = nodes_.size(); + int cost_id = cost_node ? cost_node->cost_id() : id; + node->Initialize(id, cost_id, props); + nodes_.push_back(node); + return node; +} + +void Graph::ReleaseNode(Node* node) { + DCHECK(IsValidNode(node)) << node->DebugString(); + nodes_[node->id()] = nullptr; + free_nodes_.push_back(node); + node->Clear(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h new file mode 100644 index 0000000000..030e471bf4 --- /dev/null +++ b/tensorflow/core/graph/graph.h @@ -0,0 +1,440 @@ +// A Graph describes a set of computations that are to be +// performed, as well as the dependencies between those +// compuations. The basic model is a DAG (directed acyclic graph) with +// * internal nodes representing computational operations to be performed; +// * edges represent dependencies, indicating the target may only be +// executed once the source has completed; and +// * predefined "source" (start) and "sink" (finish) nodes -- the source +// should be the only node that doesn't depend on anything, and the sink +// should be the only node that nothing depends on. +// +// Note: Node ids are intended to be relatively dense in the +// 0..max_id range, but there may be gaps since ids won't be reused. +// +// Note: Some dependencies between operations are due to one operation +// consuming the output of another. In fact operations can produce +// multiple outputs and consume multiple inputs, and some +// optimizations will care about which specific outputs are connected +// to which specific inputs. We therefore represent data dependency +// between output O of layer A and input I of layer B using +// "input index" and "output index" labels per edge. + +#ifndef TENSORFLOW_GRAPH_GRAPH_H_ +#define TENSORFLOW_GRAPH_GRAPH_H_ + +#include <functional> +#include <string> +#include <vector> +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/edgeset.h" +#include "tensorflow/core/lib/core/arena.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/iterator_range.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +class Edge; +class EdgeSetTest; +class Graph; +class Node; + +class NeighborIter; // Declared below +class NodeIter; // Declared below + +class Node { + public: + string DebugString() const; + int id() const { return id_; } + int cost_id() const { return cost_id_; } + const string& name() const { return props_->node_def_.name(); } + const string& type_string() const { return props_->node_def_.op(); } + const NodeDef& def() const { return props_->node_def_; } + const OpDef& op_def() const { return *props_->op_def_; } + + // input and output types + int num_inputs() const { return props_->input_types_.size(); } + DataType input_type(int i) const { return props_->input_types_[i]; } + const DataTypeVector& input_types() const { return props_->input_types_; } + + int num_outputs() const { return props_->output_types_.size(); } + DataType output_type(int o) const { return props_->output_types_[o]; } + const DataTypeVector& output_types() const { return props_->output_types_; } + + // This gives the device the runtime has assigned this node to. If + // you want the device the user requested, use def().device() instead. + // TODO(josh11b): Validate that the assigned_device, if not empty: + // fully specifies a device, and satisfies def().device(). + // TODO(josh11b): Move device_name outside of Node into a NodeId->DeviceName + // map. + string assigned_device_name() const { return assigned_device_name_; } + void set_assigned_device_name(const string& device_name) { + assigned_device_name_ = device_name; + } + + // Get the neighboring nodes via edges either in or out of this node. + gtl::iterator_range<NeighborIter> in_nodes() const; + gtl::iterator_range<NeighborIter> out_nodes() const; + const EdgeSet& in_edges() const { return in_edges_; } + const EdgeSet& out_edges() const { return out_edges_; } + + // Node type helpers. + bool IsSource() const { return id() == 0; } + bool IsSink() const { return id() == 1; } + // Anything other than the special Source & Sink nodes. + bool IsOp() const { return id() > 1; } + + private: + friend class Graph; + Node(); + ~Node(); + + class Properties : public core::RefCounted { + public: + Properties(const OpDef* op_def, const NodeDef& node_def, + const DataTypeSlice inputs, const DataTypeSlice outputs); + + const OpDef* op_def_; // not owned + const NodeDef node_def_; + const DataTypeVector input_types_; + const DataTypeVector output_types_; + + private: + // Destructor invoked when last reference goes away via Unref() + virtual ~Properties(); + TF_DISALLOW_COPY_AND_ASSIGN(Properties); + }; + + Properties* properties() const { return props_; } + + // Initialize() adopts a reference to props, and so is suitable if props was + // just allocated or you call props->Ref() to increment the reference + // count for a props being held by another Node. + void Initialize(int id, int cost_id, Properties* props); + // Releases memory from props_, in addition to restoring *this to its + // uninitialized state. + void Clear(); + + int id_; // -1 until Initialize() is called + int cost_id_; // -1 if there is no corresponding cost accounting node + + EdgeSet in_edges_; + EdgeSet out_edges_; + + Properties* props_; + + // Name of device assigned to perform this computation. + string assigned_device_name_; + + TF_DISALLOW_COPY_AND_ASSIGN(Node); +}; + +class Edge { + public: + Node* src() const { return src_; } + Node* dst() const { return dst_; } + int id() const { return id_; } + + // Return the number of the source output that produces the data + // carried by this edge. The special value kControlSlot is used + // for control dependencies. + int src_output() const { return src_output_; } + + // Return the number of the destination input that consumes the data + // carried by this edge. The special value kControlSlot is used + // for control dependencies. + int dst_input() const { return dst_input_; } + + // Return true iff this is an edge that indicates a control-flow + // (as opposed to a data-flow) dependency. + bool IsControlEdge() const; + + private: + Edge() {} + + friend class EdgeSetTest; + friend class Graph; + Node* src_; + Node* dst_; + int id_; + int src_output_; + int dst_input_; +}; + +// Thread compatible but not thread safe. +class Graph { + public: + // Constructs a graph with a single SOURCE (always id kSourceId) and a + // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK. + // + // The graph can hold ops found in registry. + explicit Graph(const OpRegistryInterface* registry); + ~Graph(); + + static const int kControlSlot = -1; + + // Adds a new node to this graph, and returns it. Infers the Op and + // input/output types for the node. *this owns the returned instance. + // Returns nullptr and sets *status on error. + Node* AddNode(const NodeDef& node_def, Status* status); + + // Copies *node, which may belong to another graph, to a new node, + // which is returned. Does not copy any edges. *this owns the + // returned instance. + Node* CopyNode(Node* node); + + // Remove a node from this graph, including all edges from or to it. + // *node should not be accessed after calling this function. + // REQUIRES: node->IsOp() + void RemoveNode(Node* node); + + // Add an edge that connects the xth output of "source" to the yth input + // of "dest". + const Edge* AddEdge(Node* source, int x, Node* dest, int y); + + // Add a control-edge (no data flows along this edge) that + // connects "source" to "dest". + const Edge* AddControlEdge(Node* source, Node* dest) { + return AddEdge(source, kControlSlot, dest, kControlSlot); + } + + // Removes edge from the graph. + // REQUIRES: The edge must exist. + void RemoveEdge(const Edge* edge); + + // Returns one more than the maximum id assigned to any node. + int num_node_ids() const { return nodes_.size(); } + + // Serialize to a GraphDef. + void ToGraphDef(GraphDef* graph_def) const; + + // Generate new node name with the specified prefix that is unique + // across this graph. + string NewName(StringPiece prefix); + + // Access to the list of all nodes. Example usage: + // for (Node* node : graph.nodes()) { ... } + gtl::iterator_range<NodeIter> nodes() const; + + // Returns the node associated with an id, or nullptr if no node + // with that id (the node with that id was removed and the id has + // not yet been re-used). *this owns the returned instance. + // REQUIRES: 0 <= id < num_node_ids(). + Node* FindNodeId(int id) const { return nodes_[id]; } + + // Returns one more than the maximum id assigned to any edge. + int num_edge_ids() const { return edges_.size(); } + + // Returns the Edge associated with an id, or nullptr if no edge + // with that id (the node with that id was removed and the id has + // not yet been re-used). *this owns the returned instance. + // REQUIRES: 0 <= id < num_node_ids(). + const Edge* FindEdgeId(int id) const { return edges_[id]; } + + // Access to the set of all edges. Example usage: + // for (const Edge* e : graph.edges()) { ... } + const EdgeSet& edges() const { return edge_set_; } + + // The pre-defined nodes. + enum { kSourceId = 0, kSinkId = 1 }; + Node* source_node() const { return FindNodeId(kSourceId); } + Node* sink_node() const { return FindNodeId(kSinkId); } + + const OpRegistryInterface* op_registry() const { return ops_; } + + // TODO(josh11b): uint64 hash() const; + + private: + bool IsValidNode(Node* node) const; + // If cost_node is non-null, then cost accounting (in CostModel) + // will be associated with that node rather than the new one being + // created. + Node* AllocateNode(Node::Properties* props, const Node* cost_node); + void ReleaseNode(Node* node); + + // Registry of all known ops. Not owned. + const OpRegistryInterface* const ops_; + + // Allocator which will give us good locality. + core::Arena arena_; + + // Map from node ids to allocated nodes. nodes_[id] may be nullptr if + // the node with that id was removed from the graph. + std::vector<Node*> nodes_; + + // Map from edge ids to allocated edges. edges_[id] may be nullptr if + // the edge with that id was removed from the graph. + std::vector<Edge*> edges_; + + // For ease of iteration, we currently just keep a set of all live + // edges. May want to optimize by removing this copy. + EdgeSet edge_set_; + + // Allocated but free nodes and edges. + std::vector<Node*> free_nodes_; + std::vector<Edge*> free_edges_; + + // For generating unique names. + int name_counter_ = 0; + + TF_DISALLOW_COPY_AND_ASSIGN(Graph); +}; + +// TODO(josh11b): We may want to support keeping an index on various +// node/edge attributes in a graph, particularly node names. + +// Helper routines + +inline bool IsSwitch(const Node* node) { + return node->type_string() == "Switch" || node->type_string() == "RefSwitch"; +} + +inline bool IsMerge(const Node* node) { return node->type_string() == "Merge"; } + +inline bool IsEnter(const Node* node) { + return node->type_string() == "Enter" || node->type_string() == "RefEnter"; +} + +inline bool IsExit(const Node* node) { return node->type_string() == "Exit"; } + +inline bool IsNextIteration(const Node* node) { + return node->type_string() == "NextIteration"; +} + +inline bool IsLoopCond(const Node* node) { + return node->type_string() == "LoopCond"; +} + +inline bool IsControlTrigger(const Node* node) { + return node->type_string() == "ControlTrigger"; +} + +inline bool IsSend(const Node* node) { + return node->type_string() == "_Send" || node->type_string() == "_HostSend"; +} + +inline bool IsRecv(const Node* node) { + return node->type_string() == "_Recv" || node->type_string() == "_HostRecv"; +} + +// True for Nodes that mediate the transfer of values between processes. +inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); } + +inline bool IsConstant(const Node* node) { + return node->type_string() == "Const" || node->type_string() == "HostConst"; +} + +inline bool IsVariable(const Node* node) { + return node->type_string() == "Variable"; +} + +inline bool IsIdentity(const Node* node) { + return (node->type_string() == "Identity" || + node->type_string() == "RefIdentity"); +} + +// Returns true iff 'n' is a control flow node. +inline bool IsControlFlow(const Node* n) { + return IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n) || + IsNextIteration(n); +} + +inline bool IsHostMemoryPreserving(const Node* node) { + return IsIdentity(node) || IsControlFlow(node); +} + +// Iterator for stepping through the nodes of a graph. +class NodeIter { + public: + NodeIter(const Graph* graph, int id); + bool operator==(const NodeIter& rhs); + bool operator!=(const NodeIter& rhs); + void operator++(); + Node* operator*(); + Node* operator->(); + + private: + // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr + const Graph* graph_; + int id_; +}; + +// Iterator for stepping through the neighbors of a node. +class NeighborIter { + public: + NeighborIter(EdgeSet::const_iterator iter, bool incoming); + bool operator==(const NeighborIter& rhs); + bool operator!=(const NeighborIter& rhs); + void operator++(); + Node* operator*(); + Node* operator->(); + + private: + EdgeSet::const_iterator iter_; + bool incoming_; +}; + +// IMPLEMENTATION DETAILS, PLEASE IGNORE + +inline NodeIter::NodeIter(const Graph* graph, int id) + : graph_(graph), id_(id) {} + +inline bool NodeIter::operator==(const NodeIter& rhs) { + DCHECK(graph_ == rhs.graph_); + return id_ == rhs.id_; +} + +inline bool NodeIter::operator!=(const NodeIter& rhs) { + return !(*this == rhs); +} + +inline void NodeIter::operator++() { + while (1) { + DCHECK_LE(id_, graph_->num_node_ids()); + ++id_; + if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) { + return; + } + } +} + +inline Node* NodeIter::operator*() { return graph_->FindNodeId(id_); } + +inline Node* NodeIter::operator->() { return graph_->FindNodeId(id_); } + +inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming) + : iter_(iter), incoming_(incoming) {} + +inline bool NeighborIter::operator==(const NeighborIter& rhs) { + return iter_ == rhs.iter_ && incoming_ == rhs.incoming_; +} + +inline bool NeighborIter::operator!=(const NeighborIter& rhs) { + return !(*this == rhs); +} + +inline void NeighborIter::operator++() { ++iter_; } + +inline Node* NeighborIter::operator*() { + const Edge* e = *iter_; + return incoming_ ? e->src() : e->dst(); +} + +inline Node* NeighborIter::operator->() { + const Edge* e = *iter_; + return incoming_ ? e->src() : e->dst(); +} + +inline bool Edge::IsControlEdge() const { + // Note that if either src_output_ or dst_input_ is kControlSlot, + // so is the other one (AddEdge checks this). + return src_output_ == Graph::kControlSlot; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_GRAPH_H_ diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc new file mode 100644 index 0000000000..3928348f0a --- /dev/null +++ b/tensorflow/core/graph/graph_constructor.cc @@ -0,0 +1,385 @@ +#include "tensorflow/core/graph/graph_constructor.h" + +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/optimizer_cse.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" + +namespace tensorflow { + +namespace { +inline bool IsMerge(const NodeDef& node_def) { + return node_def.op() == "Merge"; +} +} // namespace + +namespace { + +class GraphConstructor { + public: + GraphConstructor(const GraphConstructorOptions& opts, const GraphDef* gdef, + Graph* g, Status* status) + : opts_(opts), gdef_(gdef), g_(g), status_(status) { + BuildNodeIndex(); + InitFromEdges(); + Convert(); + } + + private: + void SetError(const string& error); + void SetNodeError(const NodeDef& node_def, const StringPiece& message) { + SetError(strings::StrCat("Node '", node_def.name(), "': ", message)); + } + void BuildNodeIndex(); + void InitFromEdges(); + Node* MakeNode(const NodeDef& node_def); + void Convert(); + // Calls SetError() and returns false if the type of the output of + // the source of the edge can't be consumed by destination of the edge. + // REQUIRES: edge must be a data edge, not a control edge. + bool TypeValidateEdge(const Edge* edge); + + // From constructor + const GraphConstructorOptions opts_; + const GraphDef* gdef_; + Graph* g_; + Status* status_; + + // Mapping from node name to the index within gdef_ + struct NodeInfo { + explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {} + // std::unordered_map<> requires that we have a default constructor. + NodeInfo() : NodeInfo(-1) {} + int gdef_index; + Node* node; // nullptr until the NodeDef is converted to a Node. + }; + // TODO(vrv): Profile this data structure to see if we should use an + // alternative implementation of std::unordered_map. + std::unordered_map<StringPiece, NodeInfo, StringPiece::Hasher> name_index_; + + // Index of NodeDefs in gdef_ with all inputs already converted. + std::vector<int> ready_; + + // Mapping between index within gdef_ and the number of inputs that + // still need to be converted. + std::vector<int> pending_count_; + + // Mapping between index within gdef_ and the index within gdef_ of + // all nodes it outputs to. + std::vector<gtl::InlinedVector<int, 4>> outputs_; + + // Used in the conversion from gdef_ to g_ to represent the ith input + // of a node. + struct InputInfo { + explicit InputInfo(StringPiece node_name, Node* n, int i) + : name(node_name), node(n), index(i) {} + StringPiece name; + Node* node; + int index; + }; + + // Used in the conversion from gdef_ to g_ to represent an edge from + // the node named 'name' to node 'n'. + struct EdgeInfo { + explicit EdgeInfo(StringPiece name, int i1, Node* n, int i2) + : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {} + StringPiece src_name; + int src_index; + Node* dst_node; + int dst_index; + }; +}; + +void GraphConstructor::SetError(const string& error) { + status_->Update(errors::InvalidArgument(error)); +} + +void GraphConstructor::BuildNodeIndex() { + // Initialized outside the loop for efficiency + const char* pattern; + if (opts_.allow_internal_ops) { + pattern = "[A-Za-z0-9._][A-Za-z0-9_.\\-/]*"; + } else { + pattern = "[A-Za-z0-9.][A-Za-z0-9_.\\-/]*"; + } + RE2 node_name_re(pattern); + + // Validate the node names and add them to name_index_. + for (int n = 0; n < gdef_->node_size(); ++n) { + const NodeDef& node_def(gdef_->node(n)); + if (!RE2::FullMatch(node_def.name(), node_name_re)) { + SetNodeError(node_def, "Node name contains invalid characters"); + return; + } + if (!name_index_.insert(std::make_pair(StringPiece(node_def.name()), + NodeInfo(n))) + .second) { + SetNodeError(node_def, "Node name is not unique"); + return; + } + // Validate the operation's type. + if (node_def.op().empty()) { + SetNodeError(node_def, "Does not specify a type"); + return; + } + if (opts_.expect_device_spec && node_def.device().empty()) { + SetNodeError(node_def, strings::StrCat("Missing device specification.")); + return; + } + } +} + +void GraphConstructor::InitFromEdges() { + const int num_nodes = gdef_->node_size(); + ready_.reserve(num_nodes); + pending_count_.reserve(num_nodes); + outputs_.resize(num_nodes); + + // Parse the inputs for each node. + for (int n = 0; n < num_nodes; ++n) { + const NodeDef& node_def(gdef_->node(n)); + if (IsMerge(node_def)) { + // for merge only wait for one non-control input. + int32 num_control_edges = 0; + for (int i = 0; i < node_def.input_size(); ++i) { + StringPiece input_name(node_def.input(i)); + if (StringPiece(input_name).starts_with("^")) { + num_control_edges++; + } + } + pending_count_.push_back(num_control_edges + 1); + } else { + pending_count_.push_back(node_def.input_size()); + } + if (node_def.input_size() == 0) { + ready_.push_back(n); + continue; + } + for (int i = 0; i < node_def.input_size(); ++i) { + StringPiece input_name = node_def.input(i); + if (input_name.starts_with("^")) { + // Control dependence + input_name.remove_prefix(1); + } + TensorId id(ParseTensorName(input_name)); + auto iter = name_index_.find(id.first); + if (iter == name_index_.end()) { + SetNodeError(node_def, + strings::StrCat("Unknown input node ", node_def.input(i))); + return; + } + outputs_[iter->second.gdef_index].push_back(n); + } + } +} + +Node* GraphConstructor::MakeNode(const NodeDef& node_def) { + // Add the node to the graph. + Node* node = g_->AddNode(node_def, status_); + if (node == nullptr) return nullptr; + if (opts_.expect_device_spec) { + node->set_assigned_device_name(node_def.device()); + } + name_index_[node_def.name()].node = node; + return node; +} + +// Return the number of nodes in "g" +static int CountNodes(Graph* g) { + int nodes = 0; + for (Node* node : g->nodes()) { + VLOG(1) << node; // Dummy use to avoid compiler warning + nodes++; + } + return nodes; +} + +void GraphConstructor::Convert() { + std::vector<InputInfo> inputs; + std::vector<EdgeInfo> back_edges; + int processed = 0; + // Process the NodeDefs in topological order. + while (!ready_.empty()) { + int o = ready_.back(); + ready_.pop_back(); + ++processed; + const NodeDef& node_def(gdef_->node(o)); + inputs.clear(); + bool in_control_dependence = false; + bool has_data_back_edge = false; + for (int i = 0; i < node_def.input_size(); ++i) { + StringPiece input_name(node_def.input(i)); + if (StringPiece(input_name).starts_with("^")) { + // A control dependence + in_control_dependence = true; + input_name.remove_prefix(1); + } else { + if (in_control_dependence) { + SetNodeError(node_def, strings::StrCat( + "Control dependencies must come after ", + "regular dependencies: input ", input_name, + " of source node ", node_def.name())); + return; + } + } + TensorId id(ParseTensorName(input_name)); + auto iter = name_index_.find(id.first); + DCHECK(iter != name_index_.end()); + Node* src_node = iter->second.node; + if (in_control_dependence) { + inputs.push_back(InputInfo(id.first, src_node, -1)); + } else { + if (src_node == nullptr) { + has_data_back_edge = true; + inputs.push_back(InputInfo(id.first, src_node, id.second)); + } else { + if (id.second >= src_node->num_outputs()) { + SetNodeError( + node_def, + strings::StrCat("Connecting to invalid output ", id.second, + " of source node ", id.first, " which has ", + src_node->num_outputs(), " outputs")); + return; + } + inputs.push_back(InputInfo(id.first, src_node, id.second)); + } + } + } + if (has_data_back_edge && !IsMerge(node_def)) { + SetError(strings::StrCat( + node_def.name(), + " had a back edge. But only Merge can have back edges.")); + return; + } + + Node* node = MakeNode(node_def); + if (node == nullptr) return; + + // Add edges from inputs to *node to the graph. + for (size_t i = 0; i < inputs.size(); ++i) { + if (inputs[i].node == nullptr) { + // Record this back edge, which will be added after all nodes + // are created. + back_edges.push_back( + EdgeInfo(inputs[i].name, inputs[i].index, node, i)); + } else if (inputs[i].index == -1) { + g_->AddControlEdge(inputs[i].node, node); + } else { + const Edge* edge = + g_->AddEdge(inputs[i].node, inputs[i].index, node, i); + if (!TypeValidateEdge(edge)) return; + } + } + + // Update pending_count_ for outputs. + for (size_t i = 0; i < outputs_[o].size(); ++i) { + const int output = outputs_[o][i]; + pending_count_[output]--; + if (pending_count_[output] == 0) { + ready_.push_back(output); + } + } + } + + // Add the back edges after all nodes are created. + for (auto e : back_edges) { + Node* src_node = name_index_[e.src_name].node; + if (e.src_index == -1) { + g_->AddControlEdge(src_node, e.dst_node); + } else { + const Edge* edge = + g_->AddEdge(src_node, e.src_index, e.dst_node, e.dst_index); + if (!TypeValidateEdge(edge)) return; + } + + VLOG(2) << "Add back edge: " << src_node->name() << " -> " + << e.dst_node->name(); + } + + if (processed < gdef_->node_size()) { + SetError( + strings::StrCat(gdef_->node_size() - processed, " nodes in a cycle")); + return; + } + + if (status_->ok()) { + FixupSourceAndSinkEdges(g_); + + if (opts_.optimizer_do_cse) { + if (!back_edges.empty()) { + LOG(WARNING) << "Not doing CSE. We need to figure out how to handle " + << "loops in the CSE phase."; + } else { + VLOG(1) << "Starting CSE: graph of " << CountNodes(g_) << " nodes"; + OptimizeCSE(g_, opts_.cse_consider_function); + VLOG(1) << "Finished CSE: graph of " << CountNodes(g_) << " nodes"; + } + } + } +} + +bool GraphConstructor::TypeValidateEdge(const Edge* edge) { + DataType src_out = edge->src()->output_type(edge->src_output()); + DataType dst_in = edge->dst()->input_type(edge->dst_input()); + if (!TypesCompatible(dst_in, src_out)) { + SetError(strings::StrCat( + "Input ", edge->dst_input(), " of node ", edge->dst()->name(), + " was passed ", DataTypeString(src_out), " from ", edge->src()->name(), + ":", edge->src_output(), " incompatible with expected ", + DataTypeString(dst_in), ".")); + return false; + } + return true; +} + +} // namespace + +// ---------------------------------------------------------------------------- +// ConvertGraphDefToGraph +// ---------------------------------------------------------------------------- + +Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, + const GraphDef& gdef, Graph* g) { + Status status; + GraphConstructor constructor(opts, &gdef, g, &status); + return status; +} + +// ---------------------------------------------------------------------------- +// CopyGraph +// ---------------------------------------------------------------------------- +void CopyGraph(const Graph& src, Graph* dest) { + for (Node* n : dest->nodes()) { + CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty"; + } + + // Copy the nodes + std::unordered_map<Node*, Node*> + node_map; // "Node in src" -> "Node in *dest" + node_map[src.source_node()] = dest->source_node(); + node_map[src.sink_node()] = dest->sink_node(); + for (Node* n : src.nodes()) { + if (n->IsSource() || n->IsSink()) continue; + CHECK(n->IsOp()); + node_map[n] = dest->CopyNode(n); + } + + // Copy the edges + for (const Edge* e : src.edges()) { + Node* src_copy = node_map[e->src()]; + Node* dst_copy = node_map[e->dst()]; + dest->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input()); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h new file mode 100644 index 0000000000..cd1615ef6b --- /dev/null +++ b/tensorflow/core/graph/graph_constructor.h @@ -0,0 +1,43 @@ +#ifndef TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_ +#define TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// Construct a graph *g out of a GraphDef gdef. Returns non-OK on +// error, in which case *g is left in an incomplete state. +struct GraphConstructorOptions { + // If true, allows internal ops in the GraphDef. + bool allow_internal_ops = false; + + // If true, the graph def is expected to have fully specified + // devices for all nodes. A node in the resulting graph "g" has the + // device name set accordingly. + // + // TODO(zhifengc): if possible, consider removing this option. + bool expect_device_spec = false; + + // If true, perform common subexpression elimination on the graph. + // TODO(jeff): Turn this default to true? + bool optimizer_do_cse = false; + + // If "optimizer_do_cse" is true and "cse_consider_function" is + // not nullptr, then only consider nodes for CSE for which + // "cse_consider_function(node)" returns true. + std::function<bool(const Node*)> cse_consider_function = nullptr; +}; +extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, + const GraphDef& gdef, Graph* g); + +// Make a copy of "src" into "*dest". +// +// REQUIRES: "*dest" is a freshly allocated graph without any nodes or edges +// other than the implicit Source/Sink nodes. +extern void CopyGraph(const Graph& src, Graph* dest); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_ diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc new file mode 100644 index 0000000000..61f4427297 --- /dev/null +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -0,0 +1,190 @@ +#include "tensorflow/core/graph/graph_constructor.h" + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/public/status.h" +#include <gtest/gtest.h> + +// TODO(josh11b): Test InitCostModel(). +// TODO(josh11b): Test setting the "device" field of a NodeDef. +// TODO(josh11b): Test that feeding won't prune targets. + +namespace tensorflow { +namespace { + +class GraphConstructorTest : public ::testing::Test { + protected: + GraphConstructorTest() : g_(new Graph(OpRegistry::Global())) { + RequireDefaultOps(); + } + ~GraphConstructorTest() override {} + + void Convert(const string& gdef_ascii) { + CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef_)); + } + + void ExpectError(const string& gdef_ascii, const string& expected_error_re) { + Convert(gdef_ascii); + GraphConstructorOptions opts; + Status status = ConvertGraphDefToGraph(opts, gdef_, g_.get()); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE(RE2::PartialMatch(status.error_message(), expected_error_re)) + << status; + } + + void ExpectOK(const string& gdef_ascii) { + Convert(gdef_ascii); + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef_, g_.get())); + } + + Node* FindNode(const string& name) { + for (Node* n : g_->nodes()) { + if (n->name() == name) return n; + } + return nullptr; + } + + bool HasNode(const string& name) { return FindNode(name) != nullptr; } + + void ExpectNodes(const string& nodes) { + int count = 0; + std::vector<string> actual_nodes; + for (Node* n : g_->nodes()) { + if (n->IsOp()) { + count++; + actual_nodes.push_back(n->name()); + } + } + std::sort(actual_nodes.begin(), actual_nodes.end()); + + LOG(INFO) << "Nodes present: " << str_util::Join(actual_nodes, " "); + + std::vector<string> expected_nodes = str_util::Split(nodes, ','); + std::sort(expected_nodes.begin(), expected_nodes.end()); + for (const string& s : expected_nodes) { + Node* n = FindNode(s); + EXPECT_TRUE(n != nullptr) << s; + } + + EXPECT_TRUE(actual_nodes.size() == expected_nodes.size()) + << "\nActual: " << str_util::Join(actual_nodes, ",") + << "\nExpected: " << str_util::Join(expected_nodes, ","); + } + + bool HasEdge(const string& src, int src_out, const string& dst, int dst_in) { + for (const Edge* e : g_->edges()) { + if (e->src()->name() == src && e->src_output() == src_out && + e->dst()->name() == dst && e->dst_input() == src_out) + return true; + } + return false; + } + bool HasControlEdge(const string& src, const string& dst) { + return HasEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot); + } + + private: + GraphDef gdef_; + std::unique_ptr<Graph> g_; +}; + +REGISTER_OP("ABC"); +REGISTER_OP("TestParams").Output("o: float"); +REGISTER_OP("TestInput").Output("a: float").Output("b: float"); +REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); +REGISTER_OP("TestInt").Input("a: int32"); + +TEST_F(GraphConstructorTest, InvalidNodeName) { + ExpectError("node { name: 'a:b' op: 'ABC' }", + "Node 'a:b': Node name contains invalid characters"); + ExpectError("node { name: '_abc' op: 'ABC' }", + // Can't start with '_' + "Node '_abc': Node name contains invalid characters"); + ExpectOK("node { name: 'a-bc_' op: 'ABC' }"); +} + +TEST_F(GraphConstructorTest, InvalidSourceNodeName) { + ExpectError( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: 'W999' input: 'input' }", + + "Unknown input node.*W999"); +} + +TEST_F(GraphConstructorTest, InvalidSourceNodeIndex) { + ExpectError( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1:1', 'input:1' ] }", + + "Connecting to invalid output 1 of source node W1"); +} + +TEST_F(GraphConstructorTest, GraphWithCycle) { + ExpectError( + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'input:0', 't2' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'input:1', 't1' ] }", + + "cycle"); +} + +TEST_F(GraphConstructorTest, TypeMismatch) { + ExpectError( + "node { name: 'input' op: 'TestInput' }" + "node { name: 'int' op: 'TestInt' input: [ 'input' ] }", + + "Input 0 of node int was passed float from input:0 incompatible with " + "expected int32."); +} + +TEST_F(GraphConstructorTest, EmptyGraph) { ExpectOK(""); } + +TEST_F(GraphConstructorTest, SimpleModel) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"); + EXPECT_TRUE(HasNode("W1")); + EXPECT_TRUE(HasNode("input")); + EXPECT_TRUE(HasNode("t1")); + EXPECT_TRUE(HasEdge("W1", 0, "t1", 0)); + EXPECT_TRUE(HasEdge("input", 1, "t1", 0)); +} + +TEST_F(GraphConstructorTest, SimpleModelWithControlEdges) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' input: [ '^W1' ] }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W1', 'input:1', '^t1' ] }"); + EXPECT_TRUE(HasNode("W1")); + EXPECT_TRUE(HasNode("input")); + EXPECT_TRUE(HasNode("t1")); + EXPECT_TRUE(HasNode("t2")); + EXPECT_TRUE(HasEdge("W1", 0, "t1", 0)); + EXPECT_TRUE(HasEdge("input", 1, "t1", 0)); + EXPECT_TRUE(HasEdge("W1", 0, "t2", 0)); + EXPECT_TRUE(HasEdge("input", 1, "t2", 0)); + EXPECT_TRUE(HasControlEdge("W1", "input")); + EXPECT_TRUE(HasControlEdge("t1", "t2")); +} + +TEST_F(GraphConstructorTest, Error_ControlEdgeBeforeRealInput) { + ExpectError( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' input: [ '^W1' ] }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W1', '^t1', 'input:1' ] }", + "Node 't2': Control dependencies must come after regular dependencies"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/graph_def_builder.cc b/tensorflow/core/graph/graph_def_builder.cc new file mode 100644 index 0000000000..979604f948 --- /dev/null +++ b/tensorflow/core/graph/graph_def_builder.cc @@ -0,0 +1,121 @@ +#include "tensorflow/core/graph/graph_def_builder.h" + +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +GraphDefBuilder::Options::Options(Graph* graph, Status* status) + : graph_(graph), status_(status) {} +GraphDefBuilder::Options::~Options() {} + +GraphDefBuilder::Options GraphDefBuilder::Options::WithName( + StringPiece name) const { + return Options(*this).WithNameImpl(name); +} +GraphDefBuilder::Options GraphDefBuilder::Options::WithDevice( + StringPiece device) const { + return Options(*this).WithDeviceImpl(device); +} +GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInput( + Node* control_input) const { + return Options(*this).WithControlInputImpl(control_input); +} +GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputs( + gtl::ArraySlice<Node*> control_inputs) const { + return Options(*this).WithControlInputsImpl(control_inputs); +} +GraphDefBuilder::Options GraphDefBuilder::Options::WithNameImpl( + StringPiece name) { + name_ = name.ToString(); + return *this; +} +GraphDefBuilder::Options GraphDefBuilder::Options::WithDeviceImpl( + StringPiece device) { + device_ = device.ToString(); + return *this; +} +GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputImpl( + Node* control_input) { + control_inputs_.push_back(control_input); + return *this; +} +GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputsImpl( + gtl::ArraySlice<Node*> control_inputs) { + control_inputs_.insert(control_inputs_.end(), control_inputs.begin(), + control_inputs.end()); + return *this; +} + +Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const { + if (status_.ok()) { + graph_.ToGraphDef(graph_def); + } + return status_; +} + +Status GraphDefBuilder::ToGraph(Graph* graph) const { + if (status_.ok()) { + GraphDef graph_def; + graph_.ToGraphDef(&graph_def); + GraphConstructorOptions opts; + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def, graph)); + } + return status_; +} + +string GraphDefBuilder::Options::GetNameForOp(StringPiece op) const { + if (name_.empty()) return graph_->NewName(op); + return name_; +} + +Node* GraphDefBuilder::Options::FinalizeBuilder(NodeBuilder* builder) const { + builder->ControlInputs(control_inputs_); + if (!device_.empty()) builder->Device(device_); + for (const auto& attr : attrs_) { + builder->Attr(attr.first, attr.second); + } + + Node* returned_node; + UpdateStatus(builder->Finalize(graph_, &returned_node)); + return returned_node; +} + +void GraphDefBuilder::Options::UpdateStatus(const Status& status) const { + if (status_ == nullptr) { + TF_CHECK_OK(status); + } else { + status_->Update(status); + } +} + +namespace ops { + +Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name, + opts.op_registry()); + return opts.FinalizeBuilder(&node_builder); +} + +Node* UnaryOp(const string& op_name, NodeOut input, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name, + opts.op_registry()); + node_builder.Input(input); + return opts.FinalizeBuilder(&node_builder); +} + +Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name, + opts.op_registry()); + node_builder.Input(a).Input(b); + return opts.FinalizeBuilder(&node_builder); +} + +} // end namespace ops +} // end namespace tensorflow diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h new file mode 100644 index 0000000000..bb72f9eea6 --- /dev/null +++ b/tensorflow/core/graph/graph_def_builder.h @@ -0,0 +1,181 @@ +#ifndef TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_ +#define TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_ + +#include <vector> +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// Given a function like: +// namespace ops { +// Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) { +// if (opts.HaveError()) return nullptr; +// static const string kOpName = "Identity"; +// NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName, +// opts.op_registry()); +// node_builder.Input(input); +// return opts.FinalizeBuilder(&node_builder); +// } +// } // namspace ops +// +// // Or, alternatively: +// namespace ops { +// Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) { +// static const string kOpName = "Identity"; +// return UnaryOp(kOpName, input, opts); +// } +// } // namspace ops +// +// You call it like: +// GraphDefBuilder b; +// using namespace ::tensorflow::ops; // NOLINT(build/namespaces) +// Node* a = Const(7, b.opts()); +// // Note: WithName() returns a copy, opts is unchanged. +// Node* b = Const(5, b.opts().WithName("control-input")); +// Node* c = Identity(a, b.opts().WithControlInput(b)); +// GraphDef graph_def; +// Status status = b.ToGraphDef(&graph_def); +// if (!status.ok()) { /* Handle error */ } +// +// In tests you can skip the status handling via: +// GraphDefBuilder b(GraphDefBuilder::kFailImmediately); +// ... +// b.ToGraphDef(&graph_def); + +class GraphDefBuilder { + public: + // Options for adding a Node to a Graph. + class Options { + public: + // Sets the Graph (that Nodes will be added to) and the status. The + // status may be set to nullptr, in which case errors cause CHECK + // failures. The graph and status must outlive *this. + Options(Graph* graph, Status* status); + ~Options(); + + // Methods for setting options. These are const methods: they + // return a copy of *this with the option set. + Options WithName(StringPiece name) const; + Options WithDevice(StringPiece device) const; + Options WithControlInput(Node* control_input) const; + Options WithControlInputs(gtl::ArraySlice<Node*> control_inputs) const; + + // Override the default value for an optional attr. + template <class T> + Options WithAttr(StringPiece attr_name, T&& value) const { + return Options(*this).WithAttrImpl(attr_name, std::forward<T>(value)); + } + // Note: overload needed to allow {...} expressions for value. + template <class T> + Options WithAttr(StringPiece attr_name, + std::initializer_list<T> value) const { + return WithAttr<std::initializer_list<T>>(attr_name, std::move(value)); + } + + // Methods for using options from a function that creates a Node. + + // Returns true if the status associated with *this has an error. + // Use this to skip processing that may depend on prior results. + bool HaveError() const { return status_ != nullptr && !status_->ok(); } + + // Given the Op type name, return a name for a node of that type. + // Uses the value set in WithName() if that has been called. Otherwise, + // returns a name built out of the Op type name. + string GetNameForOp(StringPiece op) const; + + // Sets the device, adds control inputs, adds attrs, and calls Finalize(). + // If Finalize returns an error, it is saved and this function returns + // nullptr. + Node* FinalizeBuilder(NodeBuilder* builder) const; + + // Updates the associated status, if any, or calls TF_CHECK_OK if none. + void UpdateStatus(const Status& status) const; + + // Accessor + const OpRegistryInterface* op_registry() const { + return graph_->op_registry(); + } + + private: + Options WithNameImpl(StringPiece name); + Options WithDeviceImpl(StringPiece device); + Options WithControlInputImpl(Node* control_input); + Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs); + template <class T> + Options WithAttrImpl(StringPiece name, T&& value) { + attrs_.emplace_back(name.ToString(), AttrValue()); + SetAttrValue(std::forward<T>(value), &attrs_.back().second); + return *this; + } + + Graph* const graph_; + Status* const status_; + string name_; + string device_; + std::vector<Node*> control_inputs_; + std::vector<std::pair<string, AttrValue>> attrs_; + }; + + // Start building a new graph. + explicit GraphDefBuilder( + const OpRegistryInterface* op_registry = OpRegistry::Global()) + : graph_(op_registry), opts_(&graph_, &status_) {} + + // For use in tests, where you want to fail immediately on error instead + // of checking the status at the end. + enum TestFailImmediatelyType { kFailImmediately }; + explicit GraphDefBuilder( + TestFailImmediatelyType, + const OpRegistryInterface* op_registry = OpRegistry::Global()) + : graph_(op_registry), opts_(&graph_, nullptr) {} + + // Gets the Options with the associated Graph and Status. + const Options& opts() const { return opts_; } + + // Once all the nodes have been added, call this to get whether it was + // successful, and if so fill *graph_def. + Status ToGraphDef(GraphDef* graph_def) const; + + // Like ToGraphDef(), but converts to a Graph (using the default + // GraphConstructorOptions). + // TODO(josh11b): Make this faster; right now it converts + // Graph->GraphDef->Graph. This cleans up the graph (e.g. adds + // edges from the source and to the sink node, resolves back edges + // by name), and makes sure the resulting graph is valid. + Status ToGraph(Graph* graph) const; + + private: + Graph graph_; + Status status_; + Options opts_; +}; + +namespace ops { + +// A NodeOut may either be a regular input or back input. Regular +// inputs are specified via either a Node* or a Node* and an output +// index. Back inputs are specified by a node name, output index, and +// output type. +typedef NodeBuilder::NodeOut NodeOut; + +// For adding an Op with no inputs to a GraphDefBuilder. +Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts); + +// For adding an Op with one input to a GraphDefBuilder. +Node* UnaryOp(const string& op_name, NodeOut input, + const GraphDefBuilder::Options& opts); + +// For adding an Op with two inputs to a GraphDefBuilder. +Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b, + const GraphDefBuilder::Options& opts); + +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_ diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc new file mode 100644 index 0000000000..1571790e59 --- /dev/null +++ b/tensorflow/core/graph/graph_partition.cc @@ -0,0 +1,1050 @@ +#include "tensorflow/core/graph/graph_partition.h" + +#include <deque> +#include <unordered_map> + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/costmodel.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +namespace { + +struct DupRecvKey { + int src_node_id; // Edge's src node id + int src_output_slot; // Edge's src node output slot + GraphDef* dst_graph; // Edge's dst node is in this subgraph + bool recv_output_on_host; // The output of recv is on host +}; + +struct DupRecvKeyHash { + size_t operator()(const DupRecvKey& k) const { + size_t h = Hash64(reinterpret_cast<const char*>(&k.src_node_id), + sizeof(k.src_node_id), k.src_output_slot); + h = Hash64(reinterpret_cast<const char*>(&k.dst_graph), sizeof(k.dst_graph), + h); + h = Hash64(reinterpret_cast<const char*>(&k.recv_output_on_host), + sizeof(k.recv_output_on_host), h); + return h; + } +}; + +struct DupRecvKeyEq { + bool operator()(const DupRecvKey& x, const DupRecvKey& y) const { + return (x.src_node_id == y.src_node_id) && + (x.src_output_slot == y.src_output_slot) && + (x.dst_graph == y.dst_graph) && + (x.recv_output_on_host == y.recv_output_on_host); + } +}; + +// struct used to store the recvs, so that start times can be properly updated +struct RecvInfo { + NodeDef* recv; + NodeDef* real_recv; + int64 start_time; +}; + +typedef std::unordered_map<DupRecvKey, RecvInfo, DupRecvKeyHash, DupRecvKeyEq> + DupRecvTable; + +// Control flow info for a graph node. +struct ControlFlowInfo { + const Node* frame = nullptr; // frame of a node + const Node* parent_frame = nullptr; // parent frame of a node + string frame_name; // frame name of a node + int iter_level = -1; // level of a node +}; + +struct PairIntHash { + public: + std::size_t operator()(const std::pair<int, int>& x) const { + return std::hash<int>()(x.first) ^ std::hash<int>()(x.second); + } +}; +// A map used to store memory types for the inputs/outputs of every node. +// The key is a pair of ints consisting of a node id and input/output index. +typedef std::unordered_map<std::pair<int, int>, MemoryType, PairIntHash> + MemoryTypeMap; + +// We collect the following information about the graph before performing +// graph partitioning. +struct GraphInfo { + std::vector<DeviceType> device_types; + MemoryTypeMap input_types; + MemoryTypeMap output_types; + std::vector<ControlFlowInfo> cf_info; +}; + +DataType EdgeType(const Edge* e) { + if (e->IsControlEdge()) { + return DT_FLOAT; + } else { + return e->dst()->input_type(e->dst_input()); + } +} + +// Return true iff we need to add a same device send/recv for 'edge'. +bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) { + if (edge->IsControlEdge()) { + return false; + } + + Node* src = edge->src(); + Node* dst = edge->dst(); + if (src->assigned_device_name() == dst->assigned_device_name()) { + int src_port = edge->src_output(); + int dst_port = edge->dst_input(); + if (info.device_types[src->id()] == DEVICE_GPU) { + auto src_it = info.output_types.find({src->id(), src_port}); + DCHECK(src_it != info.output_types.end()); + auto dst_it = info.input_types.find({dst->id(), dst_port}); + DCHECK(dst_it != info.input_types.end()); + return src_it->second != dst_it->second; + } + } + return false; +} + +// Return true iff (dst, dst_input) is specified on host memory. +bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) { + Node* dst = edge->dst(); + int dst_port = edge->dst_input(); + if (info.device_types[dst->id()] == DEVICE_GPU) { + if (edge->IsControlEdge()) return false; + auto dst_it = info.input_types.find({dst->id(), dst_port}); + DCHECK(dst_it != info.input_types.end()); + return dst_it->second == HOST_MEMORY; + } + return true; +} + +// Add an input to dst that comes from the "src_slot" output of the +// node named by "src_name". +void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) { + if (src_slot == Graph::kControlSlot) { + dst->add_input(strings::StrCat("^", src_name)); + } else if (src_slot == 0) { + dst->add_input(src_name.data(), src_name.size()); + } else { + dst->add_input(strings::StrCat(src_name, ":", src_slot)); + } +} + +// Add a control edge from each input to each recv. +void AddReadControl(const std::vector<NodeDef*>& recvs, + const std::vector<string>& inputs) { + for (NodeDef* recv : recvs) { + for (const string& input : inputs) { + recv->add_input(strings::StrCat("^", input)); + } + } +} + +void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge, + NodeDefBuilder* builder) { + builder->Attr("tensor_name", + strings::StrCat("edge_", edge->id(), "_", edge->src()->name())); + builder->Attr("send_device", edge->src()->assigned_device_name()); + builder->Attr("send_device_incarnation", + static_cast<int64>( + opts.get_incarnation(edge->src()->assigned_device_name()))); + builder->Attr("recv_device", edge->dst()->assigned_device_name()); + builder->Attr("client_terminated", false); +} + +NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info, + GraphDef* gdef, const Edge* edge, + NodeDefBuilder::NodeOut send_from, int64 start_time, + Status* status) { + const DataType dtype = send_from.data_type; + const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype; + const Node* src = edge->src(); + const int src_port = edge->src_output(); + + // host_memory = true iff we need to use HostSend/HostCast. + bool host_memory = false; + if (!edge->IsControlEdge()) { + auto src_it = g_info.output_types.find({src->id(), src_port}); + DCHECK(src_it != g_info.output_types.end()); + host_memory = (src_it->second == HOST_MEMORY); + } + + // Add a cast node that casts dtype to cast_dtype. + // NOTE(yuanbyu): Only cast for cross-device send/recv. + if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) { + const string cast_op = (host_memory) ? "_HostCast" : "Cast"; + NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op); + cast_builder.Device(src->assigned_device_name()).Input(send_from); + if (opts.scheduling_for_recvs) { + cast_builder.Attr("_start_time", start_time); + } + cast_builder.Attr("DstT", cast_dtype); + NodeDef* cast = gdef->add_node(); + *status = cast_builder.Finalize(cast); + if (!status->ok()) return nullptr; + + // Connect the Send op to the cast. + send_from.Reset(cast->name(), 0, cast_dtype); + } + + // Add the send node. + const string send_op = (host_memory) ? "_HostSend" : "_Send"; + NodeDefBuilder send_builder(opts.new_name(src->name()), send_op); + SetSendRecvAttrs(opts, edge, &send_builder); + send_builder.Device(src->assigned_device_name()).Input(send_from); + if (opts.scheduling_for_recvs) { + send_builder.Attr("_start_time", start_time); + } + NodeDef* send = gdef->add_node(); + *status = send_builder.Finalize(send); + return send; +} + +NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info, + GraphDef* gdef, const Edge* edge, NodeDef** real_recv, + Status* status) { + const DataType dtype = EdgeType(edge); + const Node* src = edge->src(); + const Node* dst = edge->dst(); + const int dst_port = edge->dst_input(); + DataType cast_dtype = dtype; + + // NOTE(yuanbyu): Only cast for cross-device send/recv. + if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) { + cast_dtype = opts.should_cast(edge); + } + + // host_memory = true iff we need to use HostRecv/HostCast. + bool host_memory = false; + if (!edge->IsControlEdge()) { + auto dst_it = g_info.input_types.find({dst->id(), dst_port}); + DCHECK(dst_it != g_info.input_types.end()); + host_memory = (dst_it->second == HOST_MEMORY); + } + + // Add the recv node. + const string recv_op = (host_memory) ? "_HostRecv" : "_Recv"; + NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op); + SetSendRecvAttrs(opts, edge, &recv_builder); + recv_builder.Device(dst->assigned_device_name()) + .Attr("tensor_type", cast_dtype); + NodeDef* recv = gdef->add_node(); + *status = recv_builder.Finalize(recv); + if (!status->ok()) return nullptr; + *real_recv = recv; + + // Add the cast node (from cast_dtype to dtype) or an Identity node. + if (dtype != cast_dtype) { + const string cast_op = (host_memory) ? "_HostCast" : "Cast"; + NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op); + cast_builder.Attr("DstT", dtype); + cast_builder.Device(dst->assigned_device_name()) + .Input(recv->name(), 0, cast_dtype); + NodeDef* cast = gdef->add_node(); + *status = cast_builder.Finalize(cast); + if (!status->ok()) return nullptr; + return cast; + } else if (edge->IsControlEdge()) { + // An Identity is only needed for control edges. + NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity"); + id_builder.Device(dst->assigned_device_name()) + .Input(recv->name(), 0, cast_dtype); + NodeDef* id = gdef->add_node(); + *status = id_builder.Finalize(id); + if (!status->ok()) return nullptr; + return id; + } else { + return recv; + } +} + +NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef, + const Edge* edge, Status* status) { + const Node* src = edge->src(); + Tensor tensor(DT_FLOAT, TensorShape({0})); + NodeDef* result = gdef->add_node(); + *status = NodeDefBuilder(opts.new_name(src->name()), "Const") + .Device(src->assigned_device_name()) + .Attr("dtype", DT_FLOAT) + .Attr("value", tensor) + .Finalize(result); + return result; +} + +// A dummy node for scheduling. +NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef, + const string& assigned_device_name, int64 epoch, + int64 starttime, Status* status) { + NodeDef* result = gdef->add_node(); + *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_", epoch)), + "ControlTrigger") + .Device(assigned_device_name) + .Attr("_start_time", starttime) + .Finalize(result); + return result; +} + +// Assign to each node the name of the frame and the level it belongs to. +// We check the well-formedness of the graph: All inputs to a node must +// come from the same frame and have the same "static" iteration level. +// NOTE(yuanbyu): For now, we require all sends/recvs have iteration level +// 0. This essentially means there can't be multiple serial Nexts in +// an iteration, which all sane front-ends should satisfy. +Status BuildControlFlowInfo(Graph* g, std::vector<ControlFlowInfo>* info) { + info->clear(); + info->resize(g->num_node_ids()); + + Node* src_node = g->source_node(); + ControlFlowInfo& src_info = (*info)[src_node->id()]; + src_info.frame = src_node; + src_info.parent_frame = src_node; + src_info.iter_level = 0; + + string frame_name; + std::deque<const Node*> ready; + ready.push_back(src_node); + while (!ready.empty()) { + const Node* curr_node = ready.front(); + ready.pop_front(); + const ControlFlowInfo& curr_info = (*info)[curr_node->id()]; + const Node* frame = curr_info.frame; + const Node* parent = curr_info.parent_frame; + frame_name = curr_info.frame_name; + int iter_level = curr_info.iter_level; + + if (IsExit(curr_node)) { + const ControlFlowInfo& parent_info = (*info)[parent->id()]; + frame = parent_info.frame; + parent = parent_info.parent_frame; + frame_name = parent_info.frame_name; + iter_level = parent_info.iter_level; + } + + for (const Edge* out_edge : curr_node->out_edges()) { + const Node* out = out_edge->dst(); + int out_id = out->id(); + ControlFlowInfo* out_info = &(*info)[out_id]; + const Node* out_parent = out_info->parent_frame; + bool is_visited = (out_info->iter_level != -1); + + // Skip Sink/Source nodes. + if (!out->IsOp()) continue; + + // Add to ready queue if not seen. + if (!is_visited) { + ready.push_back(out); + } + + // Process the node 'out'. + if (IsEnter(out)) { + if (is_visited) { + const string& parent_name = (*info)[out_parent->id()].frame_name; + if (parent_name != frame_name || iter_level != out_info->iter_level) { + return errors::InvalidArgument( + "All inputs to Enter must be from the same frame and level."); + } + } else { + out_info->frame = out; + out_info->parent_frame = frame; + TF_RETURN_IF_ERROR( + GetNodeAttr(out->def(), "frame_name", &out_info->frame_name)); + if (out_info->frame_name.empty()) { + return errors::InvalidArgument( + "Enter must have a non-empty frame name."); + } + out_info->iter_level = 0; + } + } else if (IsNextIteration(out)) { + if (is_visited) { + if (out_info->frame_name != frame_name || + out_info->iter_level != (iter_level + 1)) { + return errors::InvalidArgument( + "All inputs to NextIteration must be from the same frame " + "and level."); + } + } else { + out_info->frame = frame; + out_info->parent_frame = parent; + out_info->frame_name = frame_name; + out_info->iter_level = iter_level + 1; + } + } else { + if (is_visited) { + if (out_info->frame_name != frame_name) { + return errors::InvalidArgument( + "All inputs to a node must be from the same frame."); + } + } else { + out_info->frame = frame; + out_info->parent_frame = parent; + out_info->frame_name = frame_name; + out_info->iter_level = iter_level; + } + } + } + } + + return Status::OK(); +} + +string ControlLoopName(const string& name) { + return strings::StrCat("_cloop", name); +} + +bool IsControlLoop(const Node* node) { + const string& name = node->def().name(); + return StringPiece(name).starts_with("_cloop"); +} + +// An enter node for control flow. +Node* AddControlEnter(Graph* g, const string& node_name, + const string& device_name, const string& frame_name, + const int parallel_iterations, Status* status) { + NodeBuilder node_builder(node_name, "Enter", g->op_registry()); + node_builder.Input({"dummy", 0, DT_FLOAT}); + node_builder.Attr("frame_name", frame_name); + node_builder.Attr("parallel_iterations", parallel_iterations); + Node* res_node; + *status = node_builder.Finalize(g, &res_node); + if (!status->ok()) return nullptr; + res_node->set_assigned_device_name(device_name); + return res_node; +} + +// A merge node for control flow. +Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g, + const string& node_name, const string& device_name, + Status* status) { + NodeBuilder node_builder(node_name, "Merge", g->op_registry()); + node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}}); + Node* res_node; + *status = node_builder.Finalize(g, &res_node); + if (!status->ok()) return nullptr; + res_node->set_assigned_device_name(device_name); + return res_node; +} + +// A switch node for control flow. +Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2, + const string& device_name, + const GraphDefBuilder::Options& bopts) { + Node* res_node = ops::BinaryOp("Switch", input1, input2, bopts); + if (bopts.HaveError()) return nullptr; + res_node->set_assigned_device_name(device_name); + return res_node; +} + +// A next_iteration node for control flow. +Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name, + const GraphDefBuilder::Options& bopts) { + Node* res_node = ops::UnaryOp("NextIteration", input, bopts); + if (bopts.HaveError()) return nullptr; + res_node->set_assigned_device_name(device_name); + return res_node; +} + +Node* EmptyConst(const GraphDefBuilder::Options& options) { + if (options.HaveError()) return nullptr; + NodeBuilder node_builder(options.GetNameForOp("Const"), "Const", + options.op_registry()); + const DataType dt = DataTypeToEnum<float>::v(); + TensorProto proto; + proto.set_dtype(dt); + TensorShape empty_shape({0}); + empty_shape.AsProto(proto.mutable_tensor_shape()); + node_builder.Attr("dtype", dt).Attr("value", proto); + return options.FinalizeBuilder(&node_builder); +} + +// A dummy const node for control flow. +Node* AddControlConst(const string& device_name, + const GraphDefBuilder::Options& bopts) { + Node* res_node = EmptyConst(bopts); + if (bopts.HaveError()) return nullptr; + res_node->set_assigned_device_name(device_name); + return res_node; +} + +// A synthetic loop, made up of dummy nodes. It performs control-flow actions +// on behalf of a leader on a different device. +struct ControlLoop { + Node* enter = nullptr; + Node* merge = nullptr; + Node* switch_node = nullptr; +}; + +// Add the control flow info of a new node added during partitioning. +// The new node has the same control flow info as edge->src(). +void AddControlFlowInfo(const Node* node, const Node* src, + std::vector<ControlFlowInfo>* cf_info) { + int id = node->id(); + if (static_cast<size_t>(id) >= cf_info->size()) { + cf_info->resize(id + 1); + } + const ControlFlowInfo& src_info = (*cf_info)[src->id()]; + ControlFlowInfo* info = &(*cf_info)[id]; + info->frame = src_info.frame; + info->parent_frame = src_info.parent_frame; + info->frame_name = src_info.frame_name; + info->iter_level = src_info.iter_level; +} + +// Constructs a control loop. Returns a struct containing the newly created +// enter, merge, and switch nodes. The enter and merge nodes are used in the +// recursive construction of control loops for nested frames (loops). The +// switch node will be connected to the LoopCond node. The merge node will +// be connected to all the recvs of the same frame by control edges when +// the actual partitioning happens. +Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src, + const Edge* edge, Node* loop_cond, + std::vector<ControlFlowInfo>* cf_info, + ControlLoop* loop) { + Status status; + GraphDefBuilder::Options bopts(g, &status); + const ControlFlowInfo& src_info = (*cf_info)[src->id()]; + const string& device_name = edge->dst()->assigned_device_name(); + const string& frame_name = src_info.frame_name; + int parallel_iterations; + status = GetNodeAttr(src_info.frame->def(), "parallel_iterations", + ¶llel_iterations); + if (!status.ok()) return status; + + // The names of the nodes to be added. + const string& enter_name = + ControlLoopName(opts.new_name(edge->dst()->name())); + const string& merge_name = + ControlLoopName(opts.new_name(edge->dst()->name())); + const string& switch_name = + ControlLoopName(opts.new_name(edge->dst()->name())); + const string& next_name = ControlLoopName(opts.new_name(edge->dst()->name())); + + // Add the nodes to the graph g. + Node* enter = AddControlEnter(g, enter_name, device_name, frame_name, + parallel_iterations, &status); + if (!status.ok()) return status; + Node* merge = AddControlMerge(enter_name, next_name, g, merge_name, + device_name, &status); + if (!status.ok()) return status; + Node* switch_node = AddControlSwitch(merge, loop_cond, device_name, + bopts.WithName(switch_name)); + if (!status.ok()) return status; + Node* next = + AddControlNext({switch_node, 1}, device_name, bopts.WithName(next_name)); + if (!status.ok()) return status; + + // Add control flow info for these new nodes: + AddControlFlowInfo(enter, src, cf_info); + AddControlFlowInfo(merge, src, cf_info); + AddControlFlowInfo(switch_node, src, cf_info); + AddControlFlowInfo(next, src, cf_info); + + // Add input edges for the newly created merge node: + g->AddEdge(enter, 0, merge, 0); + g->AddEdge(next, 0, merge, 1); + + loop->enter = enter; + loop->merge = merge; + loop->switch_node = switch_node; + return Status::OK(); +} + +// Build memory and device type info for every node in the graph. +// TODO(yuanbyu): It might be simpler if we convert MemoryType to +// DeviceType for the inputs/outputs of each node. +Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) { + Status status; + MemoryTypeVector input_memory_types; + MemoryTypeVector output_memory_types; + + info->device_types.resize(g.num_node_ids(), DEVICE_CPU); + for (const Node* node : g.nodes()) { + if (!node->IsOp()) continue; // Skip Sink/Source nodes. + + DeviceNameUtils::ParsedName parsed; + if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(), + &parsed)) { + return errors::Internal("Malformed assigned device '", + node->assigned_device_name(), "'"); + } + + input_memory_types.clear(); + input_memory_types.resize(node->num_inputs()); + output_memory_types.clear(); + output_memory_types.resize(node->num_outputs()); + status = MemoryTypesForNode(g.op_registry(), DeviceType(parsed.type), + node->def(), &input_memory_types, + &output_memory_types); + if (!status.ok()) return status; + + int node_id = node->id(); + info->device_types[node_id] = DeviceType(parsed.type); + for (size_t i = 0; i < input_memory_types.size(); ++i) { + info->input_types[{node_id, i}] = input_memory_types[i]; + } + for (size_t i = 0; i < output_memory_types.size(); ++i) { + info->output_types[{node_id, i}] = output_memory_types[i]; + } + } + return status; +} + +// Each participating device needs to decide a) if there is a next iteration, +// and b) if the loop terminates. We take the approach to encode this control +// flow logic in the dataflow graph. There are at least two possible encodings. +// In a completely decentralized encoding, the participants communicate peer +// to peer. The other encoding uses a frame leader (the participant who owns +// the pivot termination predicate) to broadcast the termination condition to +// all the participants. For now we take the latter because it is simpler. +// +// TODO(yuanbyu): The correctness of this construction is rather subtle. I got +// it wrong many times so it would be nice to write a proof to be sure. +Status AddControlFlow(const PartitionOptions& opts, Graph* g, + GraphInfo* g_info) { + Status status; + GraphDefBuilder::Options bopts(g, &status); + std::vector<ControlFlowInfo>& cf_info = g_info->cf_info; + + // Build the control flow info for every node. + status = BuildControlFlowInfo(g, &cf_info); + if (!status.ok()) return status; + + // The map from frames to their LoopCond nodes. + std::unordered_map<string, Node*> frame_cond_map; + int num_node_ids = g->num_node_ids(); + for (int i = 0; i < num_node_ids; ++i) { + Node* node = g->FindNodeId(i); + if (node == nullptr) continue; + + if (IsLoopCond(node)) { + const string& frame_name = cf_info[node->id()].frame_name; + DCHECK(!frame_name.empty()); + frame_cond_map[frame_name] = node; + } + } + + // Add all control loops for cross-device frames. + // A control loop is added only when there is a cross-device edge in a + // non-root frame. Nothing is added if there is no loops. We also don't + // add anything for a frame that is completely local to a device. For + // nested loops, we stack the control loops together by connecting + // the merge of the outer loop to the enter of the inner loop. + // + // A map from <frame_name, device_name> to ControlLoop. + std::unordered_map<string, ControlLoop> control_loops; + int num_edge_ids = g->num_edge_ids(); + for (int i = 0; i < num_edge_ids; ++i) { + const Edge* edge = g->FindEdgeId(i); + if (edge == nullptr) continue; + + const Node* src = edge->src(); + const Node* dst = edge->dst(); + // Skip Sink/Source nodes. + if (!src->IsOp() || !dst->IsOp()) continue; + + const string& src_device = src->assigned_device_name(); + const string& dst_device = dst->assigned_device_name(); + // Skip local edges. + if (src_device == dst_device) continue; + + const string& src_frame = cf_info[src->id()].frame_name; + const string& dst_frame = cf_info[dst->id()].frame_name; + // Skip if src and dst are not in the same frame. + if (src_frame.empty() || src_frame != dst_frame) { + continue; + } + + // Add the control loop. Start by adding the control loop for the + // current frame if needed, and recursively adding the control loop + // for its outer frame when nested. + ControlLoop child_loop; + while (true) { + const string& curr_frame = cf_info[src->id()].frame_name; + if (curr_frame.empty()) { + // We have reached the root frame. + if (child_loop.merge != nullptr) { + const string& node_name = opts.new_name(edge->dst()->name()); + const string& device_name = edge->dst()->assigned_device_name(); + Node* const_node = + AddControlConst(device_name, bopts.WithName(node_name)); + if (!status.ok()) return status; + AddControlFlowInfo(const_node, src, &cf_info); + g->AddEdge(const_node, 0, child_loop.enter, 0); + } + break; + } + + const string& cl_key = strings::StrCat(curr_frame, "$$", dst_device); + auto it = control_loops.find(cl_key); + if (it != control_loops.end()) { + if (child_loop.enter != nullptr) { + g->AddEdge(it->second.merge, 0, child_loop.enter, 0); + } + break; + } + + // Get the frame's LoopCond. + auto cond_it = frame_cond_map.find(curr_frame); + if (cond_it == frame_cond_map.end()) { + return errors::InvalidArgument( + "A cross-device loop must have a pivot predicate: ", curr_frame); + } + Node* loop_cond = cond_it->second; + + // Add the control loop. + ControlLoop curr_loop; + status = + AddControlLoop(opts, g, src, edge, loop_cond, &cf_info, &curr_loop); + if (!status.ok()) return status; + control_loops[cl_key] = curr_loop; + + if (child_loop.enter != nullptr) { + // Connect the merge of the outer loop to the enter of the inner. + g->AddEdge(curr_loop.merge, 0, child_loop.enter, 0); + } + src = cf_info[src->id()].parent_frame; + child_loop = curr_loop; + } + } + + // For a cross-device edge, on the dst device, add a control edge + // from the merge node of the control loop to dst. If a send/recv is + // introduced for this edge in future partitioning, we delete this + // control edge and add a new control edge from the merge to the recv. + num_edge_ids = g->num_edge_ids(); + for (int i = 0; i < num_edge_ids; ++i) { + const Edge* edge = g->FindEdgeId(i); + if (edge == nullptr) continue; + + const Node* src = edge->src(); + Node* dst = edge->dst(); + // Skip Sink/Source nodes. + if (!src->IsOp() || !dst->IsOp()) continue; + + const string& src_device = src->assigned_device_name(); + const string& dst_device = dst->assigned_device_name(); + if (src_device != dst_device) { + const string& src_frame = cf_info[src->id()].frame_name; + const string& dst_frame = cf_info[dst->id()].frame_name; + if (!src_frame.empty() && src_frame == dst_frame) { + const string& cl_key = strings::StrCat(dst_frame, "$$", dst_device); + ControlLoop loop = control_loops[cl_key]; + DCHECK(loop.enter != nullptr); + g->AddControlEdge(loop.merge, dst); + } + } + } + return Status::OK(); +} + +} // end namespace + +Status AddControlEdges(const PartitionOptions& opts, + std::unordered_map<string, GraphDef>* partitions) { + Status status; + // TODO(yuanbyu): Very naive for now. To be improved. + const int num_epochs = 100; + const int prefetch = 6; + + typedef std::pair<const NodeDef*, int64> NodeStartTime; + for (auto& part : *partitions) { + GraphDef* gdef = &part.second; + + std::vector<NodeStartTime> start_times; + start_times.resize(gdef->node_size()); + for (int n = 0; n < gdef->node_size(); ++n) { + const NodeDef& ndef = gdef->node(n); + int64 start_time; + status = GetNodeAttr(ndef, "_start_time", &start_time); + if (!status.ok()) { + return status; + } + start_times[n] = std::make_pair(&ndef, start_time); + } + + // Sort the nodes based on their start times. + std::sort( + start_times.begin(), start_times.end(), + [](NodeStartTime x, NodeStartTime y) { return x.second < y.second; }); + + // Add a dummy node for every epoch, and add a control edge from the + // "last" node in the preceding epoch to the dummy node. + string device_name = gdef->node(0).device(); + int64 makespan = start_times.back().second; + int64 resolution = (makespan / num_epochs) + 1; + + int i = 0; + int j = 0; + std::vector<NodeDef*> dummys; + while (i < num_epochs && static_cast<size_t>(j) < start_times.size()) { + if (i * resolution > start_times[j].second) { + j++; + } else { + NodeDef* dummy = AddControlTrigger(opts, gdef, device_name, i, + i * resolution, &status); + if (!status.ok()) { + return status; + } + dummys.push_back(dummy); + if (j > 0) { + string src_name = start_times[j - 1].first->name(); + AddInput(dummy, src_name, Graph::kControlSlot); + } + i++; + } + } + + // Finally, add the control edges to recvs. + for (int n = 0; n < gdef->node_size(); ++n) { + NodeDef* ndef = gdef->mutable_node(n); + if (ndef->op() == "_Recv") { + int64 start_time; + status = GetNodeAttr(*ndef, "_start_time", &start_time); + if (!status.ok()) { + return status; + } + int recv_epoch = start_time / resolution; + if (recv_epoch >= prefetch) { + NodeDef* dummy = dummys[recv_epoch - prefetch]; + AddInput(ndef, dummy->name(), Graph::kControlSlot); + } + } + } + } + return Status::OK(); +} + +Status Partition(const PartitionOptions& opts, Graph* g, + std::unordered_map<string, GraphDef>* partitions) { + Status status; + partitions->clear(); + + GraphInfo g_info; + if (!opts.control_flow_added) { + // Add the "code" for distributed execution of control flow. Code is + // added only for the frames that are placed on multiple devices. The + // new graph is an equivalent transformation of the original graph and + // has the property that it can be subsequently partitioned arbitrarily + // (down to the level of individual device) for distributed execution. + status = AddControlFlow(opts, g, &g_info); + if (!status.ok()) return status; + } + // At this point, all the graph mutations have been done. Build memory + // and device type info for every node and edge in the graph. + status = BuildMemoryDeviceInfo(*g, &g_info); + if (!status.ok()) return status; + + string dstp; + std::vector<const Edge*> inputs; + DupRecvTable dup_recv(3); + // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref + // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref + // edge to dst. We will add a control edge for every pair in + // (ref_recvs x ref_control_inputs). + std::vector<NodeDef*> ref_recvs; + std::vector<string> ref_control_inputs; + + int32 num_data = 0; + int32 num_control = 0; + for (const Node* dst : g->nodes()) { + if (!dst->IsOp()) continue; // Skip Sink/Source nodes. + + dstp = opts.node_to_loc(dst); + GraphDef* dst_graph = &(*partitions)[dstp]; + NodeDef* dst_def = dst_graph->add_node(); + *dst_def = dst->def(); + dst_def->set_device(dst->assigned_device_name()); + dst_def->clear_input(); // Inputs are filled below + if (opts.need_to_record_start_times) { + int64 start_time = opts.start_times[dst->id()].value(); + AddNodeAttr("_start_time", start_time, dst_def); + } + + // Arrange the incoming edges to dst so that input[i] holds the + // input flowing into slot numbered i. Trailing entries in input[] + // hold control edges. + inputs.clear(); + inputs.resize(dst->num_inputs(), nullptr); + ref_recvs.clear(); + ref_control_inputs.clear(); + const Edge* control_flow_edge = nullptr; + for (const Edge* edge : dst->in_edges()) { + if (edge->IsControlEdge()) { + if (IsMerge(edge->src()) && IsControlLoop(edge->src())) { + // This is one of the control edges added for control flow. There + // can be multiple such edges as the dest node may have multiple + // remote inputs. We will just take one and ignore the others. + control_flow_edge = edge; + } else { + inputs.push_back(edge); + } + } else { + DCHECK(inputs[edge->dst_input()] == nullptr); + inputs[edge->dst_input()] = edge; + } + } + + // Process in order so that all data edges are added as inputs to + // dst in Edge::dst_input() order. + bool recv_added = false; + for (const Edge* edge : inputs) { + const Node* src = edge->src(); + if (!src->IsOp()) continue; // Skip Sink/Source nodes. + + GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)]; + if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) { + // Same partition and compatible memory types: + AddInput(dst_def, src->name(), edge->src_output()); + if (edge->IsControlEdge() || + !IsRefType(src->output_type(edge->src_output()))) { + ref_control_inputs.push_back(src->name()); + } + continue; + } + + int64 send_start_time = 0; + int64 recv_start_time = 0; + if (opts.scheduling_for_recvs) { + if (opts.need_to_record_start_times) { + send_start_time = opts.start_times[src->id()].value(); + recv_start_time = opts.start_times[dst->id()].value(); + } else { + status = GetNodeAttr(src->def(), "_start_time", &send_start_time); + if (!status.ok()) { + return status; + } + status = GetNodeAttr(dst->def(), "_start_time", &recv_start_time); + if (!status.ok()) { + return status; + } + } + } + + // Check whether there is already a send/recv pair transferring + // the same tensor/control from the src to dst partition. + const bool on_host = IsDstInputOnHost(edge, g_info); + DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host}; + auto iter = dup_recv.find(key); + if (iter != dup_recv.end()) { + // We found one. Reuse the data/control transferred already. + const string& recv_node_name = iter->second.recv->name(); + if (edge->IsControlEdge()) { + AddInput(dst_def, recv_node_name, Graph::kControlSlot); + } else { + AddInput(dst_def, recv_node_name, 0); + } + // We want the start_time for the recv to be the smallest of the start + // times of it's consumers. So we update this whenever we use a recv, + // and write it out to the attribute at the end of the subroutine + if (iter->second.start_time > recv_start_time) { + iter->second.start_time = recv_start_time; + } + continue; + } + + NodeDefBuilder::NodeOut send_from; + if (edge->IsControlEdge()) { + // Insert a dummy const node that will generate a tiny + // data element to be sent from send to recv. + VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "[" + << src->name() << "] -> " << dst->assigned_device_name() << "[" + << dst->name() << "]"; + NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status); + if (!status.ok()) return status; + // Set the start time for this dummy node. + if (opts.scheduling_for_recvs) { + AddNodeAttr("_start_time", send_start_time, dummy); + } + AddInput(dummy, src->name(), Graph::kControlSlot); + send_from.Reset(dummy->name(), 0, DT_FLOAT); + } else { + send_from.Reset(src->name(), edge->src_output(), EdgeType(edge)); + } + + // Need to split edge by placing matching send/recv nodes on + // the src/dst sides of the edge. + NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from, + send_start_time, &status); + if (!status.ok()) return status; + + NodeDef* real_recv = nullptr; + NodeDef* recv = + AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status); + if (!status.ok()) return status; + + // Fix up the control flow edge. Redirect it to the recv. + // NOTE(yuanbyu): 'real_recv' must be the real recv node. + recv_added = true; + if (control_flow_edge != nullptr) { + AddInput(real_recv, control_flow_edge->src()->name(), + Graph::kControlSlot); + } + + // For same device send/recv, add a control edge from send to recv. + // This prevents the asynchronous recv kernel from being scheduled + // immediately. + if (src_graph == dst_graph) { + AddInput(real_recv, send->name(), Graph::kControlSlot); + } + + if (!edge->IsControlEdge() && + IsRefType(src->output_type(edge->src_output()))) { + // If src is of ref type and the edge is not a control edge, dst has + // read semantics and therefore we must control the recv. + ref_recvs.push_back(real_recv); + } else { + // Memorize the send/recv pair, only if this is not a "ref" edge. + // NOTE(yuanbyu): Collapsing ref edges requires extreme care so + // for now we don't do it. + dup_recv[key] = {recv, real_recv, recv_start_time}; + ref_control_inputs.push_back(recv->name()); + } + + if (edge->IsControlEdge()) { + ++num_control; + AddInput(dst_def, recv->name(), Graph::kControlSlot); + } else { + ++num_data; + AddInput(dst_def, recv->name(), 0); + } + } + + // Add control edges from 'ref_control_inputs' to 'ref_recvs'. + // NOTE(yuanbyu): Adding these control edges should not introduce + // deadlocks. 'dst' has implicit "read" nodes that, when we split + // across devices, are made explicit; Retargettig the dependencies + // to 'dst' to those nodes would not introduce cycles if there isn't + // one before the transformation. + // NOTE(yuanbyu): This may impact performance because it defers the + // execution of recvs until all the other inputs become available. + AddReadControl(ref_recvs, ref_control_inputs); + + // Add back this control edge for control flow if not used. + if (!recv_added && (control_flow_edge != nullptr)) { + AddInput(dst_def, control_flow_edge->src()->name(), Graph::kControlSlot); + } + } + + // Set the start times for recvs at the very end. + if (opts.scheduling_for_recvs) { + for (auto& it : dup_recv) { + AddNodeAttr("_start_time", it.second.start_time, it.second.recv); + if (it.second.real_recv != it.second.recv) { + AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv); + } + } + } + + VLOG(1) << "Added send/recv: controls=" << num_control + << ", data=" << num_data; + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/graph_partition.h b/tensorflow/core/graph/graph_partition.h new file mode 100644 index 0000000000..eb88ff71b1 --- /dev/null +++ b/tensorflow/core/graph/graph_partition.h @@ -0,0 +1,77 @@ +#ifndef TENSORFLOW_GRAPH_GRAPH_PARTITION_H_ +#define TENSORFLOW_GRAPH_GRAPH_PARTITION_H_ + +#include <functional> +#include <string> +#include <unordered_map> + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/costmodel.h" + +namespace tensorflow { + +struct PartitionOptions { + // A function that returns a location for the execution of a given + // Node. + typedef std::function<string(const Node*)> NodeToLocFunc; + NodeToLocFunc node_to_loc = nullptr; + + // A function that returns a unique graph node name with the given + // prefix. + typedef std::function<string(const string&)> NewNameFunc; + NewNameFunc new_name = nullptr; + + // A function that returns the incarnation of a device given the + // device's fullname. If not found, GetIncarnationFunc should return + // kIlledgalIncarnation. + static const uint64 kIllegalIncarnation = 0; + typedef std::function<uint64(const string&)> GetIncarnationFunc; + GetIncarnationFunc get_incarnation = nullptr; + + // True if all the control flow "code" has already been added. The + // control flow code needs to be added when we still have the entire + // graph before any partitioning. So this flag should be false for + // the first partitioning but true for all subsequent partitioning. + // + // TODO(yuanbyu): We could also make the addition of the control + // flow code incremental based on 'node_to_loc'. This makes the + // communication a broadcast tree, which could be more efficient when + // the number of participating devices is large. + bool control_flow_added; + + // A function that returns the data type into which the tensor + // should be cast before sent over the wire. + typedef std::function<DataType(const Edge*)> ShouldCastFunc; + ShouldCastFunc should_cast = nullptr; + + // Schedule the execution of the recvs based on their start times + // computed by some scheduling algorithm. The recvs are divided into + // epochs based on their start times. A recv is enabled only when + // execution reaches its epoch - N for some predefined N. + bool scheduling_for_recvs = false; + // The start time for each node in the graph computed by some scheduling + // algorithm. If 'need_to_record_start_times' is true, we record them + // in the graph as a node attribute. + bool need_to_record_start_times = false; + std::vector<Microseconds> start_times; +}; + +// Partition "input" graph into a set of graphs, one per location. +// The location for node n is derived by calling opts.node_to_loc(n). +// New nodes added by Partition use "opts.new_name(old_name)" to +// generate node names. +// +// Stores the partitions in *partitions. +Status Partition(const PartitionOptions& opts, Graph* input, + std::unordered_map<string, GraphDef>* partitions); + +// Add control edges to the partitions to control the ordering +// and timing of the recv nodes based on the start times calculated +// using some scheduling algorithm. +Status AddControlEdges(const PartitionOptions& opts, + std::unordered_map<string, GraphDef>* partitions); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_GRAPH_PARTITION_H_ diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc new file mode 100644 index 0000000000..d912c94025 --- /dev/null +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -0,0 +1,316 @@ +#include "tensorflow/core/graph/graph_partition.h" + +#include <unordered_map> + +#include <gtest/gtest.h> +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/control_flow_ops.h" +#include "tensorflow/cc/ops/random_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/equal_graph_def.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace { + +const char gpu_device[] = "/job:a/replica:0/task:0/gpu:0"; + +string SplitByDevice(const Node* node) { return node->assigned_device_name(); } + +string DeviceName(const Node* node) { + char first = node->name()[0]; + if (first == 'G') { + return gpu_device; + } else { + const string cpu_prefix = "/job:a/replica:0/task:0/cpu:"; + int index = first - 'A'; + return strings::StrCat(cpu_prefix, index); + } +} + +void Partition(const GraphDef& graph_def, + std::unordered_map<string, GraphDef>* partitions) { + Graph g(OpRegistry::Global()); + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &g)); + + // Assigns devices to each node. Uses 1st letter of the node name as + // the device index. + for (Node* node : g.nodes()) { + node->set_assigned_device_name(DeviceName(node)); + } + + PartitionOptions popts; + popts.node_to_loc = SplitByDevice; + popts.new_name = [&g](const string& prefix) { return g.NewName(prefix); }; + popts.get_incarnation = [](const string& name) { + return (name[0] - 'A') + 100; + }; + popts.control_flow_added = false; + Status s = Partition(popts, &g, partitions); + CHECK(s.ok()) << s; +} + +void CheckLoopConstruction(const GraphDef& graph_def) { + std::unordered_map<string, GraphDef> partitions; + Partition(graph_def, &partitions); + GraphConstructorOptions opts; + for (const auto& kv : partitions) { + const GraphDef& gdef = kv.second; + bool has_control_enter = false; + bool has_control_merge = false; + bool has_control_switch = false; + bool has_control_next = false; + for (const NodeDef& ndef : gdef.node()) { + // _recvs must have a control input + if (ndef.op() == "_Recv") { + bool has_control = false; + for (const string& input_name : ndef.input()) { + if (StringPiece(input_name).starts_with("^")) { + has_control = true; + break; + } + } + EXPECT_TRUE(has_control); + } + // Must have a control loop + if (StringPiece(ndef.name()).starts_with("_cloop")) { + if (ndef.op() == "Enter") { + has_control_enter = true; + } + if (ndef.op() == "Merge") { + has_control_merge = true; + } + if (ndef.op() == "Switch") { + has_control_switch = true; + } + if (ndef.op() == "NextIteration") { + has_control_next = true; + } + } + } + EXPECT_TRUE(has_control_enter); + EXPECT_TRUE(has_control_merge); + EXPECT_TRUE(has_control_switch); + EXPECT_TRUE(has_control_next); + } +} + +REGISTER_OP("Input").Output("o: float"); +REGISTER_OP("BoolInput").Output("o: bool"); +REGISTER_OP("Cross").Input("a: float").Input("b: float").Output("o: float"); + +Node* Input(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("Input", opts); +} + +Node* BoolInput(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("BoolInput", opts); +} + +Node* Cross(ops::NodeOut a, ops::NodeOut b, + const GraphDefBuilder::Options& opts) { + return ops::BinaryOp("Cross", a, b, opts); +} + +class GraphPartitionTest : public ::testing::Test { + protected: + GraphPartitionTest() + : in_(GraphDefBuilder::kFailImmediately), + builder_a_(GraphDefBuilder::kFailImmediately), + builder_b_(GraphDefBuilder::kFailImmediately), + a_opts_(builder_a_.opts().WithDevice("/job:a/replica:0/task:0/cpu:0")), + b_opts_(builder_b_.opts().WithDevice("/job:a/replica:0/task:0/cpu:1")) { + RequireDefaultOps(); + } + + const GraphDef& ToGraphDef() { + in_.ToGraphDef(&in_graph_def_); + return in_graph_def_; + } + + void ExpectMatchA() { + GraphDef graph_def; + builder_a_.ToGraphDef(&graph_def); + string a = "/job:a/replica:0/task:0/cpu:0"; + TF_EXPECT_GRAPH_EQ(graph_def, partitions_[a]); + } + + void ExpectMatchB() { + GraphDef graph_def; + builder_b_.ToGraphDef(&graph_def); + string b = "/job:a/replica:0/task:0/cpu:1"; + TF_EXPECT_GRAPH_EQ(graph_def, partitions_[b]); + } + + GraphDefBuilder in_; + GraphDef in_graph_def_; + GraphDefBuilder builder_a_; + GraphDefBuilder builder_b_; + GraphDefBuilder::Options a_opts_; + GraphDefBuilder::Options b_opts_; + std::unordered_map<string, GraphDef> partitions_; +}; + +TEST_F(GraphPartitionTest, SingleDevice) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Cross(a1, a1, in_.opts().WithName("A2")); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(1, partitions_.size()); + + a1 = Input(a_opts_.WithName("A1")); + Cross(a1, a1, a_opts_.WithName("A2")); + ExpectMatchA(); +} + +TEST_F(GraphPartitionTest, CrossDeviceData) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Node* b1 = Input(in_.opts().WithName("B1")); + Cross(a1, b1, in_.opts().WithName("B2")); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + string a = "/job:a/replica:0/task:0/cpu:0"; + string b = "/job:a/replica:0/task:0/cpu:1"; + a1 = Input(a_opts_.WithName("A1")); + _Send(a1, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_0")); + ExpectMatchA(); + + b1 = Input(b_opts_.WithName("B1")); + Node* recv = + _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_1")); + Cross(recv, b1, b_opts_.WithName("B2")); + ExpectMatchB(); +} + +TEST_F(GraphPartitionTest, CrossDeviceControl) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Node* b1 = Input(in_.opts().WithName("B1")); + Cross(b1, b1, in_.opts().WithName("B2").WithControlInput(a1)); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + string a = "/job:a/replica:0/task:0/cpu:0"; + string b = "/job:a/replica:0/task:0/cpu:1"; + a1 = Input(a_opts_.WithName("A1")); + Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1)); + _Send(c, "edge_3_A1", a, 82, b, a_opts_.WithName("A1/_1")); + ExpectMatchA(); + + Node* recv = + _Recv(DT_FLOAT, "edge_3_A1", a, 82, b, b_opts_.WithName("A1/_2")); + Node* id = Identity(recv, b_opts_.WithName("A1/_3")); + b1 = Input(b_opts_.WithName("B1")); + Cross(b1, b1, b_opts_.WithName("B2").WithControlInput(id)); + ExpectMatchB(); +} + +TEST_F(GraphPartitionTest, CrossDeviceData_MultiUse) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Node* b1 = Input(in_.opts().WithName("B1")); + Cross(a1, b1, in_.opts().WithName("B2")); + Cross(a1, a1, in_.opts().WithName("B3")); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + string a = "/job:a/replica:0/task:0/cpu:0"; + string b = "/job:a/replica:0/task:0/cpu:1"; + a1 = Input(a_opts_.WithName("A1")); + _Send(a1, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_0")); + ExpectMatchA(); + + Node* recv = + _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_1")); + b1 = Input(b_opts_.WithName("B1")); + Cross(recv, b1, b_opts_.WithName("B2")); + Cross(recv, recv, b_opts_.WithName("B3")); + ExpectMatchB(); +} + +TEST_F(GraphPartitionTest, CrossDeviceControl_MultiUse) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Node* b1 = Input(in_.opts().WithName("B1")); + Cross(b1, b1, in_.opts().WithName("B2").WithControlInput(a1)); + Input(in_.opts().WithName("B3").WithControlInput(a1)); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + string a = "/job:a/replica:0/task:0/cpu:0"; + string b = "/job:a/replica:0/task:0/cpu:1"; + a1 = Input(a_opts_.WithName("A1")); + Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1)); + _Send(c, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_1")); + ExpectMatchA(); + + Node* recv = + _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_2")); + Node* id = Identity(recv, b_opts_.WithName("A1/_3")); + b1 = Input(b_opts_.WithName("B1")); + Cross(b1, b1, b_opts_.WithName("B2").WithControlInput(id)); + Input(b_opts_.WithName("B3").WithControlInput(id)); + ExpectMatchB(); +} + +TEST_F(GraphPartitionTest, CrossDevice_DataControl) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Node* b1 = Input(in_.opts().WithName("B1")); + Cross(a1, b1, in_.opts().WithName("B2")); + Input(in_.opts().WithName("B3").WithControlInput(a1)); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + string a = "/job:a/replica:0/task:0/cpu:0"; + string b = "/job:a/replica:0/task:0/cpu:1"; + a1 = Input(a_opts_.WithName("A1")); + Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1)); + // NOTE: Send 0 A1/_1 -> A1/_2 is not necessarily needed. We could + // use A1/_0 -> A1/_4 as the control as a minor optimization. + _Send(c, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_1")); + _Send(a1, "edge_2_A1", a, 82, b, a_opts_.WithName("A1/_4")); + ExpectMatchA(); + + Node* recv1 = + _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_2")); + Node* id1 = Identity(recv1, b_opts_.WithName("A1/_3")); + Node* recv2 = + _Recv(DT_FLOAT, "edge_2_A1", a, 82, b, b_opts_.WithName("A1/_5")); + b1 = Input(b_opts_.WithName("B1")); + Cross(recv2, b1, b_opts_.WithName("B2")); + Input(b_opts_.WithName("B3").WithControlInput(id1)); + ExpectMatchB(); +} + +TEST_F(GraphPartitionTest, CrossDeviceLoop) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = BoolInput(in_.opts().WithName("A1")); + Node* a2 = Enter(a1, "foo", in_.opts().WithName("A2")); + Node* a3 = Merge({a2, {"A5", 0, DT_BOOL}}, in_.opts().WithName("A3")); + LoopCond(a3, in_.opts().WithName("A4")); + Node* b1 = Identity(a3, in_.opts().WithName("B1")); + NextIteration(b1, in_.opts().WithName("A5")); + + CheckLoopConstruction(ToGraphDef()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc new file mode 100644 index 0000000000..f7a8ffde89 --- /dev/null +++ b/tensorflow/core/graph/graph_test.cc @@ -0,0 +1,252 @@ +#include "tensorflow/core/graph/graph.h" + +#include <set> +#include <gtest/gtest.h> +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +class GraphTest : public ::testing::Test { + protected: + GraphTest() : graph_(OpRegistry::Global()) { RequireDefaultOps(); } + ~GraphTest() override {} + + static void VerifyNodes(Node* node, std::vector<Node*> expected_in, + std::vector<Node*> expected_out) { + std::vector<Node*> in; + for (const Edge* e : node->in_edges()) { + in.push_back(e->src()); + } + EXPECT_EQ(Stringify(expected_in), Stringify(in)); + + std::vector<Node*> out; + for (const Edge* e : node->out_edges()) { + out.push_back(e->dst()); + } + EXPECT_EQ(Stringify(expected_out), Stringify(out)); + } + + Node* AddNodeWithName(const string& name) { + Node* node; + TF_CHECK_OK(NodeBuilder(name, "NoOp").Finalize(&graph_, &node)); + return node; + } + + Graph graph_; + + private: + // Convert a list of nodes to a sorted list of strings so failure messages + // are readable. + static std::vector<string> Stringify(const std::vector<Node*>& nodes) { + std::vector<string> result; + for (Node* n : nodes) { + result.push_back(n->DebugString()); + } + std::sort(result.begin(), result.end()); + return result; + } +}; + +TEST_F(GraphTest, Constructor) { + Node* source = graph_.source_node(); + EXPECT_NE(source, nullptr); + Node* sink = graph_.sink_node(); + EXPECT_NE(sink, nullptr); + VerifyNodes(source, {}, {sink}); + VerifyNodes(sink, {source}, {}); + EXPECT_EQ(2, graph_.num_node_ids()); +} + +TEST_F(GraphTest, RemoveThenAdd) { + AddNodeWithName("A"); + Node* b = AddNodeWithName("B"); + const int b_id = b->id(); + AddNodeWithName("C"); + EXPECT_EQ(5, graph_.num_node_ids()); + graph_.RemoveNode(b); + EXPECT_EQ(5, graph_.num_node_ids()); + Node* d = AddNodeWithName("D"); + EXPECT_NE(b_id, d->id()); // Ids should not be reused. + EXPECT_EQ(6, graph_.num_node_ids()); +} + +TEST_F(GraphTest, InNodesAndOutNodes) { + Node* a = AddNodeWithName("A"); + Node* b = AddNodeWithName("B"); + Node* c = AddNodeWithName("C"); + graph_.RemoveNode(b); + Node* d = AddNodeWithName("D"); + + const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a); + graph_.AddControlEdge(a, graph_.sink_node()); + graph_.AddEdge(a, 0, c, 0); + graph_.AddControlEdge(c, graph_.sink_node()); + + EXPECT_EQ("A", a->name()); + VerifyNodes(a, {graph_.source_node()}, {c, graph_.sink_node()}); + + EXPECT_EQ("C", c->name()); + VerifyNodes(c, {a}, {graph_.sink_node()}); + + EXPECT_EQ("D", d->name()); + VerifyNodes(d, {}, {}); + + VerifyNodes(graph_.source_node(), {}, {a, graph_.sink_node()}); + VerifyNodes(graph_.sink_node(), {a, c, graph_.source_node()}, {}); + + graph_.RemoveEdge(source_to_a); + VerifyNodes(a, {}, {c, graph_.sink_node()}); + VerifyNodes(graph_.source_node(), {}, {graph_.sink_node()}); // no more a + + graph_.RemoveNode(c); + VerifyNodes(a, {}, {graph_.sink_node()}); // no more c + VerifyNodes(graph_.sink_node(), {a, graph_.source_node()}, {}); // no more c + EXPECT_EQ(6, graph_.num_node_ids()); + EXPECT_EQ(5, graph_.num_edge_ids()); +} + +TEST_F(GraphTest, NodeIteration) { + // Set up the graph with some holes due to removals. + Node* a = AddNodeWithName("A"); + Node* b = AddNodeWithName("B"); + Node* c = AddNodeWithName("C"); + graph_.RemoveNode(b); + Node* d = AddNodeWithName("D"); + const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a); + graph_.AddControlEdge(a, graph_.sink_node()); + graph_.AddEdge(a, 0, c, 0); + graph_.AddControlEdge(c, graph_.sink_node()); + graph_.RemoveEdge(source_to_a); + graph_.RemoveNode(c); + + // expected = set of all node DebugStrings we expect in the graph + std::set<string> expected; + expected.insert(graph_.source_node()->DebugString()); + expected.insert(a->DebugString()); + expected.insert(d->DebugString()); + expected.insert(graph_.sink_node()->DebugString()); + + // Verify that iterating through ids gets the same set of nodes. + std::set<string> actual; + for (int id = 0; id < graph_.num_node_ids(); ++id) { + Node* node = graph_.FindNodeId(id); + if (node != nullptr) { + actual.insert(node->DebugString()); + } + } + EXPECT_EQ(expected, actual); + + // Verify that range-based for loop gets the same set of nodes. + actual.clear(); + for (Node* node : graph_.nodes()) { + actual.insert(node->DebugString()); + } + EXPECT_EQ(expected, actual); +} + +static void CheckType(Node* node, bool b) { + EXPECT_TRUE(b) << node->DebugString(); + // Make sure none of the other IsFoo() methods return true. + int count = 0; + if (node->IsSource()) count++; + if (node->IsSink()) count++; + if (node->IsOp()) count++; + EXPECT_EQ(1, count) << node->DebugString(); +} + +TEST_F(GraphTest, Type) { + Node* op = AddNodeWithName("A"); + CheckType(graph_.source_node(), graph_.source_node()->IsSource()); + CheckType(graph_.sink_node(), graph_.sink_node()->IsSink()); + CheckType(op, op->IsOp()); +} + +// Convert edge iteration results into a sorted string. +static string EdgeIter(const Graph& g) { + std::vector<std::pair<int, int> > edges; + for (const Edge* e : g.edges()) { + edges.push_back(std::make_pair(e->src()->id(), e->dst()->id())); + } + std::sort(edges.begin(), edges.end()); + string result; + for (auto& p : edges) { + strings::StrAppend(&result, p.first, "->", p.second, ";"); + } + return result; +} + +TEST_F(GraphTest, EdgeIteration) { + EXPECT_EQ("0->1;", EdgeIter(graph_)); + + Node* a = AddNodeWithName("A"); + Node* b = AddNodeWithName("B"); + EXPECT_EQ("0->1;", EdgeIter(graph_)); // Since a,b are currently disconnected + + graph_.AddEdge(a, 0, b, 0); + EXPECT_EQ("0->1;2->3;", EdgeIter(graph_)); + + graph_.AddControlEdge(graph_.source_node(), a); + graph_.AddControlEdge(b, graph_.sink_node()); + EXPECT_EQ("0->1;0->2;2->3;3->1;", EdgeIter(graph_)); + + graph_.AddEdge(a, 1, a, 0); + EXPECT_EQ("0->1;0->2;2->2;2->3;3->1;", EdgeIter(graph_)); +} + +TEST_F(GraphTest, NewName) { + string a1 = graph_.NewName("A"); + string a2 = graph_.NewName("A"); + string b1 = graph_.NewName("B"); + EXPECT_NE(a1, a2); + EXPECT_NE(a1, b1); + EXPECT_NE(a2, b1); + EXPECT_TRUE(StringPiece(a1).starts_with("A")) << a1; +} + +REGISTER_OP("Input").Output("o: float"); +REGISTER_OP("In2Out1").Input("a: float").Input("b: float").Output("o: float"); + +static void BM_InEdgeIteration(int iters, int num_nodes) { + testing::StopTiming(); + string s; + for (int in = 0; in < 10; in++) { + s += strings::Printf("node { name: 'in%04d' op: 'Input' }", in); + } + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + for (int op = 0; op < num_nodes; op++) { + s += strings::Printf( + "node { name: 'op%04d' op: 'In2Out1' input: ['in%04d', 'in%04d' ] }", + op, rnd.Uniform(10), rnd.Uniform(10)); + } + + Graph graph(OpRegistry::Global()); + GraphDef graph_def; + CHECK(protobuf::TextFormat::ParseFromString(s, &graph_def)); + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); + + int64 sum = 0; + testing::StartTiming(); + for (int i = 0; i < iters; i += graph.num_node_ids()) { + for (const Node* node : graph.nodes()) { + for (auto e : node->in_edges()) { + sum += e->id(); + } + } + } + VLOG(1) << sum; +} +BENCHMARK(BM_InEdgeIteration)->Range(10, 100000); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc new file mode 100644 index 0000000000..8c34323dbe --- /dev/null +++ b/tensorflow/core/graph/node_builder.cc @@ -0,0 +1,115 @@ +#include "tensorflow/core/graph/node_builder.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +NodeBuilder::NodeBuilder(const string& name, const string& op_name, + const OpRegistryInterface* op_registry) + : def_builder_(name, op_name, op_registry) {} + +NodeBuilder::NodeBuilder(const string& name, const OpDef* op_def) + : def_builder_(name, op_def) {} + +NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) { + inputs_.emplace_back(src_node, src_index); + DataType dt; + if (GetOutputType(src_node, src_index, &dt)) { + def_builder_.Input(src_node->name(), src_index, dt); + } + return *this; +} + +NodeBuilder& NodeBuilder::Input(NodeOut src) { + if (src.error) { + AddIndexError(src.node, src.index); + } else { + inputs_.emplace_back(src.node, src.index); + def_builder_.Input(src.name, src.index, src.dt); + } + return *this; +} + +NodeBuilder& NodeBuilder::Input(gtl::ArraySlice<NodeOut> src_list) { + std::vector<NodeDefBuilder::NodeOut> srcs; + srcs.reserve(src_list.size()); + for (const auto& node_out : src_list) { + if (node_out.error) { + AddIndexError(node_out.node, node_out.index); + } else { + srcs.emplace_back(node_out.name, node_out.index, node_out.dt); + inputs_.emplace_back(node_out.node, node_out.index); + } + } + def_builder_.Input(srcs); + return *this; +} + +NodeBuilder& NodeBuilder::ControlInput(Node* src_node) { + control_inputs_.emplace_back(src_node); + def_builder_.ControlInput(src_node->name()); + return *this; +} + +NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) { + control_inputs_.insert(control_inputs_.end(), src_nodes.begin(), + src_nodes.end()); + for (Node* src_node : src_nodes) { + def_builder_.ControlInput(src_node->name()); + } + return *this; +} + +NodeBuilder& NodeBuilder::Device(const string& device_spec) { + def_builder_.Device(device_spec); + return *this; +} + +Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const { + // In case of error, set *created_node to nullptr. + if (created_node != nullptr) *created_node = nullptr; + if (!errors_.empty()) { + return errors::InvalidArgument(str_util::Join(errors_, "\n")); + } + + NodeDef node_def; + TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def)); + TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def())); + Status status; + Node* node = graph->AddNode(node_def, &status); + if (!status.ok()) return status; + + for (size_t i = 0; i < inputs_.size(); ++i) { + if (inputs_[i].node != nullptr) { // Skip back edges. + graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i); + } + } + for (Node* control_input : control_inputs_) { + graph->AddControlEdge(control_input, node); + } + if (created_node != nullptr) *created_node = node; + return Status::OK(); +} + +void NodeBuilder::AddIndexError(Node* node, int i) { + if (node == nullptr) { + errors_.emplace_back( + strings::StrCat("Attempt to add nullptr Node to node with type", + def_builder_.op_def().name())); + } else { + errors_.emplace_back( + strings::StrCat("Attempt to add output ", i, " of ", node->name(), + " not in range [0, ", node->num_outputs(), + ") to node with type ", def_builder_.op_def().name())); + } +} + +bool NodeBuilder::GetOutputType(Node* node, int i, DataType* dt) { + bool error; + *dt = SafeGetOutput(node, i, &error); + if (error) AddIndexError(node, i); + return !error; +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h new file mode 100644 index 0000000000..dd34b97f23 --- /dev/null +++ b/tensorflow/core/graph/node_builder.h @@ -0,0 +1,146 @@ +#ifndef TENSORFLOW_GRAPH_NODE_BUILDER_H_ +#define TENSORFLOW_GRAPH_NODE_BUILDER_H_ + +#include <vector> +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// This is a helper for creating a Node and adding it to a Graph. +// Internally, it uses a NodeDefBuilder to automatically set attrs +// that can be inferred from the inputs, and use default values +// (where they exist) for unspecified attrs. Example usage: +// +// Node* node; +// Status status = NodeBuilder(node_name, op_name) +// .Input(...) +// .Attr(...) +// .Finalize(&graph, &node); +// if (!status.ok()) return status; +// // Use node here. +class NodeBuilder { + public: + // For specifying the output of a Node to provide to one of the Input() + // functions below. It supports both regular inputs (where you are + // connecting to an existing Node*), and inputs from outside the graph + // (or haven't been added to the graph yet, like back edges, where + // you don't have a Node*). Both types can be mixed, e.g. in an + // ArraySlice. + struct NodeOut { + // For referencing an existing Node. + NodeOut(Node* n, int i = 0) // NOLINT(runtime/explicit) + : node(n), + error(false), + name(node != nullptr ? node->name() : (error = true, "")), + index(i), + dt(SafeGetOutput(node, i, &error)) {} + + // For referencing Nodes not in the graph being built. It is + // useful when preparing a graph for ExtendSession or creating a + // back edge to a node that hasn't been added to the graph yet, + // but will be. + NodeOut(const string& name, int i, DataType t) + : node(nullptr), error(false), name(name), index(i), dt(t) {} + + // Default constructor for std::vector<NodeOut>. + NodeOut() {} + + Node* node = nullptr; + // error is set to true if: + // * the NodeOut was default constructed and never overwritten, + // * a nullptr Node* was passed to the NodeOut constructor, or + // * an out-of-range index was passed to the NodeOut constructor. + bool error = true; + string name; + int index = 0; + DataType dt = DT_FLOAT; + }; + + // Specify the name and the Op (either via an OpDef or the name of + // the Op plus a registry) for the Node. Other fields are + // specified by calling the methods below. + // REQUIRES: The OpDef must satisfy ValidateOpDef(). + NodeBuilder(const string& name, const string& op_name, + const OpRegistryInterface* op_registry = OpRegistry::Global()); + NodeBuilder(const string& name, const OpDef* op_def); + + // You must call one Input() function per input_arg in the Op, + // *and in the same order as the input_args appear in the OpDef.* + + // For inputs that take a single tensor. + NodeBuilder& Input(Node* src_node, int src_index = 0); + NodeBuilder& Input(NodeOut src); + + // For inputs that take a list of tensors. + NodeBuilder& Input(gtl::ArraySlice<NodeOut> src_list); + + // Require that this node run after src_node(s). + NodeBuilder& ControlInput(Node* src_node); + NodeBuilder& ControlInputs(gtl::ArraySlice<Node*> src_nodes); + + // Sets the "requested device spec" in the NodeDef (not the + // "assigned device" in the Node). + NodeBuilder& Device(const string& device_spec); + + // Set the value of an attr. attr_name must match the name of one of + // attrs defined by the Op, and value must have the corresponding type + // (see SetAttrValue() in ../framework/attr_value_util.h for legal + // types for value). Note that attrs will be set automatically if + // they can be determined by the inputs. + template <class T> + NodeBuilder& Attr(const string& attr_name, T&& value); + template <class T> + NodeBuilder& Attr(const string& attr_name, std::initializer_list<T> value); + + // Validates the described node and adds it to *graph, adding edges + // for all (non-back) inputs. If created_node is not nullptr, + // *created_node will be set to the new node (or nullptr on error). + Status Finalize(Graph* graph, Node** created_node) const; + + private: + static DataType SafeGetOutput(Node* node, int i, bool* error) { + if (node != nullptr && i >= 0 && i < node->num_outputs()) { + *error = false; + return node->output_type(i); + } else { + *error = true; + return DT_FLOAT; + } + } + + // If SafeGetOutput indicates a range error, add it to errors_. + void AddIndexError(Node* node, int i); + + // Set *dt and returns true if i is in range. Combines + // SafeGetOutput() and AddIndexError(). + bool GetOutputType(Node* node, int i, DataType* dt); + + NodeDefBuilder def_builder_; + std::vector<NodeOut> inputs_; + std::vector<Node*> control_inputs_; + std::vector<string> errors_; +}; + +// IMPLEMENTATION ------------------------------------------------------------- + +template <class T> +inline NodeBuilder& NodeBuilder::Attr(const string& attr_name, T&& value) { + def_builder_.Attr(attr_name, std::forward<T>(value)); + return *this; +} + +template <class T> +NodeBuilder& NodeBuilder::Attr(const string& attr_name, + std::initializer_list<T> value) { + def_builder_.Attr(attr_name, value); + return *this; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_NODE_BUILDER_H_ diff --git a/tensorflow/core/graph/node_builder_test.cc b/tensorflow/core/graph/node_builder_test.cc new file mode 100644 index 0000000000..9f667d00e4 --- /dev/null +++ b/tensorflow/core/graph/node_builder_test.cc @@ -0,0 +1,59 @@ +#include "tensorflow/core/graph/node_builder.h" + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace { + +REGISTER_OP("Source").Output("o: out_types").Attr("out_types: list(type)"); +REGISTER_OP("Sink").Input("i: T").Attr("T: type"); + +TEST(NodeBuilderTest, Simple) { + RequireDefaultOps(); + Graph graph(OpRegistry::Global()); + Node* source_node; + EXPECT_OK(NodeBuilder("source_op", "Source") + .Attr("out_types", {DT_INT32, DT_STRING}) + .Finalize(&graph, &source_node)); + ASSERT_TRUE(source_node != nullptr); + + // Try connecting to each of source_node's outputs. + EXPECT_OK(NodeBuilder("sink1", "Sink") + .Input(source_node) + .Finalize(&graph, nullptr)); + EXPECT_OK(NodeBuilder("sink2", "Sink") + .Input(source_node, 1) + .Finalize(&graph, nullptr)); + + // Generate an error if the index is out of range. + EXPECT_FALSE(NodeBuilder("sink3", "Sink") + .Input(source_node, 2) + .Finalize(&graph, nullptr) + .ok()); + EXPECT_FALSE(NodeBuilder("sink4", "Sink") + .Input(source_node, -1) + .Finalize(&graph, nullptr) + .ok()); + EXPECT_FALSE(NodeBuilder("sink5", "Sink") + .Input({source_node, -1}) + .Finalize(&graph, nullptr) + .ok()); + + // Generate an error if the node is nullptr. This can happen when using + // GraphDefBuilder if there was an error creating the input node. + EXPECT_FALSE(NodeBuilder("sink6", "Sink") + .Input(nullptr) + .Finalize(&graph, nullptr) + .ok()); + EXPECT_FALSE(NodeBuilder("sink7", "Sink") + .Input(NodeBuilder::NodeOut(nullptr, 0)) + .Finalize(&graph, nullptr) + .ok()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/optimizer_cse.cc b/tensorflow/core/graph/optimizer_cse.cc new file mode 100644 index 0000000000..2fa6f075c0 --- /dev/null +++ b/tensorflow/core/graph/optimizer_cse.cc @@ -0,0 +1,220 @@ +// This module implements a common subexpression elimination pass. We +// process the nodes in the graph in reverse postorder +// (i.e. inputs before their downstream dependencies). The rough algorithm is +// as follows: +// +// std::unordered_map<size_t, Node*> available +// for each node n in forward topological order: +// h = NodeHash(n) +// if available[h] exists and Equivalent(available(h), h) +// redirect downstream uses of outputs of n to available[h] +// remove n from graph +// else +// if available[h] does not exist +// available[h] = n +// +// This is similar to the global value number algorithm describe in this +// paper: +// +// "Global code motion/global value numbering", Cliff Click, PLDI '95 +// Proceedings of the ACM SIGPLAN 1995 conference on Programming +// language design and implementation, Pages 246-257 +// http://dl.acm.org/citation.cfm?id=207154 + +#include "tensorflow/core/graph/optimizer_cse.h" + +#include <unordered_map> + +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +class OptimizerCSE { + public: + explicit OptimizerCSE(Graph* g) : g_(g) {} + + void Optimize(std::function<bool(const Node*)> consider_fn); + + private: + struct Scratch; + + static size_t NodeHash(const Node* n); + static bool Equivalent(const Node* a, const Node* b, Scratch* s); + static bool EqualAttrs(const Node* a, const Node* b, Scratch* s); + + Graph* g_; +}; + +static void FillInputs(const Node* n, + gtl::InlinedVector<Node*, 4>* control_edges, + gtl::InlinedVector<std::pair<Node*, int>, 4>* in) { + DCHECK_EQ(in->size(), n->num_inputs()); + control_edges->clear(); + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + control_edges->push_back(e->src()); + } else { + (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output()); + } + } + std::sort(control_edges->begin(), control_edges->end()); + if (n->op_def().is_commutative()) { + // For commutative inputs, we sort the input by the input Node* + // to get a canonical ordering (so that add(a,b) and add(b, a) will + // hash to the same value if is_commutative is true for 'add'). + std::sort(in->begin(), in->end()); + } +} + +static size_t kIllegalNodeHash = 0; + +size_t OptimizerCSE::NodeHash(const Node* n) { + const DataTypeVector& out = n->output_types(); + string str_to_hash = strings::StrCat(n->type_string(), out.size()); + for (DataType dt : out) { + strings::StrAppend(&str_to_hash, dt); + } + + const int N_in = n->num_inputs(); + strings::StrAppend(&str_to_hash, N_in); + gtl::InlinedVector<Node*, 4> control_edges; + gtl::InlinedVector<std::pair<Node*, int>, 4> in(N_in); + FillInputs(n, &control_edges, &in); + for (const auto& edge : in) { + strings::StrAppend(&str_to_hash, edge.first->id(), edge.second); + } + + size_t h = Hash64(str_to_hash); + +#if !defined(__ANDROID__) && !defined(ANDROID) + // Hash the attrs. For example, this makes sure different constants + // end up in different hash buckets. + string tmp; + for (const auto& attr : n->def().attr()) { + tmp = attr.first; + attr.second.AppendToString(&tmp); + // Add hashes of attrs, so the order of attrs doesn't matter. + h += Hash32(tmp.data(), tmp.size(), 0x87341245); + } +#endif + + if (h == kIllegalNodeHash) h = kIllegalNodeHash + 1; + return h; +} + +struct OptimizerCSE::Scratch { + // For EqualAttrs(): + string a; + string b; +}; + +bool OptimizerCSE::EqualAttrs(const Node* a, const Node* b, Scratch* scratch) { + if (a->def().attr_size() != b->def().attr_size()) return false; + + for (const auto& attr : b->def().attr()) { + auto iter = a->def().attr().find(attr.first); + if (iter == a->def().attr().end()) return false; + // Note: it should be safe to compare proto serializations of the attr + // values since at most one field should be set in each (indeed, it + // should be the same field). + iter->second.SerializeToString(&scratch->a); + attr.second.SerializeToString(&scratch->b); + if (scratch->a != scratch->b) return false; + } + return true; +} + +static bool HasRefInput(const Node* n) { + for (auto dt : n->input_types()) { + if (IsRefType(dt)) return true; + } + return false; +} + +bool OptimizerCSE::Equivalent(const Node* a, const Node* b, Scratch* scratch) { + // Different op names are different + if (a->type_string() != b->type_string()) return false; + + // Never consider stateful nodes (such as non-const inputs) equivalent. + if (a->op_def().is_stateful()) return false; + + // For now, we consider any node that takes a ref input to not be + // equivalent to any other node. + if (HasRefInput(a) || HasRefInput(b)) return false; + + // Compare attrs. Note that equal attrs implies equal input and + // output types. + if (!EqualAttrs(a, b, scratch)) return false; + + // Compare input sources + if (a->num_inputs() != b->num_inputs()) return false; + const int N_in = a->num_inputs(); + gtl::InlinedVector<Node*, 4> a_control_edges; + gtl::InlinedVector<Node*, 4> b_control_edges; + gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in); + gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in); + FillInputs(a, &a_control_edges, &a_in); + FillInputs(b, &b_control_edges, &b_in); + if (a_in != b_in) return false; + if (a_control_edges != b_control_edges) return false; + + return true; +} + +void OptimizerCSE::Optimize(std::function<bool(const Node*)> consider_fn) { + // This very simple implementation works if the whole graph is one + // giant basic block (because we just traverse nodes in a + // topological order). We'll need to do something more + // sophisticated when we have control flow/loops/etc. + + // TODO(jeff): We need to handle Update nodes specially, but dealing + // with more general control flow will also solve this issue, and for + // now, our updates are almost always the most downstream nodes in + // the graph. + std::vector<Node*> order; + GetReversePostOrder(*g_, &order); + + // Our value is just a single Node*, meaning we keep just a single + // candidate for a given node hash value. This may cause us to + // (rarely) lose some optimization opportunities if there are + // hash collisions, but it allows us to avoid having the value + // be a set<Node*> (or equivalent). + std::unordered_map<size_t, Node*> available; + + // Scratch space for Equivalent calls. Allocated here and passed in to + // Equivalent to avoid allocation inside the loop below. + Scratch scratch; + for (Node* n : order) { + if (!n->IsOp()) continue; + + // See if we should consider this node at all + if (consider_fn != nullptr && !consider_fn(n)) continue; + + size_t h = NodeHash(n); + Node** candidate = &available[h]; + if (*candidate == nullptr) { + // No existing match: insert "n" into the hash table under "h" + *candidate = n; + } else if (Equivalent(*candidate, n, &scratch)) { + VLOG(1) << "CSE: equivalent: " << (*candidate)->name() << " and " + << n->name(); + // *candidate and n are equivalent. Therefore, we can replace + // n with *candidate by fixing up outgoing edges from "n" to instead + // come from "*candidate", and then delete n from the graph + for (const Edge* e : n->out_edges()) { + g_->AddEdge(*candidate, e->src_output(), e->dst(), e->dst_input()); + } + g_->RemoveNode(n); + } + } +} + +void OptimizeCSE(Graph* g, std::function<bool(const Node*)> consider_fn) { + OptimizerCSE opt(g); + opt.Optimize(consider_fn); +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/optimizer_cse.h b/tensorflow/core/graph/optimizer_cse.h new file mode 100644 index 0000000000..430c97a449 --- /dev/null +++ b/tensorflow/core/graph/optimizer_cse.h @@ -0,0 +1,19 @@ +// An optimization pass that performs common subexpression elimination. + +#ifndef TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_ +#define TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_ + +#include <sys/types.h> +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Perform common-subexpression elimination on the graph "*g". If +// "consider_fn" is not nullptr, then only nodes for which +// consider_fn(node) returns true will be considered for combining +// during the common subexpression elimination. +extern void OptimizeCSE(Graph* g, std::function<bool(const Node*)> consider_fn); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_ diff --git a/tensorflow/core/graph/optimizer_cse_test.cc b/tensorflow/core/graph/optimizer_cse_test.cc new file mode 100644 index 0000000000..ebbb948fdc --- /dev/null +++ b/tensorflow/core/graph/optimizer_cse_test.cc @@ -0,0 +1,365 @@ +#include "tensorflow/core/graph/optimizer_cse.h" + +#include <gtest/gtest.h> +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +namespace { + +static void InitGraph(const string& s, Graph* graph) { + GraphDef graph_def; + + auto parser = protobuf::TextFormat::Parser(); + // parser.AllowRelaxedWhitespace(true); + CHECK(parser.MergeFromString(s, &graph_def)) << s; + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph)); +} + +class OptimizerCSETest : public ::testing::Test { + public: + OptimizerCSETest() : graph_(OpRegistry::Global()) { RequireDefaultOps(); } + + void InitGraph(const string& s) { + ::tensorflow::InitGraph(s, &graph_); + original_ = CanonicalGraphString(&graph_); + } + + static bool IncludeNode(const Node* n) { return n->IsOp(); } + + static string EdgeId(const Node* n, int index) { + if (index == 0) { + return n->name(); + } else if (index == Graph::kControlSlot) { + return strings::StrCat(n->name(), ":control"); + } else { + return strings::StrCat(n->name(), ":", index); + } + } + + string CanonicalGraphString(Graph* g) { + std::vector<string> nodes; + std::vector<string> edges; + for (const Node* n : g->nodes()) { + if (IncludeNode(n)) { + nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")")); + } + } + for (const Edge* e : g->edges()) { + if (IncludeNode(e->src()) && IncludeNode(e->dst())) { + edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->", + EdgeId(e->dst(), e->dst_input()))); + } + } + // Canonicalize + std::sort(nodes.begin(), nodes.end()); + std::sort(edges.begin(), edges.end()); + return strings::StrCat(str_util::Join(nodes, ";"), "|", + str_util::Join(edges, ";")); + } + + string DoCSE(std::function<bool(const Node*)> consider_fn = nullptr) { + string before = CanonicalGraphString(&graph_); + LOG(ERROR) << "Before rewrites: " << before; + + OptimizeCSE(&graph_, consider_fn); + + string result = CanonicalGraphString(&graph_); + LOG(ERROR) << "After rewrites: " << result; + return result; + } + + const string& OriginalGraph() const { return original_; } + + Graph graph_; + string original_; +}; + +REGISTER_OP("Input").Output("o: float").SetIsStateful(); + +// Note that the "rules" in these tests are not meant to be logically correct +TEST_F(OptimizerCSETest, Simple) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);D(Mul)|" + "A->D;B->D:1"); +} + +TEST_F(OptimizerCSETest, Simple_ThreeEquivalent) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);E(Mul)|" + "A->E;B->E:1"); +} + +TEST_F(OptimizerCSETest, Simple_WithFixups) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['C', 'D'] }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);D(Mul);E(Mul)|" + "A->D;B->D:1;D->E;D->E:1"); +} + +TEST_F(OptimizerCSETest, Simple_Commutative) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['B', 'A'] }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);D(Mul)|" + "A->D:1;B->D"); +} + +static bool IsNotMultiply(const Node* n) { return n->type_string() != "Mul"; } + +// Like Simple_Commutative, +TEST_F(OptimizerCSETest, Simple_Filtered) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['B', 'A'] }"); + EXPECT_EQ(DoCSE(IsNotMultiply), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, Simple_NotCommutative) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['B', 'A'] }"); + EXPECT_EQ(DoCSE(), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, NotEquivalent_Ops) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoCSE(), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs1) { + // Should still do CSE for ops with attrs if they match. + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] attr { key: 'shape'" + " value { shape: { dim: { size: 37 name: 'SAME_NAME' } } } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] attr { key: 'shape'" + " value { shape: { dim: { size: 37 name: 'SAME_NAME' } } } } }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);D(Mul)|" + "A->D;B->D:1"); +} + +TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs2) { + // Should still do CSE for ops with attrs if they match, even if they + // are not in the same order. + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 'a' value { i: 3 } }" + " attr { key: 't' value { type: DT_INT32 } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 't' value { type: DT_INT32 } }" + " attr { key: 'a' value { i: 3 } } }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);D(Mul)|" + "A->D;B->D:1"); +} + +TEST_F(OptimizerCSETest, SameConstants) { + // Should still do CSE for ops with constants if the values are identical + InitGraph( + "node { name: 'A' op: 'Const' " + " attr { key: 'dtype' value { type: DT_INT32 } }" + " attr { key: 'value' value {" + " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " + " int_val: 0 } } } }" + "node { name: 'B' op: 'Const' " + " attr { key: 'dtype' value { type: DT_INT32 } }" + " attr { key: 'value' value {" + " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " + " int_val: 0 } } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_INT32 } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoCSE(), + "B(Const);D(Mul)|" + "B->D;B->D:1"); +} + +TEST_F(OptimizerCSETest, DifferentConstants) { + // Should still do CSE for ops with extensions if the extensions are identical + InitGraph( + "node { name: 'A' op: 'Const' " + " attr { key: 'dtype' value { type: DT_INT32 } }" + " attr { key: 'value' value {" + " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " + " int_val: 0 } } } }" + "node { name: 'B' op: 'Const' " + " attr { key: 'dtype' value { type: DT_INT32 } }" + " attr { key: 'value' value {" + " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " + " int_val: 100000 } } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_INT32 } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoCSE(), + "A(Const);B(Const);D(Mul)|" + "A->D;B->D:1"); +} + +TEST_F(OptimizerCSETest, SameOps_DifferentAttrs1) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 'a' value { i: 3 } }" + " attr { key: 't' value { type: DT_INT32 } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 't' value { type: DT_INT32 } }" + " attr { key: 'a' value { i: 4 } } }"); + EXPECT_EQ(DoCSE(), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, SameOps_DifferentAttrs2) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 'a' value { i: 3 } }" + " attr { key: 't' value { type: DT_FLOAT } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 't' value { type: DT_INT32 } }" + " attr { key: 'a' value { i: 3 } } }"); + EXPECT_EQ(DoCSE(), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, NotEquivalent_Inputs) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Input'}" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'C'] }"); + EXPECT_EQ(DoCSE(), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, Constant_Dedup) { + Tensor a(DT_FLOAT, TensorShape({1})); + a.flat<float>()(0) = 1.0; + Tensor b(DT_DOUBLE, TensorShape({1})); // Different type + b.flat<double>()(0) = 1.0; + Tensor c(DT_FLOAT, TensorShape({1, 1})); // Different shape + c.flat<float>()(0) = 1.0; + Tensor d(DT_FLOAT, TensorShape({1})); // Different value + d.flat<float>()(0) = 2.0; + + // A graph contains a bunch of constants. + Graph g(OpRegistry::Global()); + for (auto val : {a, b, c, d, d, c, b, a}) { + test::graph::Constant(&g, val); // Node name is n/_0, n/_1, ... + } + GraphDef gdef; + test::graph::ToGraphDef(&g, &gdef); + InitGraph(gdef.DebugString()); + + EXPECT_EQ(OriginalGraph(), + "n/_0(Const);n/_1(Const);n/_2(Const);n/_3(Const);" + "n/_4(Const);n/_5(Const);n/_6(Const);n/_7(Const)|"); + // In theory, there are 2^4 possible correct output of CSE. In this + // test, it happens happens to eliminate the first 4 nodes. + EXPECT_EQ(DoCSE(), "n/_4(Const);n/_5(Const);n/_6(Const);n/_7(Const)|"); +} + +static void BM_CSE(int iters, int op_nodes) { + testing::StopTiming(); + string s; + for (int in = 0; in < 10; in++) { + s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in); + } + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + for (int op = 0; op < op_nodes; op++) { + s += strings::Printf( + "node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { " + "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }", + op, rnd.Uniform(10), rnd.Uniform(10)); + } + + bool first = true; + while (iters > 0) { + Graph* graph = new Graph(OpRegistry::Global()); + InitGraph(s, graph); + int N = graph->num_node_ids(); + if (first) { + testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N)); + first = false; + } + { + testing::StartTiming(); + OptimizeCSE(graph, nullptr); + testing::StopTiming(); + } + iters -= N; // Our benchmark units are individual graph nodes, + // not whole graphs + delete graph; + } +} +BENCHMARK(BM_CSE)->Arg(1000)->Arg(10000); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc new file mode 100644 index 0000000000..7910511dfb --- /dev/null +++ b/tensorflow/core/graph/subgraph.cc @@ -0,0 +1,258 @@ +#include "tensorflow/core/graph/subgraph.h" + +#include <algorithm> +#include <deque> +#include <string> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// ---------------------------------------------------------------------------- +// Subgraph construction-related routines +// ---------------------------------------------------------------------------- +// TODO(vrv): Profile the unordered_set and unordered_map use in this file to +// see if we should use an alternative implementation. + +namespace { + +typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher> NameIndex; + +// Rewrite graph by replacing the output tensors specified in +// "fed_outputs" with special feed nodes for each specified output +// tensor, and removing any nodes that are now disconnected from the +// part of the graph that reaches the sink node. The set of special +// feed nodes added to the graph are returned in "*feed_nodes". +// +// Return true on success. On error, return false and sets *error to +// an appropriate error message (and *g is left in an indeterminate +// state). +static Status FeedInputs(Graph* g, const DeviceAttributes& device_info, + const gtl::ArraySlice<string>& fed_outputs, + NameIndex* name_index) { + for (const string& t : fed_outputs) { + TensorId id(ParseTensorName(t)); + + auto iter = name_index->find(id.first); + if (iter == name_index->end()) { + return errors::NotFound("FeedInputs: unable to find feed output ", t); + } + const Node* n = iter->second; + DCHECK_EQ(n->name(), id.first); + if (id.second >= n->num_outputs()) { + return errors::InvalidArgument( + "FeedInputs: ", t, " should have output index < ", n->num_outputs()); + } + + Node* recv_node; + TF_RETURN_IF_ERROR( + NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second), + "_Recv") + .Attr("tensor_type", BaseType(n->output_type(id.second))) + .Attr("tensor_name", t) + .Attr("send_device", device_info.name()) + .Attr("recv_device", device_info.name()) + .Attr("send_device_incarnation", + static_cast<int64>(device_info.incarnation())) + .Attr("client_terminated", true) + .Finalize(g, &recv_node)); + recv_node->set_assigned_device_name(device_info.name()); + + // Update name_index + (*name_index)[recv_node->name()] = recv_node; + g->AddControlEdge(g->source_node(), recv_node); + + // Look through edges coming out of "n" for edges whose src_output() index + // matches "output_index". If found, replace the edges with a connection + // from the special feed node. + std::vector<const Edge*> to_remove; + for (const Edge* e : n->out_edges()) { + if (e->src_output() == id.second) { + to_remove.emplace_back(e); + } else if (e->src_output() == Graph::kControlSlot && + n->def().op() == "Placeholder") { + // When feeding a Placeholder node, any outgoing control edges + // will be replaced with a control edge from the replacement + // recv_node. + // TODO(josh11b,mrry): Come up with a more elegant way of addressing + // the general version of this problem. + to_remove.emplace_back(e); + } + } + + for (const Edge* e : to_remove) { + if (e->src_output() == id.second) { + g->AddEdge(recv_node, 0, e->dst(), e->dst_input()); + } else { + CHECK_EQ(Graph::kControlSlot, e->src_output()); + g->AddControlEdge(recv_node, e->dst()); + } + g->RemoveEdge(e); + } + } + return Status::OK(); +} + +// Augment "*g" by adding special "fetch" nodes that connect to the +// tensor outputs specified in "fetch_outputs" to retrieve the output +// of the tensors. The new nodes added are set up to execute on +// "client_device_name", and are returned in "*fetch_nodes". +// +// Return true on success. On error, return false and sets *error to +// an appropriate error message (and *g is left in an indeterminate +// state). +static Status FetchOutputs(Graph* g, const DeviceAttributes& device_info, + const gtl::ArraySlice<string>& fetch_outputs, + NameIndex* name_index, + std::vector<Node*>* fetch_nodes) { + fetch_nodes->clear(); + for (const string& t : fetch_outputs) { + // Parse t into node_name and output_index. + TensorId id(ParseTensorName(t)); + + // Find node in graph with that name. + auto iter = name_index->find(id.first); + if (iter == name_index->end()) { + return errors::NotFound("FetchOutputs node ", t, ": not found"); + } + Node* n = iter->second; + DCHECK_EQ(n->name(), id.first); + VLOG(2) << "Found fetch node for " << t; + + // Validate output_index + if (id.second >= n->num_outputs()) { + return errors::InvalidArgument("FetchOutputs ", t, + ": output index too large, must be < ", + n->num_outputs()); + } + + // Create the fetch Node and connect it up + Node* send_node; + TF_RETURN_IF_ERROR( + NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second), + "_Send") + .Input(n, id.second) + .Attr("tensor_name", t) + .Attr("send_device", device_info.name()) + .Attr("recv_device", device_info.name()) + .Attr("send_device_incarnation", + static_cast<int64>(device_info.incarnation())) + .Attr("client_terminated", true) + .Finalize(g, &send_node)); + send_node->set_assigned_device_name(device_info.name()); + VLOG(1) << "Created fetch node: " << SummarizeNodeDef(send_node->def()); + + // Update the index. + (*name_index)[send_node->name()] = send_node; + + g->AddControlEdge(send_node, g->sink_node()); + fetch_nodes->push_back(send_node); + } + + return Status::OK(); +} + +static bool AddNodeToTargets(const string& node_or_tensor_name, + const NameIndex& name_index, + std::unordered_set<const Node*>* targets) { + TensorId id = ParseTensorName(node_or_tensor_name); + auto iter = name_index.find(id.first); + if (iter == name_index.end()) { + return false; + } + const Node* n = iter->second; + if (n->name() != node_or_tensor_name) { + return false; + } + + targets->insert(n); + return true; +} + +static Status PruneForTargets(Graph* g, const NameIndex& name_index, + const std::vector<Node*>& fetch_nodes, + const gtl::ArraySlice<string>& target_nodes) { + string not_found; + std::unordered_set<const Node*> targets; + for (Node* n : fetch_nodes) { + if (!AddNodeToTargets(n->name(), name_index, &targets)) { + strings::StrAppend(¬_found, n->name(), " "); + } + } + for (const string& s : target_nodes) { + if (!AddNodeToTargets(s, name_index, &targets)) { + strings::StrAppend(¬_found, s, " "); + } + } + if (!not_found.empty()) { + return errors::NotFound("PruneForTargets: Some target nodes not found: ", + not_found); + } + PruneForReverseReachability(g, targets); + + return Status::OK(); +} + +} // namespace + +namespace subgraph { + +Status RewriteGraphForExecution( + Graph* g, const gtl::ArraySlice<string>& fed_outputs, + const gtl::ArraySlice<string>& fetch_outputs, + const gtl::ArraySlice<string>& target_node_names, + const DeviceAttributes& device_info) { + std::unordered_set<string> endpoints(fed_outputs.begin(), fed_outputs.end()); + for (const auto& fetch : fetch_outputs) { + if (endpoints.count(fetch) > 0) { + return errors::InvalidArgument(fetch, " is both fed and fetched."); + } + } + + // A separate index mapping name to Node*, for use by FeedInputs, + // FetchOutputs, and PruneForTargets + NameIndex name_index; + for (Node* n : g->nodes()) { + name_index[n->name()] = n; + } + + // Add the feeds. This may replace nodes in the graph, including the nodes + // currently listed in "fetch_nodes". We pass "name_index" so the index is + // kept up to date. + if (!fed_outputs.empty()) { + TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs, &name_index)); + } + + // Add the fetch nodes, also updating "name_index". + std::vector<Node*> fetch_nodes; + if (!fetch_outputs.empty()) { + TF_RETURN_IF_ERROR( + FetchOutputs(g, device_info, fetch_outputs, &name_index, &fetch_nodes)); + } + + // Prune the graph to only compute what is needed for the fetch nodes and the + // targets nodes. + if (!fetch_nodes.empty() || !target_node_names.empty()) { + TF_RETURN_IF_ERROR( + PruneForTargets(g, name_index, fetch_nodes, target_node_names)); + } + + return Status::OK(); +} + +} // namespace subgraph + +} // namespace tensorflow diff --git a/tensorflow/core/graph/subgraph.h b/tensorflow/core/graph/subgraph.h new file mode 100644 index 0000000000..d2e138e8ae --- /dev/null +++ b/tensorflow/core/graph/subgraph.h @@ -0,0 +1,49 @@ +#ifndef TENSORFLOW_GRAPH_SUBGRAPH_H_ +#define TENSORFLOW_GRAPH_SUBGRAPH_H_ + +#include <string> + +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +namespace subgraph { + +// Rewrite the graph structure of "*g" to deal with feeding node +// outputs, fetching node outputs, and only running a subset of the +// graph. "fed_outputs" and "fetch_outputs" are both lists of +// output tensor identifiers in the form of +// "<name>[:<optional_output_index>]", and "target_nodes_str" is a +// lists of of target node names in "*g" "g". +// +// In the resulting graph "*g", output edges in "fed_outputs" have +// been redirected to special "_recv" nodes introduced into the graph. +// If these fed nodes are not needed in order to compute the effects +// of the nodes in "targets_nodes" and "fetch_outputs", then these may +// be omitted from the graph. +// +// In the resulting graph "*g", additional "_send" nodes are connected +// to every output in "fetch_outputs". These "_send" nodes are set up +// to execute on the device described by device_info. +// +// On success, returns OK, and sets "*g" to a version of "*g" +// that represents the portions of the graph necessary for producing +// the output of all nodes listed in "target_node_names" and fetching the +// specific node outputs specified in "fetch_outputs". +// +// On failure, returns the error status. Possible errors include: +// - fed output "node:output_index" does not exist in "*g" +// - fetch output "node:output_index" does not exist in "*g" +// - target node "node" does not exist in "*g" +Status RewriteGraphForExecution( + Graph* g, const gtl::ArraySlice<string>& fed_outputs, + const gtl::ArraySlice<string>& fetch_outputs, + const gtl::ArraySlice<string>& target_node_names, + const DeviceAttributes& device_info); + +} // namespace subgraph +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_SUBGRAPH_H_ diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc new file mode 100644 index 0000000000..ffb3e6e403 --- /dev/null +++ b/tensorflow/core/graph/subgraph_test.cc @@ -0,0 +1,305 @@ +#include "tensorflow/core/graph/subgraph.h" + +#include <string> +#include <vector> + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/status.h" +#include <gtest/gtest.h> + +// TODO(josh11b): Test setting the "device" field of a NodeDef. +// TODO(josh11b): Test that feeding won't prune targets. + +namespace tensorflow { +namespace { + +class SubgraphTest : public ::testing::Test { + protected: + SubgraphTest() : g_(new Graph(OpRegistry::Global())) { + RequireDefaultOps(); + device_info_.set_name("/job:a/replica:0/task:0/cpu:0"); + device_info_.set_device_type(DeviceType(DEVICE_CPU).type()); + device_info_.set_incarnation(0); + } + + ~SubgraphTest() override {} + + void ExpectOK(const string& gdef_ascii) { + CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef_)); + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef_, g_.get())); + } + + Node* FindNode(const string& name) { + for (Node* n : g_->nodes()) { + if (n->name() == name) return n; + } + return nullptr; + } + + bool HasNode(const string& name) { return FindNode(name) != nullptr; } + + void ExpectNodes(const string& nodes) { + int count = 0; + std::vector<string> actual_nodes; + for (Node* n : g_->nodes()) { + if (n->IsOp()) { + count++; + actual_nodes.push_back(n->name()); + } + } + std::sort(actual_nodes.begin(), actual_nodes.end()); + + LOG(INFO) << "Nodes present: " << str_util::Join(actual_nodes, " "); + + std::vector<string> expected_nodes = str_util::Split(nodes, ','); + std::sort(expected_nodes.begin(), expected_nodes.end()); + for (const string& s : expected_nodes) { + Node* n = FindNode(s); + EXPECT_TRUE(n != nullptr) << s; + if (n->def().op() == "_Send" || n->def().op() == "_Recv") { + EXPECT_EQ(device_info_.name(), n->assigned_device_name()) << s; + } + } + + EXPECT_TRUE(actual_nodes.size() == expected_nodes.size()) + << "\nActual: " << str_util::Join(actual_nodes, ",") + << "\nExpected: " << str_util::Join(expected_nodes, ","); + } + + bool HasEdge(const string& src, int src_out, const string& dst, int dst_in) { + for (const Edge* e : g_->edges()) { + if (e->src()->name() == src && e->src_output() == src_out && + e->dst()->name() == dst && e->dst_input() == dst_in) + return true; + } + return false; + } + bool HasControlEdge(const string& src, const string& dst) { + return HasEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot); + } + + string Subgraph(const string& fed_str, const string& fetch_str, + const string& targets_str) { + Graph* subgraph = new Graph(OpRegistry::Global()); + CopyGraph(*g_, subgraph); + std::vector<string> fed = + str_util::Split(fed_str, ',', str_util::SkipEmpty()); + std::vector<string> fetch = + str_util::Split(fetch_str, ',', str_util::SkipEmpty()); + std::vector<string> targets = + str_util::Split(targets_str, ',', str_util::SkipEmpty()); + + Status s = subgraph::RewriteGraphForExecution(subgraph, fed, fetch, + targets, device_info_); + if (!s.ok()) { + delete subgraph; + return s.ToString(); + } + + // Replace the graph with the subgraph for the rest of the display program + g_.reset(subgraph); + return "OK"; + } + + Graph* graph() { return g_.get(); } + + private: + GraphDef gdef_; + std::unique_ptr<Graph> g_; + DeviceAttributes device_info_; +}; + +REGISTER_OP("TestParams").Output("o: float"); +REGISTER_OP("TestInput").Output("a: float").Output("b: float"); +REGISTER_OP("TestRelu").Input("i: float").Output("o: float"); +REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); + +TEST_F(SubgraphTest, Targets1) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", Subgraph("", "", "t1")); + ExpectNodes("W1,input,t1"); +} + +TEST_F(SubgraphTest, Targets2) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: 'W1' input: 'input:1' }" + "node { name: 't2' op: 'TestMul' input: 'W2' input: 't1' }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", Subgraph("", "", "t2,t3_a")); + ExpectNodes("W1,W2,input,t1,t2,t3_a"); +} + +TEST_F(SubgraphTest, FedOutputs1) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", Subgraph("input:1", "", "t2")); + ExpectNodes("W1,W2,_recv_input_1,t1,t2"); +} + +TEST_F(SubgraphTest, FedRefNode) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 't1' op: 'TestMul' input: [ 'W2', 'W1' ] }"); + EXPECT_EQ("OK", Subgraph("W1:0", "", "t1")); + ExpectNodes("_recv_W1_0,W2,t1"); + Node* n = FindNode("_recv_W1_0"); + EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0))); +} + +TEST_F(SubgraphTest, FedOutputs2) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + // We feed input:1, but nothing connects to it, so the _recv(input:1) + // node also disappears. + EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2")); + ExpectNodes("_recv_t1_0,_recv_W2_0,t2"); +} + +TEST_F(SubgraphTest, FetchOutputs1) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", Subgraph("", "W2,input:1,t1,t2", "t2")); + ExpectNodes( + "W1,W2,input,t1,t2,_send_W2_0,_send_input_1,_send_t1_0,_send_t2_0"); +} + +TEST_F(SubgraphTest, FetchOutputs2) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", Subgraph("", "t3_a", "t2")); + ExpectNodes("W1,W2,input,t1,t2,t3_a,_send_t3_a_0"); +} + +TEST_F(SubgraphTest, ChainOfFools) { + ExpectOK( + "node { name: 'a' op: 'TestParams' }" + "node { name: 'b' op: 'TestRelu' input: 'a'}" + "node { name: 'c' op: 'TestRelu' input: 'b'}" + "node { name: 'd' op: 'TestRelu' input: 'c'}" + "node { name: 'e' op: 'TestRelu' input: 'd'}" + "node { name: 'f' op: 'TestRelu' input: 'e'}"); + EXPECT_EQ("OK", Subgraph("c:0", "b:0,e:0", "")); + ExpectNodes("a,b,_send_b_0,_recv_c_0,d,e,_send_e_0"); + EXPECT_TRUE(HasEdge("a", 0, "b", 0)); + EXPECT_TRUE(HasEdge("b", 0, "_send_b_0", 0)); + EXPECT_TRUE(HasEdge("_recv_c_0", 0, "d", 0)); + EXPECT_TRUE(HasEdge("d", 0, "e", 0)); + EXPECT_TRUE(HasEdge("e", 0, "_send_e_0", 0)); +} + +static bool HasSubstr(const string& base, const string& substr) { + bool ok = StringPiece(base).contains(substr); + EXPECT_TRUE(ok) << base << ", expected substring " << substr; + return ok; +} + +TEST_F(SubgraphTest, Errors) { + ExpectOK( + "node { name: 'a' op: 'TestParams' }" + "node { name: 'b' op: 'TestRelu' input: 'a'}" + "node { name: 'c' op: 'TestRelu' input: 'b'}" + "node { name: 'd' op: 'TestRelu' input: 'c'}" + "node { name: 'e' op: 'TestRelu' input: 'd'}" + "node { name: 'f' op: 'TestRelu' input: 'e'}"); + // Duplicated feed and fetch + EXPECT_TRUE( + HasSubstr(Subgraph("c:0", "b:0,c:0", ""), "both fed and fetched")); + // Feed not found. + EXPECT_TRUE(HasSubstr(Subgraph("foo:0", "", ""), "unable to find")); + // Fetch not found. + EXPECT_TRUE(HasSubstr(Subgraph("", "foo:0", ""), "not found")); + // Target not found. + EXPECT_TRUE(HasSubstr(Subgraph("", "", "foo"), "not found")); +} + +REGISTER_OP("In").Output("o: float"); +REGISTER_OP("Op").Input("i: float").Output("o: float"); + +static void BM_Subgraph(int iters, int num_nodes) { + DeviceAttributes device_info; + device_info.set_name("/job:a/replica:0/task:0/cpu:0"); + device_info.set_device_type(DeviceType(DEVICE_CPU).type()); + device_info.set_incarnation(0); + + testing::StopTiming(); + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* last_node = nullptr; + for (int i = 0; i < num_nodes; i++) { + string name = strings::StrCat("N", i); + if (i > 0) { + last_node = ops::UnaryOp("Op", last_node, b.opts().WithName(name)); + } else { + last_node = ops::SourceOp("In", b.opts().WithName(name)); + } + } + TF_CHECK_OK(b.ToGraph(&g)); + } + + std::vector<string> fed; + if (num_nodes > 1000) { + fed.push_back(strings::StrCat("N", num_nodes - 1000)); + } + std::vector<string> fetch; + std::vector<string> targets = {strings::StrCat("N", num_nodes - 1)}; + testing::StartTiming(); + while (--iters > 0) { + Graph* subgraph = new Graph(OpRegistry::Global()); + CopyGraph(g, subgraph); + TF_CHECK_OK(subgraph::RewriteGraphForExecution(subgraph, fed, fetch, + targets, device_info)); + delete subgraph; + } +} +BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/tensor_id.cc b/tensorflow/core/graph/tensor_id.cc new file mode 100644 index 0000000000..f789110ff3 --- /dev/null +++ b/tensorflow/core/graph/tensor_id.cc @@ -0,0 +1,41 @@ +#include "tensorflow/core/graph/tensor_id.h" + +#include <string> + +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { + +TensorId ParseTensorName(const string& name) { + return ParseTensorName(StringPiece(name.data(), name.size())); +} + +TensorId ParseTensorName(StringPiece name) { + // Parse either a name, or a name:digits. To do so, we go backwards + // from the end of the string, skipping over a run of digits. If + // we hit a ':' character, then we know we are in the 'name:digits' + // regime. Otherwise, the output index is implicitly 0, and the whole + // name string forms the first part of the tensor name. + // + // Equivalent to matching with this regexp: ([^:]+):(\\d+) + const char* base = name.data(); + const char* p = base + name.size() - 1; + int index = 0; + int mul = 1; + while (p > base && (*p >= '0' && *p <= '9')) { + index += ((*p - '0') * mul); + mul *= 10; + p--; + } + TensorId id; + if (p > base && *p == ':' && mul > 1) { + id.first = StringPiece(base, p - base); + id.second = index; + } else { + id.first = name; + id.second = 0; + } + return id; +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/tensor_id.h b/tensorflow/core/graph/tensor_id.h new file mode 100644 index 0000000000..f1f3846875 --- /dev/null +++ b/tensorflow/core/graph/tensor_id.h @@ -0,0 +1,28 @@ +#ifndef TENSORFLOW_GRAPH_TENSOR_ID_H_ +#define TENSORFLOW_GRAPH_TENSOR_ID_H_ + +#include <string> + +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { + +// Identifier for a tensor within a step. +// first == operation_name, second == output_index +// Note: does not own backing storage for name. +struct TensorId : public std::pair<StringPiece, int> { + typedef std::pair<StringPiece, int> Base; + + // Inherit the set of constructors. + using Base::pair; + + string ToString() const { return strings::StrCat(first, ":", second); } +}; + +TensorId ParseTensorName(const string& name); +TensorId ParseTensorName(StringPiece name); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_TENSOR_ID_H_ diff --git a/tensorflow/core/graph/tensor_id_test.cc b/tensorflow/core/graph/tensor_id_test.cc new file mode 100644 index 0000000000..b945774cc3 --- /dev/null +++ b/tensorflow/core/graph/tensor_id_test.cc @@ -0,0 +1,77 @@ +#include "tensorflow/core/graph/tensor_id.h" +#include <gtest/gtest.h> +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +static string ParseHelper(const string& n) { + TensorId id = ParseTensorName(n); + return strings::StrCat(id.first, ":", id.second); +} + +TEST(TensorIdTest, ParseTensorName) { + EXPECT_EQ(ParseHelper("W1"), "W1:0"); + EXPECT_EQ(ParseHelper("weights:0"), "weights:0"); + EXPECT_EQ(ParseHelper("W1:1"), "W1:1"); + EXPECT_EQ(ParseHelper("W1:17"), "W1:17"); + EXPECT_EQ(ParseHelper("xyz1_17"), "xyz1_17:0"); +} + +static uint32 Skewed(random::SimplePhilox* rnd, int max_log) { + const uint32 space = 1 << (rnd->Rand32() % (max_log + 1)); + return rnd->Rand32() % space; +} + +static void BM_ParseTensorName(int iters, int arg) { + testing::StopTiming(); + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + std::vector<string> names; + for (int i = 0; i < 100; i++) { + string name; + switch (arg) { + case 0: { // Generate random names + size_t len = Skewed(&rnd, 4); + while (name.size() < len) { + name += rnd.OneIn(4) ? '0' : 'a'; + } + if (rnd.OneIn(3)) { + strings::StrAppend(&name, ":", rnd.Uniform(12)); + } + break; + } + case 1: + name = "W1"; + break; + case 2: + name = "t0003"; + break; + case 3: + name = "weights"; + break; + case 4: + name = "weights:17"; + break; + default: + LOG(FATAL) << "Unexpected arg"; + break; + } + names.push_back(name); + } + testing::StartTiming(); + TensorId id; + int index = 0; + int sum = 0; + while (--iters > 0) { + id = ParseTensorName(names[index++ % names.size()]); + sum += id.second; + } + VLOG(2) << sum; // Prevent compiler from eliminating loop body +} +BENCHMARK(BM_ParseTensorName)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc new file mode 100644 index 0000000000..e49d5e819a --- /dev/null +++ b/tensorflow/core/graph/testlib.cc @@ -0,0 +1,299 @@ +#include "tensorflow/core/graph/testlib.h" + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +namespace test { +namespace graph { + +Node* Send(Graph* g, Node* input, const string& tensor, const string& sender, + const uint64 sender_incarnation, const string& receiver) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send") + .Input(input, 0) + .Attr("tensor_name", tensor) + .Attr("send_device", sender) + .Attr("send_device_incarnation", + static_cast<int64>(sender_incarnation)) + .Attr("recv_device", receiver) + .Finalize(g, &ret)); + return ret; +} + +Node* Recv(Graph* g, const string& tensor, const string& type, + const string& sender, const uint64 sender_incarnation, + const string& receiver) { + Node* ret; + DataType dtype; + CHECK(DataTypeFromString(type, &dtype)); + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv") + .Attr("tensor_type", dtype) + .Attr("tensor_name", tensor) + .Attr("send_device", sender) + .Attr("send_device_incarnation", + static_cast<int64>(sender_incarnation)) + .Attr("recv_device", receiver) + .Finalize(g, &ret)); + return ret; +} + +Node* Constant(Graph* g, const Tensor& tensor) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const") + .Attr("dtype", tensor.dtype()) + .Attr("value", tensor) + .Finalize(g, &ret)); + return ret; +} + +Node* Constant(Graph* g, const Tensor& tensor, const string& name) { + Node* ret; + TF_CHECK_OK(NodeBuilder(name, "Const") + .Attr("dtype", tensor.dtype()) + .Attr("value", tensor) + .Finalize(g, &ret)); + return ret; +} + +Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable") + .Attr("dtype", dtype) + .Attr("shape", shape) + .Finalize(g, &ret)); + return ret; +} + +Node* Assign(Graph* g, Node* var, Node* val) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign") + .Input(var) + .Input(val) + .Attr("use_locking", true) + .Finalize(g, &ret)); + return ret; +} + +Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes, + bool keep_dims) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce) + .Input(data) + .Input(axes) + .Attr("keep_dims", keep_dims) + .Finalize(g, &ret)); + return ret; +} + +Node* QuantizeToUINT8(Graph* g, Node* data) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize") + .Input(data) + .Attr("T", DT_QUINT8) + .Attr("max_range", 1.0f) + .Attr("min_range", -1.0f) + .Finalize(g, &ret)); + return ret; +} + +Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a, + bool transpose_b) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul") + .Input(in0) + .Input(in1) + .Attr("transpose_a", transpose_a) + .Attr("transpose_b", transpose_b) + .Finalize(g, &ret)); + return ret; +} + +Node* RandomNumberGenerator(const string& op, Graph* g, Node* input, + DataType dtype) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), op) + .Input(input) + .Attr("dtype", dtype) + .Attr("seed", 0) + .Finalize(g, &ret)); + return ret; +} + +Node* RandomUniform(Graph* g, Node* input, DataType dtype) { + return RandomNumberGenerator("RandomUniform", g, input, dtype); +} + +Node* RandomGaussian(Graph* g, Node* input, DataType dtype) { + return RandomNumberGenerator("RandomStandardNormal", g, input, dtype); +} + +Node* RandomParameters(Graph* g, Node* input, DataType dtype) { + return RandomNumberGenerator("RandomParameters", g, input, dtype); +} + +Node* Unary(Graph* g, const string& func, Node* input, int index) { + Node* ret; + TF_CHECK_OK( + NodeBuilder(g->NewName("n"), func).Input(input, index).Finalize(g, &ret)); + return ret; +} + +Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), func) + .Input(in0) + .Input(in1) + .Finalize(g, &ret)); + return ret; +} + +Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) { + Node* ret; + auto b = NodeBuilder(g->NewName("n"), func); + for (Node* n : ins) b = b.Input(n); + TF_CHECK_OK(b.Finalize(g, &ret)); + return ret; +} + +Node* Identity(Graph* g, Node* input, int index) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity") + .Input(input, index) + .Finalize(g, &ret)); + return ret; +} + +Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); } + +Node* Error(Graph* g, Node* input, const string& errmsg) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error") + .Input(input) + .Attr("message", errmsg) + .Finalize(g, &ret)); + return ret; +} + +Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) { + DCHECK(out_type != invalid_type); + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType") + .Attr("TIn", out_type) + .Attr("TOut", invalid_type) + .Finalize(g, &ret)); + return ret; +} + +Node* Delay(Graph* g, Node* input, Microseconds delay_micros) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay") + .Input(input) + .Attr("micros", delay_micros.value()) + .Finalize(g, &ret)); + return ret; +} + +Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp") + .ControlInputs(control_inputs) + .Finalize(g, &ret)); + return ret; +} + +Node* Switch(Graph* g, Node* in0, Node* in1) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch") + .Input(in0) + .Input(in1) + .Finalize(g, &ret)); + return ret; +} + +Node* Enter(Graph* g, Node* input, const string& frame_name) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter") + .Input(input) + .Attr("frame_name", frame_name) + .Finalize(g, &ret)); + return ret; +} + +Node* Exit(Graph* g, Node* input) { + Node* ret; + TF_CHECK_OK( + NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret)); + return ret; +} + +Node* Merge(Graph* g, Node* in0, Node* in1) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge") + .Input({in0, in1}) + .Finalize(g, &ret)); + return ret; +} + +Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in) { + std::vector<NodeBuilder::NodeOut> inputs; + inputs.reserve(remaining_in.size() + 1); + inputs.emplace_back(in0); + for (const string& in_name : remaining_in) { + inputs.emplace_back(in_name, 0, inputs[0].dt); + } + + Node* ret; + TF_CHECK_OK( + NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret)); + return ret; +} + +Node* Next(Graph* g, const string& name, Node* input) { + Node* ret; + TF_CHECK_OK( + NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret)); + return ret; +} + +Node* LoopCond(Graph* g, Node* input) { + Node* ret; + TF_CHECK_OK( + NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret)); + return ret; +} + +Node* Less(Graph* g, Node* in0, Node* in1) { + return Binary(g, "Less", in0, in1); +} + +Node* Select(Graph* g, Node* c, Node* inx, Node* iny) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select") + .Input(c) + .Input(inx) + .Input(iny) + .Finalize(g, &ret)); + return ret; +} + +Node* Cast(Graph* g, Node* in, DataType dst) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast") + .Input(in) + .Attr("DstT", dst) + .Finalize(g, &ret)); + return ret; +} + +void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); } + +} // end namespace graph +} // end namespace test +} // end namespace tensorflow diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h new file mode 100644 index 0000000000..11905bbf6a --- /dev/null +++ b/tensorflow/core/graph/testlib.h @@ -0,0 +1,141 @@ +// DEPRECATED: Use GraphDefBuilder instead. + +#ifndef TENSORFLOW_GRAPH_TESTLIB_H_ +#define TENSORFLOW_GRAPH_TESTLIB_H_ + +#include <string> +#include <vector> + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { +namespace test { +namespace graph { + +// Converts "g" into its corresponding GraphDef "def". +// DEPRECATED: call g->ToGraphDef(def) instead. +void ToGraphDef(Graph* g, GraphDef* def); + +// A few helpers to construct a graph. + +// Adds a node in "g" producing a constant "tensor". +Node* Constant(Graph* g, const Tensor& tensor); +Node* Constant(Graph* g, const Tensor& tensor, const string& name); + +// Adds a variable in "g" of the given "shape" and "dtype". +Node* Var(Graph* g, const DataType dtype, const TensorShape& shape); + +// Adds an assign node in "g" which assigns "val" into "var". +Node* Assign(Graph* g, Node* var, Node* val); + +// Adds a send node "g" sending "input" as a named "tensor" from +// "sender" to "receiver". +Node* Send(Graph* g, Node* input, const string& tensor, const string& sender, + const uint64 sender_incarnation, const string& receiver); + +// Adds a recv node in "g" receiving a named "tensor" from "sender" +// to "receiver". +Node* Recv(Graph* g, const string& tensor, const string& type, + const string& sender, const uint64 sender_incarnation, + const string& receiver); + +// Adds a reduction "node" in "g" doing sum(data, axes). "reduce" is +// a reduction, e.g., Sum, Max, Min, Mean, etc. +Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes, + bool keep_dims = false); + +// Adds a Matmul node in g doing in0.contract(in1). +Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a, + bool transpose_b); + +// Adds a Quantize node into g that quantize floats into QUINT8. The range of +// the input float tensor is assumed to be [-1, 1]. +Node* QuantizeToUINT8(Graph* g, Node* data); + +// Adds a unary function "func" "node" in "g" taking "input". +Node* Unary(Graph* g, const string& func, Node* input, int index = 0); + +// Adds an identity node in "g" taking "input" and producing an +// identity copy. +Node* Identity(Graph* g, Node* input, int index = 0); + +// Adds a binary function "func" node in "g" taking "in0" and "in1". +// Requires that "func" name an attr-style Op. +Node* Binary(Graph* g, const string& func, Node* in0, Node* in1); + +// Adds a function "func" node in "g" taking inputs "ins". +// Requires that "func" name an attr-style Op. +Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins); + +// Adds a binary add node in "g" doing in0 + in1. +Node* Add(Graph* g, Node* in0, Node* in1); + +// Generates random unit uniform distribution of the input shape. +Node* RandomUniform(Graph* g, Node* input, DataType dtype); + +// Generates random unit normal distribution of the input shape. +Node* RandomGaussian(Graph* g, Node* input, DataType dtype); + +// Generates random parameters from the truncated standard normal distribution +// of the nput shape +Node* RandomParameters(Graph* g, Node* input, DataType dtype); + +// Adds an error node in "g". The node's computation always +// generates an error with the given error message "errmsg". +Node* Error(Graph* g, Node* input, const string& errmsg); + +// Adds a node that generates a invalid ref output. +Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type); + +// Adds a node in "g". Its Compute() sleeps a while and outputs the +// input (i.e., same as identity). +Node* Delay(Graph* g, Node* input, Microseconds delay_micros); + +// Adds a no-op "node" in "g", with control inputs from all nodes in +// control_inputs vector. +Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs); + +// Adds a Switch node in "g". If "in1" is true, it forwards "in0" to +// output 1. Otherwise, it forwards "in0" to output 0. +Node* Switch(Graph* g, Node* in0, Node* in1); + +// Adds an Enter node in "g", which enters a new frame. +Node* Enter(Graph* g, Node* input, const string& frame_name); + +// Adds an Exit node in "g", which exits a frame. +Node* Exit(Graph* g, Node* input); + +// Adds a Merge node in "g" with two inputs "in0" and "in1". +Node* Merge(Graph* g, Node* in0, Node* in1); + +// Adds a Merge node in "g". The first input is "in0", the remaining +// inputs are only given by their names in remaining_in. +Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in); + +// Adds a NextIteration node in "g", which makes its input available +// to the next iteration. +Node* Next(Graph* g, const string& name, Node* input); + +// Adds a LoopCond node in "g", representing the "pivot" termination +// condition of a loop. +Node* LoopCond(Graph* g, Node* input); + +// Adds a less node in "g", which returns true iff "in0" < "in1". +Node* Less(Graph* g, Node* in0, Node* in1); + +// Adds a select node in "g", which outputs either "inx" or "iny" +// depending on the boolean value of "c". +Node* Select(Graph* g, Node* c, Node* inx, Node* iny); + +// Casts "in" into data type "dst". +Node* Cast(Graph* g, Node* in, DataType dst); + +} // end namespace graph +} // end namespace test +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPH_TESTLIB_H_ diff --git a/tensorflow/core/graph/types.h b/tensorflow/core/graph/types.h new file mode 100644 index 0000000000..41400611a9 --- /dev/null +++ b/tensorflow/core/graph/types.h @@ -0,0 +1,17 @@ +#ifndef TENSORFLOW_GRAPH_TYPES_H_ +#define TENSORFLOW_GRAPH_TYPES_H_ + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/gtl/int_type.h" + +namespace tensorflow { + +// We model running time in microseconds. +TF_LIB_GTL_DEFINE_INT_TYPE(Microseconds, int64); + +// We model size in bytes. +TF_LIB_GTL_DEFINE_INT_TYPE(Bytes, int64); + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_TYPES_H_ |