diff options
Diffstat (limited to 'tensorflow/core/graph/graph.h')
-rw-r--r-- | tensorflow/core/graph/graph.h | 440 |
1 files changed, 440 insertions, 0 deletions
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h new file mode 100644 index 0000000000..030e471bf4 --- /dev/null +++ b/tensorflow/core/graph/graph.h @@ -0,0 +1,440 @@ +// A Graph describes a set of computations that are to be +// performed, as well as the dependencies between those +// compuations. The basic model is a DAG (directed acyclic graph) with +// * internal nodes representing computational operations to be performed; +// * edges represent dependencies, indicating the target may only be +// executed once the source has completed; and +// * predefined "source" (start) and "sink" (finish) nodes -- the source +// should be the only node that doesn't depend on anything, and the sink +// should be the only node that nothing depends on. +// +// Note: Node ids are intended to be relatively dense in the +// 0..max_id range, but there may be gaps since ids won't be reused. +// +// Note: Some dependencies between operations are due to one operation +// consuming the output of another. In fact operations can produce +// multiple outputs and consume multiple inputs, and some +// optimizations will care about which specific outputs are connected +// to which specific inputs. We therefore represent data dependency +// between output O of layer A and input I of layer B using +// "input index" and "output index" labels per edge. + +#ifndef TENSORFLOW_GRAPH_GRAPH_H_ +#define TENSORFLOW_GRAPH_GRAPH_H_ + +#include <functional> +#include <string> +#include <vector> +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/edgeset.h" +#include "tensorflow/core/lib/core/arena.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/iterator_range.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +class Edge; +class EdgeSetTest; +class Graph; +class Node; + +class NeighborIter; // Declared below +class NodeIter; // Declared below + +class Node { + public: + string DebugString() const; + int id() const { return id_; } + int cost_id() const { return cost_id_; } + const string& name() const { return props_->node_def_.name(); } + const string& type_string() const { return props_->node_def_.op(); } + const NodeDef& def() const { return props_->node_def_; } + const OpDef& op_def() const { return *props_->op_def_; } + + // input and output types + int num_inputs() const { return props_->input_types_.size(); } + DataType input_type(int i) const { return props_->input_types_[i]; } + const DataTypeVector& input_types() const { return props_->input_types_; } + + int num_outputs() const { return props_->output_types_.size(); } + DataType output_type(int o) const { return props_->output_types_[o]; } + const DataTypeVector& output_types() const { return props_->output_types_; } + + // This gives the device the runtime has assigned this node to. If + // you want the device the user requested, use def().device() instead. + // TODO(josh11b): Validate that the assigned_device, if not empty: + // fully specifies a device, and satisfies def().device(). + // TODO(josh11b): Move device_name outside of Node into a NodeId->DeviceName + // map. + string assigned_device_name() const { return assigned_device_name_; } + void set_assigned_device_name(const string& device_name) { + assigned_device_name_ = device_name; + } + + // Get the neighboring nodes via edges either in or out of this node. + gtl::iterator_range<NeighborIter> in_nodes() const; + gtl::iterator_range<NeighborIter> out_nodes() const; + const EdgeSet& in_edges() const { return in_edges_; } + const EdgeSet& out_edges() const { return out_edges_; } + + // Node type helpers. + bool IsSource() const { return id() == 0; } + bool IsSink() const { return id() == 1; } + // Anything other than the special Source & Sink nodes. + bool IsOp() const { return id() > 1; } + + private: + friend class Graph; + Node(); + ~Node(); + + class Properties : public core::RefCounted { + public: + Properties(const OpDef* op_def, const NodeDef& node_def, + const DataTypeSlice inputs, const DataTypeSlice outputs); + + const OpDef* op_def_; // not owned + const NodeDef node_def_; + const DataTypeVector input_types_; + const DataTypeVector output_types_; + + private: + // Destructor invoked when last reference goes away via Unref() + virtual ~Properties(); + TF_DISALLOW_COPY_AND_ASSIGN(Properties); + }; + + Properties* properties() const { return props_; } + + // Initialize() adopts a reference to props, and so is suitable if props was + // just allocated or you call props->Ref() to increment the reference + // count for a props being held by another Node. + void Initialize(int id, int cost_id, Properties* props); + // Releases memory from props_, in addition to restoring *this to its + // uninitialized state. + void Clear(); + + int id_; // -1 until Initialize() is called + int cost_id_; // -1 if there is no corresponding cost accounting node + + EdgeSet in_edges_; + EdgeSet out_edges_; + + Properties* props_; + + // Name of device assigned to perform this computation. + string assigned_device_name_; + + TF_DISALLOW_COPY_AND_ASSIGN(Node); +}; + +class Edge { + public: + Node* src() const { return src_; } + Node* dst() const { return dst_; } + int id() const { return id_; } + + // Return the number of the source output that produces the data + // carried by this edge. The special value kControlSlot is used + // for control dependencies. + int src_output() const { return src_output_; } + + // Return the number of the destination input that consumes the data + // carried by this edge. The special value kControlSlot is used + // for control dependencies. + int dst_input() const { return dst_input_; } + + // Return true iff this is an edge that indicates a control-flow + // (as opposed to a data-flow) dependency. + bool IsControlEdge() const; + + private: + Edge() {} + + friend class EdgeSetTest; + friend class Graph; + Node* src_; + Node* dst_; + int id_; + int src_output_; + int dst_input_; +}; + +// Thread compatible but not thread safe. +class Graph { + public: + // Constructs a graph with a single SOURCE (always id kSourceId) and a + // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK. + // + // The graph can hold ops found in registry. + explicit Graph(const OpRegistryInterface* registry); + ~Graph(); + + static const int kControlSlot = -1; + + // Adds a new node to this graph, and returns it. Infers the Op and + // input/output types for the node. *this owns the returned instance. + // Returns nullptr and sets *status on error. + Node* AddNode(const NodeDef& node_def, Status* status); + + // Copies *node, which may belong to another graph, to a new node, + // which is returned. Does not copy any edges. *this owns the + // returned instance. + Node* CopyNode(Node* node); + + // Remove a node from this graph, including all edges from or to it. + // *node should not be accessed after calling this function. + // REQUIRES: node->IsOp() + void RemoveNode(Node* node); + + // Add an edge that connects the xth output of "source" to the yth input + // of "dest". + const Edge* AddEdge(Node* source, int x, Node* dest, int y); + + // Add a control-edge (no data flows along this edge) that + // connects "source" to "dest". + const Edge* AddControlEdge(Node* source, Node* dest) { + return AddEdge(source, kControlSlot, dest, kControlSlot); + } + + // Removes edge from the graph. + // REQUIRES: The edge must exist. + void RemoveEdge(const Edge* edge); + + // Returns one more than the maximum id assigned to any node. + int num_node_ids() const { return nodes_.size(); } + + // Serialize to a GraphDef. + void ToGraphDef(GraphDef* graph_def) const; + + // Generate new node name with the specified prefix that is unique + // across this graph. + string NewName(StringPiece prefix); + + // Access to the list of all nodes. Example usage: + // for (Node* node : graph.nodes()) { ... } + gtl::iterator_range<NodeIter> nodes() const; + + // Returns the node associated with an id, or nullptr if no node + // with that id (the node with that id was removed and the id has + // not yet been re-used). *this owns the returned instance. + // REQUIRES: 0 <= id < num_node_ids(). + Node* FindNodeId(int id) const { return nodes_[id]; } + + // Returns one more than the maximum id assigned to any edge. + int num_edge_ids() const { return edges_.size(); } + + // Returns the Edge associated with an id, or nullptr if no edge + // with that id (the node with that id was removed and the id has + // not yet been re-used). *this owns the returned instance. + // REQUIRES: 0 <= id < num_node_ids(). + const Edge* FindEdgeId(int id) const { return edges_[id]; } + + // Access to the set of all edges. Example usage: + // for (const Edge* e : graph.edges()) { ... } + const EdgeSet& edges() const { return edge_set_; } + + // The pre-defined nodes. + enum { kSourceId = 0, kSinkId = 1 }; + Node* source_node() const { return FindNodeId(kSourceId); } + Node* sink_node() const { return FindNodeId(kSinkId); } + + const OpRegistryInterface* op_registry() const { return ops_; } + + // TODO(josh11b): uint64 hash() const; + + private: + bool IsValidNode(Node* node) const; + // If cost_node is non-null, then cost accounting (in CostModel) + // will be associated with that node rather than the new one being + // created. + Node* AllocateNode(Node::Properties* props, const Node* cost_node); + void ReleaseNode(Node* node); + + // Registry of all known ops. Not owned. + const OpRegistryInterface* const ops_; + + // Allocator which will give us good locality. + core::Arena arena_; + + // Map from node ids to allocated nodes. nodes_[id] may be nullptr if + // the node with that id was removed from the graph. + std::vector<Node*> nodes_; + + // Map from edge ids to allocated edges. edges_[id] may be nullptr if + // the edge with that id was removed from the graph. + std::vector<Edge*> edges_; + + // For ease of iteration, we currently just keep a set of all live + // edges. May want to optimize by removing this copy. + EdgeSet edge_set_; + + // Allocated but free nodes and edges. + std::vector<Node*> free_nodes_; + std::vector<Edge*> free_edges_; + + // For generating unique names. + int name_counter_ = 0; + + TF_DISALLOW_COPY_AND_ASSIGN(Graph); +}; + +// TODO(josh11b): We may want to support keeping an index on various +// node/edge attributes in a graph, particularly node names. + +// Helper routines + +inline bool IsSwitch(const Node* node) { + return node->type_string() == "Switch" || node->type_string() == "RefSwitch"; +} + +inline bool IsMerge(const Node* node) { return node->type_string() == "Merge"; } + +inline bool IsEnter(const Node* node) { + return node->type_string() == "Enter" || node->type_string() == "RefEnter"; +} + +inline bool IsExit(const Node* node) { return node->type_string() == "Exit"; } + +inline bool IsNextIteration(const Node* node) { + return node->type_string() == "NextIteration"; +} + +inline bool IsLoopCond(const Node* node) { + return node->type_string() == "LoopCond"; +} + +inline bool IsControlTrigger(const Node* node) { + return node->type_string() == "ControlTrigger"; +} + +inline bool IsSend(const Node* node) { + return node->type_string() == "_Send" || node->type_string() == "_HostSend"; +} + +inline bool IsRecv(const Node* node) { + return node->type_string() == "_Recv" || node->type_string() == "_HostRecv"; +} + +// True for Nodes that mediate the transfer of values between processes. +inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); } + +inline bool IsConstant(const Node* node) { + return node->type_string() == "Const" || node->type_string() == "HostConst"; +} + +inline bool IsVariable(const Node* node) { + return node->type_string() == "Variable"; +} + +inline bool IsIdentity(const Node* node) { + return (node->type_string() == "Identity" || + node->type_string() == "RefIdentity"); +} + +// Returns true iff 'n' is a control flow node. +inline bool IsControlFlow(const Node* n) { + return IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n) || + IsNextIteration(n); +} + +inline bool IsHostMemoryPreserving(const Node* node) { + return IsIdentity(node) || IsControlFlow(node); +} + +// Iterator for stepping through the nodes of a graph. +class NodeIter { + public: + NodeIter(const Graph* graph, int id); + bool operator==(const NodeIter& rhs); + bool operator!=(const NodeIter& rhs); + void operator++(); + Node* operator*(); + Node* operator->(); + + private: + // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr + const Graph* graph_; + int id_; +}; + +// Iterator for stepping through the neighbors of a node. +class NeighborIter { + public: + NeighborIter(EdgeSet::const_iterator iter, bool incoming); + bool operator==(const NeighborIter& rhs); + bool operator!=(const NeighborIter& rhs); + void operator++(); + Node* operator*(); + Node* operator->(); + + private: + EdgeSet::const_iterator iter_; + bool incoming_; +}; + +// IMPLEMENTATION DETAILS, PLEASE IGNORE + +inline NodeIter::NodeIter(const Graph* graph, int id) + : graph_(graph), id_(id) {} + +inline bool NodeIter::operator==(const NodeIter& rhs) { + DCHECK(graph_ == rhs.graph_); + return id_ == rhs.id_; +} + +inline bool NodeIter::operator!=(const NodeIter& rhs) { + return !(*this == rhs); +} + +inline void NodeIter::operator++() { + while (1) { + DCHECK_LE(id_, graph_->num_node_ids()); + ++id_; + if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) { + return; + } + } +} + +inline Node* NodeIter::operator*() { return graph_->FindNodeId(id_); } + +inline Node* NodeIter::operator->() { return graph_->FindNodeId(id_); } + +inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming) + : iter_(iter), incoming_(incoming) {} + +inline bool NeighborIter::operator==(const NeighborIter& rhs) { + return iter_ == rhs.iter_ && incoming_ == rhs.incoming_; +} + +inline bool NeighborIter::operator!=(const NeighborIter& rhs) { + return !(*this == rhs); +} + +inline void NeighborIter::operator++() { ++iter_; } + +inline Node* NeighborIter::operator*() { + const Edge* e = *iter_; + return incoming_ ? e->src() : e->dst(); +} + +inline Node* NeighborIter::operator->() { + const Edge* e = *iter_; + return incoming_ ? e->src() : e->dst(); +} + +inline bool Edge::IsControlEdge() const { + // Note that if either src_output_ or dst_input_ is kControlSlot, + // so is the other one (AddEdge checks this). + return src_output_ == Graph::kControlSlot; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_GRAPH_GRAPH_H_ |