/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/graph/graph.h" #include #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/while_context.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/version.h" namespace tensorflow { const int Graph::kControlSlot = -1; struct NodeProperties { public: NodeProperties(const OpDef* op_def, const NodeDef& node_def, const DataTypeSlice inputs, const DataTypeSlice outputs) : op_def(op_def), node_def(node_def), input_types(inputs.begin(), inputs.end()), output_types(outputs.begin(), outputs.end()) {} const OpDef* op_def; // not owned NodeDef node_def; const DataTypeVector input_types; const DataTypeVector output_types; }; // Node #define REF_CLASS(key, value) \ {key, value}, { "Ref" key, value } const std::unordered_map& Node::kNodeClassTable = *new std::unordered_map({ // Keep in same order as NodeClass values REF_CLASS("Switch", NC_SWITCH), REF_CLASS("Merge", NC_MERGE), REF_CLASS("Enter", NC_ENTER), REF_CLASS("Exit", NC_EXIT), REF_CLASS("NextIteration", NC_NEXT_ITERATION), {"LoopCond", NC_LOOP_COND}, {"ControlTrigger", NC_CONTROL_TRIGGER}, {"_Send", NC_SEND}, {"_HostSend", NC_HOST_SEND}, {"_Recv", NC_RECV}, {"_HostRecv", NC_HOST_RECV}, {"Const", NC_CONSTANT}, {"HostConst", NC_CONSTANT}, {"Variable", NC_VARIABLE}, {"VariableV2", NC_VARIABLE}, REF_CLASS("Identity", NC_IDENTITY), {"GetSessionHandle", NC_GET_SESSION_HANDLE}, {"GetSessionHandleV2", NC_GET_SESSION_HANDLE}, {"GetSessionTensor", NC_GET_SESSION_TENSOR}, {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR}, {"Size", NC_METADATA}, {"Shape", NC_METADATA}, {"Rank", NC_METADATA}, {"_ScopedAllocator", NC_SCOPED_ALLOCATOR}, {"CollectiveReduce", NC_COLLECTIVE}, {"CollectiveBcastSend", NC_COLLECTIVE}, {"CollectiveBcastRecv", NC_COLLECTIVE}, }); #undef REF_CLASS Node::NodeClass Node::GetNodeClassForOp(const string& ts) { auto it = kNodeClassTable.find(ts); if (it != kNodeClassTable.end()) { return it->second; } else { return NC_OTHER; } } string Node::DebugString() const { string ret = strings::StrCat("{name:'", name(), "' id:", id_); if (IsSource()) { strings::StrAppend(&ret, " source}"); } else if (IsSink()) { strings::StrAppend(&ret, " sink}"); } else { strings::StrAppend(&ret, " op device:"); strings::StrAppend(&ret, "{", assigned_device_name(), "}"); strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}"); } return ret; } Node::Node() : id_(-1), cost_id_(-1), class_(NC_UNINITIALIZED), props_(nullptr), assigned_device_name_index_(0), while_ctx_(nullptr) {} void Node::Initialize(int id, int cost_id, std::shared_ptr props) { DCHECK_EQ(id_, -1); DCHECK(in_edges_.empty()); DCHECK(out_edges_.empty()); id_ = id; cost_id_ = cost_id; props_ = std::move(props); // Initialize the class_ based on the type string class_ = GetNodeClassForOp(props_->node_def.op()); } void Node::Clear() { in_edges_.clear(); out_edges_.clear(); id_ = -1; cost_id_ = -1; class_ = NC_UNINITIALIZED; props_.reset(); assigned_device_name_index_ = 0; } void Node::UpdateProperties() { DataTypeVector inputs; DataTypeVector outputs; Status status = InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs); if (!status.ok()) { LOG(ERROR) << "Failed at updating node: " << status; return; } props_ = std::make_shared(props_->op_def, props_->node_def, inputs, outputs); } const string& Node::name() const { return props_->node_def.name(); } const string& Node::type_string() const { return props_->node_def.op(); } const NodeDef& Node::def() const { return props_->node_def; } const OpDef& Node::op_def() const { return *props_->op_def; } int32 Node::num_inputs() const { return props_->input_types.size(); } DataType Node::input_type(int32 i) const { return props_->input_types[i]; } const DataTypeVector& Node::input_types() const { return props_->input_types; } int32 Node::num_outputs() const { return props_->output_types.size(); } DataType Node::output_type(int32 o) const { return props_->output_types[o]; } const DataTypeVector& Node::output_types() const { return props_->output_types; } AttrSlice Node::attrs() const { return AttrSlice(def()); } const protobuf::RepeatedPtrField& Node::requested_inputs() const { return def().input(); } const string& Node::requested_device() const { return def().device(); } gtl::iterator_range Node::out_nodes() const { return gtl::make_range(NeighborIter(out_edges_.begin(), false), NeighborIter(out_edges_.end(), false)); } gtl::iterator_range Node::in_nodes() const { return gtl::make_range(NeighborIter(in_edges_.begin(), true), NeighborIter(in_edges_.end(), true)); } void Node::MaybeCopyOnWrite() { // NodeProperties may be shared between Nodes. Make a copy if so. if (!props_.unique()) { props_ = std::make_shared(*props_); } } AttrValue* Node::AddAttrHelper(const string& name) { MaybeCopyOnWrite(); return &((*props_->node_def.mutable_attr())[name]); } void Node::ClearAttr(const string& name) { MaybeCopyOnWrite(); (*props_->node_def.mutable_attr()).erase(name); } void Node::set_name(string name) { MaybeCopyOnWrite(); props_->node_def.set_name(std::move(name)); } void Node::set_requested_device(const string& device) { MaybeCopyOnWrite(); props_->node_def.set_device(device); } Status Node::input_edge(int idx, const Edge** e) const { if (idx < 0 || idx >= num_inputs()) { return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ", name(), " only has ", num_inputs(), " inputs."); } // This does a linear search over the edges. In the common case, // the number of elements is small enough that this search isn't // expensive. Should it become a bottleneck, one can make an // optimization where, if the number of edges is small, we use // linear iteration, and if the number of edges is large, we perform // an indexing step during construction that keeps an array of Edges // indexed by pointer. This would keep the size of each Node small // in the common case but make this function faster when the number // of edges is large. for (const Edge* edge : in_edges()) { if (edge->dst_input() == idx) { *e = edge; return Status::OK(); } } return errors::NotFound("Could not find input edge ", idx, " for ", name()); } // Returns a vector of the non-control input edges to a node, indexed by ID. Status Node::input_edges(std::vector* input_edges) const { input_edges->clear(); input_edges->resize(num_inputs(), nullptr); for (const Edge* edge : in_edges()) { if (edge->IsControlEdge()) continue; if (edge->dst_input() < 0 || edge->dst_input() >= num_inputs()) { return errors::Internal("Invalid edge input number ", edge->dst_input()); } if ((*input_edges)[edge->dst_input()] != nullptr) { return errors::Internal("Duplicate edge input number: ", edge->dst_input()); } (*input_edges)[edge->dst_input()] = edge; } for (int i = 0; i < num_inputs(); ++i) { if ((*input_edges)[i] == nullptr) { return errors::InvalidArgument("Missing edge input number: ", i); } } return Status::OK(); } Status Node::input_node(int idx, Node** n) const { const Edge* e; TF_RETURN_IF_ERROR(input_edge(idx, &e)); if (e == nullptr) { *n = nullptr; } else { *n = e->src(); } return Status::OK(); } Status Node::input_node(int idx, const Node** const_n) const { Node* n; TF_RETURN_IF_ERROR(input_node(idx, &n)); *const_n = n; return Status::OK(); } // InputTensor bool InputTensor::operator==(const InputTensor& other) const { return node == other.node && index == other.index; } uint64 InputTensor::Hash::operator()(InputTensor const& s) const { return Hash64Combine(std::hash()(s.node), std::hash()(s.index)); } // OutputTensor bool OutputTensor::operator==(const OutputTensor& other) const { return node == other.node && index == other.index; } uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const { return Hash64Combine(std::hash()(s.node), std::hash()(s.index)); } // Graph Graph::Graph(const OpRegistryInterface* ops) : ops_(ops, FunctionDefLibrary()), versions_(new VersionDef), arena_(8 << 10 /* 8kB */) { versions_->set_producer(TF_GRAPH_DEF_VERSION); versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER); // Initialize the name interning table for assigned_device_name. device_names_.push_back(""); DCHECK_EQ(0, InternDeviceName("")); // Source and sink have no endpoints, just control edges. NodeDef def; def.set_name("_SOURCE"); def.set_op("NoOp"); Status status; Node* source = AddNode(def, &status); TF_CHECK_OK(status); CHECK_EQ(source->id(), kSourceId); def.set_name("_SINK"); Node* sink = AddNode(def, &status); TF_CHECK_OK(status); CHECK_EQ(sink->id(), kSinkId); AddControlEdge(source, sink); } Graph::Graph(const FunctionLibraryDefinition& flib_def) : Graph(flib_def.default_registry()) { // Need a new-enough consumer to support the functions we add to the graph. if (flib_def.ToProto().function_size() > 0 && versions_->min_consumer() < 12) { versions_->set_min_consumer(12); } Status s = ops_.AddLibrary(flib_def); CHECK(s.ok()) << s.error_message(); } Graph::~Graph() { // Manually call the destructors for all the Nodes we constructed using // placement new. for (Node* node : nodes_) { if (node != nullptr) { node->~Node(); } } for (Node* node : free_nodes_) { node->~Node(); } // Edges have no destructor, and we arena-allocated them, so no need to // destroy them. } const VersionDef& Graph::versions() const { return *versions_; } void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; } Node* Graph::AddNode(const NodeDef& node_def, Status* status) { const OpDef* op_def; status->Update(ops_.LookUpOpDef(node_def.op(), &op_def)); if (!status->ok()) return nullptr; DataTypeVector inputs; DataTypeVector outputs; status->Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs)); if (!status->ok()) { *status = AttachDef(*status, node_def); return nullptr; } Node* node = AllocateNode( std::make_shared(op_def, node_def, inputs, outputs), nullptr); return node; } Node* Graph::CopyNode(const Node* node) { DCHECK(!node->IsSource()); DCHECK(!node->IsSink()); Node* copy = AllocateNode(node->props_, node); copy->set_assigned_device_name(node->assigned_device_name()); // Since the OpDef of a function may be owned by the Graph that owns 'node', // relookup the OpDef in the target graph. If it differs, then clone the // node properties with the updated OpDef. const OpDef* op_def; TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def)); if (op_def != node->props_->op_def) { copy->MaybeCopyOnWrite(); copy->props_->op_def = op_def; } return copy; } void Graph::RemoveNode(Node* node) { TF_DCHECK_OK(IsValidNode(node)) << node->DebugString(); DCHECK(!node->IsSource()); DCHECK(!node->IsSink()); // Remove any edges involving this node. while (!node->in_edges_.empty()) { RemoveEdge(*node->in_edges_.begin()); } while (!node->out_edges_.empty()) { RemoveEdge(*node->out_edges_.begin()); } ReleaseNode(node); } const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) { TF_DCHECK_OK(IsValidNode(source)) << source->DebugString(); TF_DCHECK_OK(IsValidNode(dest)) << dest->DebugString(); // source/sink must only be linked via control slots, and // control slots must only be linked to control slots. if (source == source_node() || dest == sink_node() || x == kControlSlot || y == kControlSlot) { DCHECK_EQ(x, kControlSlot) << source->DebugString(); DCHECK_EQ(y, kControlSlot) << dest->DebugString(); } Edge* e = nullptr; if (free_edges_.empty()) { e = new (arena_.Alloc(sizeof(Edge))) Edge; // placement new } else { e = free_edges_.back(); free_edges_.pop_back(); } e->id_ = edges_.size(); e->src_ = source; e->dst_ = dest; e->src_output_ = x; e->dst_input_ = y; CHECK(source->out_edges_.insert(e).second); CHECK(dest->in_edges_.insert(e).second); edges_.push_back(e); ++num_edges_; return e; } void Graph::RemoveEdge(const Edge* e) { TF_DCHECK_OK(IsValidNode(e->src_)) << e->src_->DebugString(); TF_DCHECK_OK(IsValidNode(e->dst_)) << e->dst_->DebugString(); CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1}); CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1}); CHECK_EQ(e, edges_[e->id_]); CHECK_GT(num_edges_, 0); edges_[e->id_] = nullptr; Edge* del = const_cast(e); del->src_ = nullptr; del->dst_ = nullptr; del->id_ = -1; del->src_output_ = kControlSlot - 1; del->dst_input_ = kControlSlot - 1; free_edges_.push_back(del); --num_edges_; } const Edge* Graph::AddControlEdge(Node* source, Node* dest, bool allow_duplicates) { if (!allow_duplicates) { for (const Edge* edge : dest->in_edges()) { if (edge->IsControlEdge() && edge->src() == source) { // The requested edge already exists. return nullptr; } } } // Modify dest's NodeDef if necessary. if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) { // Check if this input is already in dest's NodeDef. const string new_input = strings::StrCat("^", source->name()); bool input_exists = false; for (const string& input : dest->props_->node_def.input()) { if (input == new_input) { input_exists = true; break; } } if (!input_exists) { dest->MaybeCopyOnWrite(); dest->props_->node_def.add_input(new_input); } } return AddEdge(source, kControlSlot, dest, kControlSlot); } void Graph::RemoveControlEdge(const Edge* e) { if (!e->src_->IsSource() && !e->dst_->IsSink()) { e->dst_->MaybeCopyOnWrite(); string e_src_name = strings::StrCat("^", e->src_->name()); auto* inputs = e->dst_->props_->node_def.mutable_input(); for (auto it = inputs->begin(); it != inputs->end(); ++it) { if (*it == e_src_name) { inputs->erase(it); break; } } } RemoveEdge(e); } namespace { const Edge* FindEdge(const Node* dst, int index) { for (const Edge* e : dst->in_edges()) { if (e->dst_input() == index) return e; } return nullptr; } } // namespace Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index) { TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index)); TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index)); const Edge* e = FindEdge(dst, dst_index); if (e == nullptr) { return errors::InvalidArgument("Couldn't find edge to ", dst->DebugString()); } RemoveEdge(e); AddEdge(new_src, new_src_index, dst, dst_index); dst->MaybeCopyOnWrite(); (*dst->props_->node_def.mutable_input())[dst_index] = strings::StrCat(new_src->name(), ":", new_src_index); return Status::OK(); } Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { // Need a new-enough consumer to support the functions we add to the graph. if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) { versions_->set_min_consumer(12); } return ops_.AddLibrary(fdef_lib); } namespace { void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) { if (src_slot == Graph::kControlSlot) { dst->add_input(strings::StrCat("^", src_name)); } else if (src_slot == 0) { dst->add_input(src_name.data(), src_name.size()); } else { dst->add_input(strings::StrCat(src_name, ":", src_slot)); } } } // namespace void Graph::ToGraphDef(GraphDef* graph_def) const { ToGraphDefSubRange(graph_def, 0); } GraphDef Graph::ToGraphDefDebug() const { GraphDef ret; ToGraphDef(&ret); return ret; } void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const { graph_def->Clear(); *graph_def->mutable_versions() = versions(); *graph_def->mutable_library() = ops_.ToProto(); graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id)); std::vector inputs; // Construct this outside the loop for speed. for (auto id = from_node_id; id < num_node_ids(); ++id) { const Node* node = FindNodeId(id); if (node == nullptr || !node->IsOp()) continue; NodeDef* node_def = graph_def->add_node(); *node_def = node->def(); // Use the node's assigned device, if any, instead of the device requested // in the NodeDef. if (!node->assigned_device_name().empty()) { node_def->set_device(node->assigned_device_name()); } // Get the inputs for this Node. We make sure control inputs are // after data inputs, as required by GraphDef. inputs.clear(); inputs.resize(node->num_inputs(), nullptr); for (const Edge* edge : node->in_edges()) { if (edge->IsControlEdge()) { inputs.push_back(edge); } else { CHECK(inputs[edge->dst_input()] == nullptr) << "Edge " << edge->src()->DebugString() << ":" << edge->dst()->DebugString() << " with dst_input " << edge->dst_input() << " and had pre-existing input edge " << inputs[edge->dst_input()]->src()->DebugString() << ":" << inputs[edge->dst_input()]->dst()->DebugString(); inputs[edge->dst_input()] = edge; } } // Sort the control inputs for more predictable serialization. std::sort(inputs.begin() + node->num_inputs(), inputs.end(), [](const Edge* a, const Edge* b) -> bool { return a->src()->name() < b->src()->name(); }); node_def->clear_input(); node_def->mutable_input()->Reserve(inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { const Edge* edge = inputs[i]; if (edge == nullptr) { if (i < node->requested_inputs().size()) { node_def->add_input(node->requested_inputs()[i]); } else { node_def->add_input(""); } } else { const Node* src = edge->src(); if (!src->IsOp()) continue; AddInput(node_def, src->name(), edge->src_output()); } } } } string Graph::NewName(StringPiece prefix) { return strings::StrCat(prefix, "/_", name_counter_++); } Status Graph::IsValidNode(const Node* node) const { if (node == nullptr) { return errors::InvalidArgument("Node is null"); } const int id = node->id(); if (id < 0) { return errors::InvalidArgument("node id ", id, " is less than zero"); } if (static_cast(id) >= nodes_.size()) { return errors::InvalidArgument( "node id ", id, " is >= than number of nodes in graph ", nodes_.size()); } if (nodes_[id] != node) { return errors::InvalidArgument("Node with id ", id, " is different from the passed in node. " "Does it belong to a different graph?"); } return Status::OK(); } Status Graph::IsValidOutputTensor(const Node* node, int idx) const { TF_RETURN_IF_ERROR(IsValidNode(node)); if (idx >= node->num_outputs() || idx < 0) { return errors::OutOfRange("Node '", node->name(), "' (type: '", node->op_def().name(), "', num of outputs: ", node->num_outputs(), ") does not have ", "output ", idx); } return Status::OK(); } Status Graph::IsValidInputTensor(const Node* node, int idx) const { TF_RETURN_IF_ERROR(IsValidNode(node)); if (idx >= node->num_inputs() || idx < 0) { return errors::OutOfRange("Node '", node->name(), "' (type: '", node->op_def().name(), "', num of inputs: ", node->num_inputs(), ") does not have ", "input ", idx); } return Status::OK(); } Node* Graph::AllocateNode(std::shared_ptr props, const Node* cost_node) { Node* node = nullptr; if (free_nodes_.empty()) { node = new (arena_.Alloc(sizeof(Node))) Node; // placement new } else { node = free_nodes_.back(); free_nodes_.pop_back(); } node->graph_ = this; const int id = nodes_.size(); int cost_id = cost_node ? cost_node->cost_id() : id; node->Initialize(id, cost_id, std::move(props)); nodes_.push_back(node); ++num_nodes_; return node; } void Graph::ReleaseNode(Node* node) { TF_DCHECK_OK(IsValidNode(node)) << node->DebugString(); nodes_[node->id()] = nullptr; free_nodes_.push_back(node); --num_nodes_; node->Clear(); } // Ensures that 'device_name' is present in the device name table, and returns // the index of that device name. The index is stable, and can be used in // calls to Node::set_assigned_device_name_index(). int Graph::InternDeviceName(const string& device_name) { // Special case, very common. Also, this allows us to use a single map // lookup below, instead of two. The 'if (index_cell > 0)' test below // relies on this check. if (device_name.empty()) { return 0; } int& index_cell = device_names_map_[device_name]; if (index_cell > 0) { return index_cell; } const int index = device_names_map_.size(); index_cell = index; device_names_.push_back(device_name); return index; } Status Graph::AddWhileContext(StringPiece frame_name, std::vector enter_nodes, std::vector exit_nodes, OutputTensor cond_output, std::vector body_inputs, std::vector body_outputs, WhileContext** result) { auto pair = while_ctxs_.insert(std::pair( string(frame_name), WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes), cond_output, std::move(body_inputs), std::move(body_outputs)))); if (!pair.second) { *result = nullptr; return errors::InvalidArgument("WhileContext with frame name '", frame_name, "' already exists"); } *result = &pair.first->second; return Status::OK(); } string Edge::DebugString() const { return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(), src_output_, dst_->name().c_str(), dst_input_); } } // namespace tensorflow