aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-06-23 15:24:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-23 15:32:27 -0700
commita94e547402f67566b0ddc361203382f76d282d67 (patch)
treea54470f9141659d5e49a9b9136fd7dce7f6a5e3f /tensorflow/core/graph/graph.cc
parente0944c4784d9d6bd43384bcd67bb7fe28f1d11ab (diff)
Use std::shared_ptr instead of core::RefCounted for Node::Properties
Also changes Node::Properties to a struct and removes underscores from public member variables. This change should make it easier to work with Properties moving forward as the refcount will be automatically updated. PiperOrigin-RevId: 160003281
Diffstat (limited to 'tensorflow/core/graph/graph.cc')
-rw-r--r--tensorflow/core/graph/graph.cc97
1 files changed, 37 insertions, 60 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 5d60e41d26..e06e479264 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -29,24 +29,19 @@ namespace tensorflow {
const int Graph::kControlSlot = -1;
-class NodeProperties : public core::RefCounted {
+struct NodeProperties {
public:
NodeProperties(const OpDef* op_def, const NodeDef& node_def,
const DataTypeSlice inputs, const DataTypeSlice outputs)
- : op_def_(op_def),
- node_def_(node_def),
- input_types_(inputs.begin(), inputs.end()),
- output_types_(outputs.begin(), outputs.end()) {}
-
- const OpDef* op_def_; // not owned
- NodeDef node_def_;
- const DataTypeVector input_types_;
- const DataTypeVector output_types_;
-
- private:
- // Destructor invoked when last reference goes away via Unref()
- ~NodeProperties() override {}
- TF_DISALLOW_COPY_AND_ASSIGN(NodeProperties);
+ : op_def(op_def),
+ node_def(node_def),
+ input_types(inputs.begin(), inputs.end()),
+ output_types(outputs.begin(), outputs.end()) {}
+
+ const OpDef* op_def; // not owned
+ NodeDef node_def;
+ const DataTypeVector input_types;
+ const DataTypeVector output_types;
};
// Node
@@ -114,26 +109,17 @@ Node::Node()
props_(nullptr),
assigned_device_name_index_(0) {}
-Node::~Node() {
- if (props_) {
- props_->Unref();
- }
-}
-
-void Node::Initialize(int id, int cost_id, NodeProperties* props) {
+void Node::Initialize(int id, int cost_id,
+ std::shared_ptr<NodeProperties> props) {
DCHECK_EQ(id_, -1);
DCHECK(in_edges_.empty());
DCHECK(out_edges_.empty());
id_ = id;
cost_id_ = cost_id;
- // Unref the old, assign the new properties.
- if (props_) {
- props_->Unref();
- }
- props_ = props;
+ props_ = std::move(props);
// Initialize the class_ based on the type string
- class_ = GetNodeClassForOp(props->node_def_.op());
+ class_ = GetNodeClassForOp(props_->node_def.op());
}
void Node::Clear() {
@@ -142,28 +128,23 @@ void Node::Clear() {
id_ = -1;
cost_id_ = -1;
class_ = NC_UNINITIALIZED;
-
- if (props_) {
- props_->Unref();
- props_ = nullptr;
- }
-
+ props_.reset();
assigned_device_name_index_ = 0;
}
-const string& Node::name() const { return props_->node_def_.name(); }
-const string& Node::type_string() const { return props_->node_def_.op(); }
-const NodeDef& Node::def() const { return props_->node_def_; }
-const OpDef& Node::op_def() const { return *props_->op_def_; }
+const string& Node::name() const { return props_->node_def.name(); }
+const string& Node::type_string() const { return props_->node_def.op(); }
+const NodeDef& Node::def() const { return props_->node_def; }
+const OpDef& Node::op_def() const { return *props_->op_def; }
-int32 Node::num_inputs() const { return props_->input_types_.size(); }
-DataType Node::input_type(int32 i) const { return props_->input_types_[i]; }
-const DataTypeVector& Node::input_types() const { return props_->input_types_; }
+int32 Node::num_inputs() const { return props_->input_types.size(); }
+DataType Node::input_type(int32 i) const { return props_->input_types[i]; }
+const DataTypeVector& Node::input_types() const { return props_->input_types; }
-int32 Node::num_outputs() const { return props_->output_types_.size(); }
-DataType Node::output_type(int32 o) const { return props_->output_types_[o]; }
+int32 Node::num_outputs() const { return props_->output_types.size(); }
+DataType Node::output_type(int32 o) const { return props_->output_types[o]; }
const DataTypeVector& Node::output_types() const {
- return props_->output_types_;
+ return props_->output_types;
}
AttrSlice Node::attrs() const { return AttrSlice(def()); }
@@ -186,23 +167,19 @@ gtl::iterator_range<NeighborIter> Node::in_nodes() const {
void Node::MaybeCopyOnWrite() {
// NodeProperties may be shared between Nodes. Make a copy if so.
- if (!props_->RefCountIsOne()) {
- NodeProperties* new_props =
- new NodeProperties(props_->op_def_, props_->node_def_,
- props_->input_types_, props_->output_types_);
- props_->Unref();
- props_ = new_props;
+ if (!props_.unique()) {
+ props_ = std::make_shared<NodeProperties>(*props_);
}
}
AttrValue* Node::AddAttrHelper(const string& name) {
MaybeCopyOnWrite();
- return &((*props_->node_def_.mutable_attr())[name]);
+ return &((*props_->node_def.mutable_attr())[name]);
}
void Node::ClearAttr(const string& name) {
MaybeCopyOnWrite();
- (*props_->node_def_.mutable_attr()).erase(name);
+ (*props_->node_def.mutable_attr()).erase(name);
}
Status Node::input_edge(int idx, const Edge** e) const {
@@ -338,16 +315,15 @@ Node* Graph::AddNode(const NodeDef& node_def, Status* status) {
}
Node* node = AllocateNode(
- new NodeProperties(op_def, node_def, inputs, outputs), nullptr);
+ std::make_shared<NodeProperties>(op_def, node_def, inputs, outputs),
+ nullptr);
return node;
}
Node* Graph::CopyNode(Node* node) {
DCHECK(!node->IsSource());
DCHECK(!node->IsSink());
- NodeProperties* props = node->properties();
- props->Ref();
- Node* copy = AllocateNode(props, node);
+ Node* copy = AllocateNode(node->props_, node);
copy->set_assigned_device_name(node->assigned_device_name());
// Since the OpDef of a function may be owned by the Graph that owns 'node',
@@ -355,9 +331,9 @@ Node* Graph::CopyNode(Node* node) {
// node properties with the updated OpDef.
const OpDef* op_def;
TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def));
- if (op_def != props->op_def_) {
+ if (op_def != node->props_->op_def) {
copy->MaybeCopyOnWrite();
- copy->props_->op_def_ = op_def;
+ copy->props_->op_def = op_def;
}
return copy;
@@ -540,7 +516,8 @@ bool Graph::IsValidNode(Node* node) const {
return nodes_[id] == node;
}
-Node* Graph::AllocateNode(NodeProperties* props, const Node* cost_node) {
+Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
+ const Node* cost_node) {
Node* node = nullptr;
if (free_nodes_.empty()) {
node = new (arena_.Alloc(sizeof(Node))) Node; // placement new
@@ -551,7 +528,7 @@ Node* Graph::AllocateNode(NodeProperties* props, const Node* cost_node) {
node->graph_ = this;
const int id = nodes_.size();
int cost_id = cost_node ? cost_node->cost_id() : id;
- node->Initialize(id, cost_id, props);
+ node->Initialize(id, cost_id, std::move(props));
nodes_.push_back(node);
++num_nodes_;
return node;