From f41959ccb2d9d4c722fe8fc3351401d53bcf4900 Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Fri, 6 Nov 2015 16:27:58 -0800 Subject: TensorFlow: Initial commit of TensorFlow library. TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108 --- tensorflow/core/graph/graph.cc | 319 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 319 insertions(+) create mode 100644 tensorflow/core/graph/graph.cc (limited to 'tensorflow/core/graph/graph.cc') diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc new file mode 100644 index 0000000000..0c268a51a9 --- /dev/null +++ b/tensorflow/core/graph/graph.cc @@ -0,0 +1,319 @@ +#include "tensorflow/core/graph/graph.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// Node + +string Node::DebugString() const { + if (this == nullptr) { + return "{nullptr}"; + } + string ret = strings::StrCat("{name:'", name(), "' id:", id_); + if (IsSource()) { + strings::StrAppend(&ret, " source}"); + } else if (IsSink()) { + strings::StrAppend(&ret, " sink}"); + } else { + strings::StrAppend(&ret, " op device:"); + strings::StrAppend(&ret, "{", assigned_device_name_, "}"); + strings::StrAppend(&ret, " def:{", SummarizeNodeDef(def()), "}}"); + } + return ret; +} + +Node::Node() + : id_(-1), cost_id_(-1), props_(nullptr), assigned_device_name_() {} + +Node::~Node() { + if (props_) { + props_->Unref(); + } +} + +void Node::Initialize(int id, int cost_id, Properties* props) { + DCHECK_EQ(id_, -1); + DCHECK(in_edges_.empty()); + DCHECK(out_edges_.empty()); + id_ = id; + cost_id_ = cost_id; + + // Unref the old, assign the new properties. + if (props_) { + props_->Unref(); + } + props_ = props; +} + +void Node::Clear() { + in_edges_.clear(); + out_edges_.clear(); + id_ = -1; + cost_id_ = -1; + + if (props_) { + props_->Unref(); + props_ = nullptr; + } + + assigned_device_name_.clear(); +} + +gtl::iterator_range Node::out_nodes() const { + return gtl::make_range(NeighborIter(out_edges_.begin(), false), + NeighborIter(out_edges_.end(), false)); +} + +gtl::iterator_range Node::in_nodes() const { + return gtl::make_range(NeighborIter(in_edges_.begin(), true), + NeighborIter(in_edges_.end(), true)); +} + +// Node::Properties + +Node::Properties::Properties(const OpDef* op_def, const NodeDef& node_def, + const DataTypeSlice inputs, + const DataTypeSlice outputs) + : op_def_(op_def), + node_def_(node_def), + input_types_(inputs.begin(), inputs.end()), + output_types_(outputs.begin(), outputs.end()) {} + +Node::Properties::~Properties() {} + +// Graph + +Graph::Graph(const OpRegistryInterface* ops) + : ops_(ops), arena_(8 << 10 /* 8kB */) { + // Source and sink have no endpoints, just control edges. + NodeDef def; + def.set_name("_SOURCE"); + def.set_op("NoOp"); + Status status; + Node* source = AddNode(def, &status); + TF_CHECK_OK(status); + CHECK_EQ(source->id(), kSourceId); + + def.set_name("_SINK"); + Node* sink = AddNode(def, &status); + TF_CHECK_OK(status); + CHECK_EQ(sink->id(), kSinkId); + + AddControlEdge(source, sink); +} + +Graph::~Graph() { + // Manually call the destructors for all the Nodes we constructed using + // placement new. + for (Node* node : nodes_) { + if (node != nullptr) { + node->~Node(); + } + } + for (Node* node : free_nodes_) { + node->~Node(); + } + // Edges have no destructor, and we arena-allocated them, so no need to + // destroy them. +} + +Node* Graph::AddNode(const NodeDef& node_def, Status* status) { + const OpDef* op_def = ops_->LookUp(node_def.op(), status); + if (op_def == nullptr) return nullptr; + + // TODO(vrv,josh11b): Find a location higher in the stack to add these defaults + // to the NodeDef. + NodeDef node_def_with_defaults(node_def); + AddDefaultsToNodeDef(*op_def, &node_def_with_defaults); + + DataTypeVector inputs; + DataTypeVector outputs; + status->Update( + InOutTypesForNode(node_def_with_defaults, *op_def, &inputs, &outputs)); + if (!status->ok()) { + *status = AttachDef(*status, node_def_with_defaults); + return nullptr; + } + + Node* node = AllocateNode( + new Node::Properties(op_def, node_def_with_defaults, inputs, outputs), + nullptr); + return node; +} + +Node* Graph::CopyNode(Node* node) { + DCHECK(!node->IsSource()); + DCHECK(!node->IsSink()); + Node::Properties* props = node->properties(); + props->Ref(); + Node* copy = AllocateNode(props, node); + copy->set_assigned_device_name(node->assigned_device_name()); + return copy; +} + +void Graph::RemoveNode(Node* node) { + DCHECK(IsValidNode(node)) << node->DebugString(); + DCHECK(!node->IsSource()); + DCHECK(!node->IsSink()); + + // Remove any edges involving this node. + while (!node->in_edges_.empty()) { + RemoveEdge(*node->in_edges_.begin()); + } + while (!node->out_edges_.empty()) { + RemoveEdge(*node->out_edges_.begin()); + } + ReleaseNode(node); +} + +const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) { + DCHECK(IsValidNode(source)) << source->DebugString(); + DCHECK(IsValidNode(dest)) << dest->DebugString(); + + // source/sink must only be linked via control slots, and + // control slots must only be linked to control slots. + if (source == source_node() || dest == sink_node() || x == kControlSlot || + y == kControlSlot) { + DCHECK_EQ(x, kControlSlot) << source->DebugString(); + DCHECK_EQ(y, kControlSlot) << dest->DebugString(); + } + + Edge* e = nullptr; + if (free_edges_.empty()) { + e = new (arena_.Alloc(sizeof(Edge))) Edge; // placement new + } else { + e = free_edges_.back(); + free_edges_.pop_back(); + } + e->id_ = edges_.size(); + e->src_ = source; + e->dst_ = dest; + e->src_output_ = x; + e->dst_input_ = y; + CHECK(source->out_edges_.insert(e).second); + CHECK(dest->in_edges_.insert(e).second); + edges_.push_back(e); + edge_set_.insert(e); + return e; +} + +void Graph::RemoveEdge(const Edge* e) { + DCHECK(IsValidNode(e->src_)) << e->src_->DebugString(); + DCHECK(IsValidNode(e->dst_)) << e->dst_->DebugString(); + CHECK_EQ(e->src_->out_edges_.erase(e), 1); + CHECK_EQ(e->dst_->in_edges_.erase(e), 1); + CHECK_EQ(e, edges_[e->id_]); + + CHECK_EQ(edge_set_.erase(e), 1); + edges_[e->id_] = nullptr; + + Edge* del = const_cast(e); + del->src_ = nullptr; + del->dst_ = nullptr; + del->id_ = -1; + del->src_output_ = kControlSlot - 1; + del->dst_input_ = kControlSlot - 1; + free_edges_.push_back(del); +} + +namespace { + +void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) { + if (src_slot == Graph::kControlSlot) { + dst->add_input(strings::StrCat("^", src_name)); + } else if (src_slot == 0) { + dst->add_input(src_name.data(), src_name.size()); + } else { + dst->add_input(strings::StrCat(src_name, ":", src_slot)); + } +} + +} // namespace + +void Graph::ToGraphDef(GraphDef* graph_def) const { + graph_def->Clear(); + std::vector + inputs; // Construct this outside the loop for speed. + for (const Node* node : nodes()) { + if (!node->IsOp()) continue; + NodeDef* node_def = graph_def->add_node(); + *node_def = node->def(); + + // Use the node's assigned device, if any, instead of the device requested + // in the NodeDef. + if (!node->assigned_device_name().empty()) { + node_def->set_device(node->assigned_device_name()); + } + + // Get the inputs for this Node. We make sure control inputs are + // after data inputs, as required by GraphDef. + inputs.clear(); + inputs.resize(node->num_inputs(), nullptr); + for (const Edge* edge : node->in_edges()) { + if (edge->IsControlEdge()) { + inputs.push_back(edge); + } else { + DCHECK(inputs[edge->dst_input()] == nullptr); + inputs[edge->dst_input()] = edge; + } + } + node_def->clear_input(); + for (size_t i = 0; i < inputs.size(); ++i) { + const Edge* edge = inputs[i]; + if (edge == nullptr) { + node_def->add_input(node->def().input(i)); + } else { + const Node* src = edge->src(); + if (!src->IsOp()) continue; + AddInput(node_def, src->name(), edge->src_output()); + } + } + } +} + +string Graph::NewName(StringPiece prefix) { + return strings::StrCat(prefix, "/_", name_counter_++); +} + +gtl::iterator_range Graph::nodes() const { + // Note that NodeId 0 is always valid since we don't let the source + // node be removed from the graph. + return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids())); +} + +bool Graph::IsValidNode(Node* node) const { + if (node == nullptr) return false; + const int id = node->id(); + if (id < 0 || static_cast(id) >= nodes_.size()) return false; + return nodes_[id] == node; +} + +Node* Graph::AllocateNode(Node::Properties* props, const Node* cost_node) { + Node* node = nullptr; + if (free_nodes_.empty()) { + node = new (arena_.Alloc(sizeof(Node))) Node; // placement new + } else { + node = free_nodes_.back(); + free_nodes_.pop_back(); + } + const int id = nodes_.size(); + int cost_id = cost_node ? cost_node->cost_id() : id; + node->Initialize(id, cost_id, props); + nodes_.push_back(node); + return node; +} + +void Graph::ReleaseNode(Node* node) { + DCHECK(IsValidNode(node)) << node->DebugString(); + nodes_[node->id()] = nullptr; + free_nodes_.push_back(node); + node->Clear(); +} + +} // namespace tensorflow -- cgit v1.2.3