aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph.cc
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
commitf41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch)
treeef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/graph/graph.cc
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108
Diffstat (limited to 'tensorflow/core/graph/graph.cc')
-rw-r--r--tensorflow/core/graph/graph.cc319
1 files changed, 319 insertions, 0 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
new file mode 100644
index 0000000000..0c268a51a9
--- /dev/null
+++ b/tensorflow/core/graph/graph.cc
@@ -0,0 +1,319 @@
+#include "tensorflow/core/graph/graph.h"
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+// Node
+
+string Node::DebugString() const {
+ if (this == nullptr) {
+ return "{nullptr}";
+ }
+ string ret = strings::StrCat("{name:'", name(), "' id:", id_);
+ if (IsSource()) {
+ strings::StrAppend(&ret, " source}");
+ } else if (IsSink()) {
+ strings::StrAppend(&ret, " sink}");
+ } else {
+ strings::StrAppend(&ret, " op device:");
+ strings::StrAppend(&ret, "{", assigned_device_name_, "}");
+ strings::StrAppend(&ret, " def:{", SummarizeNodeDef(def()), "}}");
+ }
+ return ret;
+}
+
+Node::Node()
+ : id_(-1), cost_id_(-1), props_(nullptr), assigned_device_name_() {}
+
+Node::~Node() {
+ if (props_) {
+ props_->Unref();
+ }
+}
+
+void Node::Initialize(int id, int cost_id, Properties* props) {
+ DCHECK_EQ(id_, -1);
+ DCHECK(in_edges_.empty());
+ DCHECK(out_edges_.empty());
+ id_ = id;
+ cost_id_ = cost_id;
+
+ // Unref the old, assign the new properties.
+ if (props_) {
+ props_->Unref();
+ }
+ props_ = props;
+}
+
+void Node::Clear() {
+ in_edges_.clear();
+ out_edges_.clear();
+ id_ = -1;
+ cost_id_ = -1;
+
+ if (props_) {
+ props_->Unref();
+ props_ = nullptr;
+ }
+
+ assigned_device_name_.clear();
+}
+
+gtl::iterator_range<NeighborIter> Node::out_nodes() const {
+ return gtl::make_range(NeighborIter(out_edges_.begin(), false),
+ NeighborIter(out_edges_.end(), false));
+}
+
+gtl::iterator_range<NeighborIter> Node::in_nodes() const {
+ return gtl::make_range(NeighborIter(in_edges_.begin(), true),
+ NeighborIter(in_edges_.end(), true));
+}
+
+// Node::Properties
+
+Node::Properties::Properties(const OpDef* op_def, const NodeDef& node_def,
+ const DataTypeSlice inputs,
+ const DataTypeSlice outputs)
+ : op_def_(op_def),
+ node_def_(node_def),
+ input_types_(inputs.begin(), inputs.end()),
+ output_types_(outputs.begin(), outputs.end()) {}
+
+Node::Properties::~Properties() {}
+
+// Graph
+
+Graph::Graph(const OpRegistryInterface* ops)
+ : ops_(ops), arena_(8 << 10 /* 8kB */) {
+ // Source and sink have no endpoints, just control edges.
+ NodeDef def;
+ def.set_name("_SOURCE");
+ def.set_op("NoOp");
+ Status status;
+ Node* source = AddNode(def, &status);
+ TF_CHECK_OK(status);
+ CHECK_EQ(source->id(), kSourceId);
+
+ def.set_name("_SINK");
+ Node* sink = AddNode(def, &status);
+ TF_CHECK_OK(status);
+ CHECK_EQ(sink->id(), kSinkId);
+
+ AddControlEdge(source, sink);
+}
+
+Graph::~Graph() {
+ // Manually call the destructors for all the Nodes we constructed using
+ // placement new.
+ for (Node* node : nodes_) {
+ if (node != nullptr) {
+ node->~Node();
+ }
+ }
+ for (Node* node : free_nodes_) {
+ node->~Node();
+ }
+ // Edges have no destructor, and we arena-allocated them, so no need to
+ // destroy them.
+}
+
+Node* Graph::AddNode(const NodeDef& node_def, Status* status) {
+ const OpDef* op_def = ops_->LookUp(node_def.op(), status);
+ if (op_def == nullptr) return nullptr;
+
+ // TODO(vrv,josh11b): Find a location higher in the stack to add these defaults
+ // to the NodeDef.
+ NodeDef node_def_with_defaults(node_def);
+ AddDefaultsToNodeDef(*op_def, &node_def_with_defaults);
+
+ DataTypeVector inputs;
+ DataTypeVector outputs;
+ status->Update(
+ InOutTypesForNode(node_def_with_defaults, *op_def, &inputs, &outputs));
+ if (!status->ok()) {
+ *status = AttachDef(*status, node_def_with_defaults);
+ return nullptr;
+ }
+
+ Node* node = AllocateNode(
+ new Node::Properties(op_def, node_def_with_defaults, inputs, outputs),
+ nullptr);
+ return node;
+}
+
+Node* Graph::CopyNode(Node* node) {
+ DCHECK(!node->IsSource());
+ DCHECK(!node->IsSink());
+ Node::Properties* props = node->properties();
+ props->Ref();
+ Node* copy = AllocateNode(props, node);
+ copy->set_assigned_device_name(node->assigned_device_name());
+ return copy;
+}
+
+void Graph::RemoveNode(Node* node) {
+ DCHECK(IsValidNode(node)) << node->DebugString();
+ DCHECK(!node->IsSource());
+ DCHECK(!node->IsSink());
+
+ // Remove any edges involving this node.
+ while (!node->in_edges_.empty()) {
+ RemoveEdge(*node->in_edges_.begin());
+ }
+ while (!node->out_edges_.empty()) {
+ RemoveEdge(*node->out_edges_.begin());
+ }
+ ReleaseNode(node);
+}
+
+const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
+ DCHECK(IsValidNode(source)) << source->DebugString();
+ DCHECK(IsValidNode(dest)) << dest->DebugString();
+
+ // source/sink must only be linked via control slots, and
+ // control slots must only be linked to control slots.
+ if (source == source_node() || dest == sink_node() || x == kControlSlot ||
+ y == kControlSlot) {
+ DCHECK_EQ(x, kControlSlot) << source->DebugString();
+ DCHECK_EQ(y, kControlSlot) << dest->DebugString();
+ }
+
+ Edge* e = nullptr;
+ if (free_edges_.empty()) {
+ e = new (arena_.Alloc(sizeof(Edge))) Edge; // placement new
+ } else {
+ e = free_edges_.back();
+ free_edges_.pop_back();
+ }
+ e->id_ = edges_.size();
+ e->src_ = source;
+ e->dst_ = dest;
+ e->src_output_ = x;
+ e->dst_input_ = y;
+ CHECK(source->out_edges_.insert(e).second);
+ CHECK(dest->in_edges_.insert(e).second);
+ edges_.push_back(e);
+ edge_set_.insert(e);
+ return e;
+}
+
+void Graph::RemoveEdge(const Edge* e) {
+ DCHECK(IsValidNode(e->src_)) << e->src_->DebugString();
+ DCHECK(IsValidNode(e->dst_)) << e->dst_->DebugString();
+ CHECK_EQ(e->src_->out_edges_.erase(e), 1);
+ CHECK_EQ(e->dst_->in_edges_.erase(e), 1);
+ CHECK_EQ(e, edges_[e->id_]);
+
+ CHECK_EQ(edge_set_.erase(e), 1);
+ edges_[e->id_] = nullptr;
+
+ Edge* del = const_cast<Edge*>(e);
+ del->src_ = nullptr;
+ del->dst_ = nullptr;
+ del->id_ = -1;
+ del->src_output_ = kControlSlot - 1;
+ del->dst_input_ = kControlSlot - 1;
+ free_edges_.push_back(del);
+}
+
+namespace {
+
+void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
+ if (src_slot == Graph::kControlSlot) {
+ dst->add_input(strings::StrCat("^", src_name));
+ } else if (src_slot == 0) {
+ dst->add_input(src_name.data(), src_name.size());
+ } else {
+ dst->add_input(strings::StrCat(src_name, ":", src_slot));
+ }
+}
+
+} // namespace
+
+void Graph::ToGraphDef(GraphDef* graph_def) const {
+ graph_def->Clear();
+ std::vector<const Edge*>
+ inputs; // Construct this outside the loop for speed.
+ for (const Node* node : nodes()) {
+ if (!node->IsOp()) continue;
+ NodeDef* node_def = graph_def->add_node();
+ *node_def = node->def();
+
+ // Use the node's assigned device, if any, instead of the device requested
+ // in the NodeDef.
+ if (!node->assigned_device_name().empty()) {
+ node_def->set_device(node->assigned_device_name());
+ }
+
+ // Get the inputs for this Node. We make sure control inputs are
+ // after data inputs, as required by GraphDef.
+ inputs.clear();
+ inputs.resize(node->num_inputs(), nullptr);
+ for (const Edge* edge : node->in_edges()) {
+ if (edge->IsControlEdge()) {
+ inputs.push_back(edge);
+ } else {
+ DCHECK(inputs[edge->dst_input()] == nullptr);
+ inputs[edge->dst_input()] = edge;
+ }
+ }
+ node_def->clear_input();
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ const Edge* edge = inputs[i];
+ if (edge == nullptr) {
+ node_def->add_input(node->def().input(i));
+ } else {
+ const Node* src = edge->src();
+ if (!src->IsOp()) continue;
+ AddInput(node_def, src->name(), edge->src_output());
+ }
+ }
+ }
+}
+
+string Graph::NewName(StringPiece prefix) {
+ return strings::StrCat(prefix, "/_", name_counter_++);
+}
+
+gtl::iterator_range<NodeIter> Graph::nodes() const {
+ // Note that NodeId 0 is always valid since we don't let the source
+ // node be removed from the graph.
+ return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
+}
+
+bool Graph::IsValidNode(Node* node) const {
+ if (node == nullptr) return false;
+ const int id = node->id();
+ if (id < 0 || static_cast<size_t>(id) >= nodes_.size()) return false;
+ return nodes_[id] == node;
+}
+
+Node* Graph::AllocateNode(Node::Properties* props, const Node* cost_node) {
+ Node* node = nullptr;
+ if (free_nodes_.empty()) {
+ node = new (arena_.Alloc(sizeof(Node))) Node; // placement new
+ } else {
+ node = free_nodes_.back();
+ free_nodes_.pop_back();
+ }
+ const int id = nodes_.size();
+ int cost_id = cost_node ? cost_node->cost_id() : id;
+ node->Initialize(id, cost_id, props);
+ nodes_.push_back(node);
+ return node;
+}
+
+void Graph::ReleaseNode(Node* node) {
+ DCHECK(IsValidNode(node)) << node->DebugString();
+ nodes_[node->id()] = nullptr;
+ free_nodes_.push_back(node);
+ node->Clear();
+}
+
+} // namespace tensorflow