From a94e547402f67566b0ddc361203382f76d282d67 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Fri, 23 Jun 2017 15:24:03 -0700 Subject: Use std::shared_ptr instead of core::RefCounted for Node::Properties Also changes Node::Properties to a struct and removes underscores from public member variables. This change should make it easier to work with Properties moving forward as the refcount will be automatically updated. PiperOrigin-RevId: 160003281 --- tensorflow/core/graph/graph.cc | 97 ++++++++++++++++-------------------------- tensorflow/core/graph/graph.h | 18 ++++---- 2 files changed, 47 insertions(+), 68 deletions(-) diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 5d60e41d26..e06e479264 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -29,24 +29,19 @@ namespace tensorflow { const int Graph::kControlSlot = -1; -class NodeProperties : public core::RefCounted { +struct NodeProperties { public: NodeProperties(const OpDef* op_def, const NodeDef& node_def, const DataTypeSlice inputs, const DataTypeSlice outputs) - : op_def_(op_def), - node_def_(node_def), - input_types_(inputs.begin(), inputs.end()), - output_types_(outputs.begin(), outputs.end()) {} - - const OpDef* op_def_; // not owned - NodeDef node_def_; - const DataTypeVector input_types_; - const DataTypeVector output_types_; - - private: - // Destructor invoked when last reference goes away via Unref() - ~NodeProperties() override {} - TF_DISALLOW_COPY_AND_ASSIGN(NodeProperties); + : op_def(op_def), + node_def(node_def), + input_types(inputs.begin(), inputs.end()), + output_types(outputs.begin(), outputs.end()) {} + + const OpDef* op_def; // not owned + NodeDef node_def; + const DataTypeVector input_types; + const DataTypeVector output_types; }; // Node @@ -114,26 +109,17 @@ Node::Node() props_(nullptr), assigned_device_name_index_(0) {} -Node::~Node() { - if (props_) { - props_->Unref(); - } -} - -void Node::Initialize(int id, int cost_id, NodeProperties* props) { +void Node::Initialize(int id, int cost_id, + std::shared_ptr props) { DCHECK_EQ(id_, -1); DCHECK(in_edges_.empty()); DCHECK(out_edges_.empty()); id_ = id; cost_id_ = cost_id; - // Unref the old, assign the new properties. - if (props_) { - props_->Unref(); - } - props_ = props; + props_ = std::move(props); // Initialize the class_ based on the type string - class_ = GetNodeClassForOp(props->node_def_.op()); + class_ = GetNodeClassForOp(props_->node_def.op()); } void Node::Clear() { @@ -142,28 +128,23 @@ void Node::Clear() { id_ = -1; cost_id_ = -1; class_ = NC_UNINITIALIZED; - - if (props_) { - props_->Unref(); - props_ = nullptr; - } - + props_.reset(); assigned_device_name_index_ = 0; } -const string& Node::name() const { return props_->node_def_.name(); } -const string& Node::type_string() const { return props_->node_def_.op(); } -const NodeDef& Node::def() const { return props_->node_def_; } -const OpDef& Node::op_def() const { return *props_->op_def_; } +const string& Node::name() const { return props_->node_def.name(); } +const string& Node::type_string() const { return props_->node_def.op(); } +const NodeDef& Node::def() const { return props_->node_def; } +const OpDef& Node::op_def() const { return *props_->op_def; } -int32 Node::num_inputs() const { return props_->input_types_.size(); } -DataType Node::input_type(int32 i) const { return props_->input_types_[i]; } -const DataTypeVector& Node::input_types() const { return props_->input_types_; } +int32 Node::num_inputs() const { return props_->input_types.size(); } +DataType Node::input_type(int32 i) const { return props_->input_types[i]; } +const DataTypeVector& Node::input_types() const { return props_->input_types; } -int32 Node::num_outputs() const { return props_->output_types_.size(); } -DataType Node::output_type(int32 o) const { return props_->output_types_[o]; } +int32 Node::num_outputs() const { return props_->output_types.size(); } +DataType Node::output_type(int32 o) const { return props_->output_types[o]; } const DataTypeVector& Node::output_types() const { - return props_->output_types_; + return props_->output_types; } AttrSlice Node::attrs() const { return AttrSlice(def()); } @@ -186,23 +167,19 @@ gtl::iterator_range Node::in_nodes() const { void Node::MaybeCopyOnWrite() { // NodeProperties may be shared between Nodes. Make a copy if so. - if (!props_->RefCountIsOne()) { - NodeProperties* new_props = - new NodeProperties(props_->op_def_, props_->node_def_, - props_->input_types_, props_->output_types_); - props_->Unref(); - props_ = new_props; + if (!props_.unique()) { + props_ = std::make_shared(*props_); } } AttrValue* Node::AddAttrHelper(const string& name) { MaybeCopyOnWrite(); - return &((*props_->node_def_.mutable_attr())[name]); + return &((*props_->node_def.mutable_attr())[name]); } void Node::ClearAttr(const string& name) { MaybeCopyOnWrite(); - (*props_->node_def_.mutable_attr()).erase(name); + (*props_->node_def.mutable_attr()).erase(name); } Status Node::input_edge(int idx, const Edge** e) const { @@ -338,16 +315,15 @@ Node* Graph::AddNode(const NodeDef& node_def, Status* status) { } Node* node = AllocateNode( - new NodeProperties(op_def, node_def, inputs, outputs), nullptr); + std::make_shared(op_def, node_def, inputs, outputs), + nullptr); return node; } Node* Graph::CopyNode(Node* node) { DCHECK(!node->IsSource()); DCHECK(!node->IsSink()); - NodeProperties* props = node->properties(); - props->Ref(); - Node* copy = AllocateNode(props, node); + Node* copy = AllocateNode(node->props_, node); copy->set_assigned_device_name(node->assigned_device_name()); // Since the OpDef of a function may be owned by the Graph that owns 'node', @@ -355,9 +331,9 @@ Node* Graph::CopyNode(Node* node) { // node properties with the updated OpDef. const OpDef* op_def; TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def)); - if (op_def != props->op_def_) { + if (op_def != node->props_->op_def) { copy->MaybeCopyOnWrite(); - copy->props_->op_def_ = op_def; + copy->props_->op_def = op_def; } return copy; @@ -540,7 +516,8 @@ bool Graph::IsValidNode(Node* node) const { return nodes_[id] == node; } -Node* Graph::AllocateNode(NodeProperties* props, const Node* cost_node) { +Node* Graph::AllocateNode(std::shared_ptr props, + const Node* cost_node) { Node* node = nullptr; if (free_nodes_.empty()) { node = new (arena_.Alloc(sizeof(Node))) Node; // placement new @@ -551,7 +528,7 @@ Node* Graph::AllocateNode(NodeProperties* props, const Node* cost_node) { node->graph_ = this; const int id = nodes_.size(); int cost_id = cost_node ? cost_node->cost_id() : id; - node->Initialize(id, cost_id, props); + node->Initialize(id, cost_id, std::move(props)); nodes_.push_back(node); ++num_nodes_; return node; diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 08e2838d3c..e19e0b727d 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -181,14 +181,10 @@ class Node { private: friend class Graph; Node(); - ~Node(); - NodeProperties* properties() const { return props_; } + NodeProperties* properties() const { return props_.get(); } - // Initialize() adopts a reference to props, and so is suitable if props was - // just allocated or you call props->Ref() to increment the reference - // count for a props being held by another Node. - void Initialize(int id, int cost_id, NodeProperties* props); + void Initialize(int id, int cost_id, std::shared_ptr props); // Releases memory from props_, in addition to restoring *this to its // uninitialized state. @@ -238,7 +234,10 @@ class Node { EdgeSet in_edges_; EdgeSet out_edges_; - NodeProperties* props_; + // NOTE(skyewm): inheriting from core::RefCounted may have a slight + // performance benefit over using shared_ptr, at the cost of manual ref + // counting + std::shared_ptr props_; // Index within Graph::device_names_ of the name of device assigned // to perform this computation. @@ -505,7 +504,10 @@ class Graph { // If cost_node is non-null, then cost accounting (in CostModel) // will be associated with that node rather than the new one being // created. - Node* AllocateNode(NodeProperties* props, const Node* cost_node); + // + // Ownership of the returned Node is not transferred to caller. + Node* AllocateNode(std::shared_ptr props, + const Node* cost_node); void ReleaseNode(Node* node); // Registry of all known ops, including functions. -- cgit v1.2.3