diff options
author | Geoffrey Irving <geoffreyi@google.com> | 2017-06-23 12:55:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-23 12:59:28 -0700 |
commit | 6ada43366663210beb0159b8c1a67b26ebfe6cb7 (patch) | |
tree | 1f41bb5c9d2000cb4dd47645f57c181aef3fae3e /tensorflow/core/graph/graph.cc | |
parent | 0eff699d3087171cf35671d9d0bd6f8e79441ab3 (diff) |
Prepare to not include node_def.proto.h in node_def_util.h
The goal is to make kernels mostly independent of proto headers, which will let
us lock down our .so imports. This CL makes a bunch of .cc files
either include node_def.proto.h themselves or not need the definition of
NodeDef; a second CL will make node_def_util.h not include node_def.proto.h.
RELNOTES: n/a
PiperOrigin-RevId: 159982117
Diffstat (limited to 'tensorflow/core/graph/graph.cc')
-rw-r--r-- | tensorflow/core/graph/graph.cc | 76 |
1 files changed, 57 insertions, 19 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index dcb8520cf7..5d60e41d26 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include <vector> +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/errors.h" @@ -28,6 +29,26 @@ namespace tensorflow { const int Graph::kControlSlot = -1; +class NodeProperties : public core::RefCounted { + public: + NodeProperties(const OpDef* op_def, const NodeDef& node_def, + const DataTypeSlice inputs, const DataTypeSlice outputs) + : op_def_(op_def), + node_def_(node_def), + input_types_(inputs.begin(), inputs.end()), + output_types_(outputs.begin(), outputs.end()) {} + + const OpDef* op_def_; // not owned + NodeDef node_def_; + const DataTypeVector input_types_; + const DataTypeVector output_types_; + + private: + // Destructor invoked when last reference goes away via Unref() + ~NodeProperties() override {} + TF_DISALLOW_COPY_AND_ASSIGN(NodeProperties); +}; + // Node #define REF_CLASS(key, value) \ @@ -99,7 +120,7 @@ Node::~Node() { } } -void Node::Initialize(int id, int cost_id, Properties* props) { +void Node::Initialize(int id, int cost_id, NodeProperties* props) { DCHECK_EQ(id_, -1); DCHECK(in_edges_.empty()); DCHECK(out_edges_.empty()); @@ -130,6 +151,29 @@ void Node::Clear() { assigned_device_name_index_ = 0; } +const string& Node::name() const { return props_->node_def_.name(); } +const string& Node::type_string() const { return props_->node_def_.op(); } +const NodeDef& Node::def() const { return props_->node_def_; } +const OpDef& Node::op_def() const { return *props_->op_def_; } + +int32 Node::num_inputs() const { return props_->input_types_.size(); } +DataType Node::input_type(int32 i) const { return props_->input_types_[i]; } +const DataTypeVector& Node::input_types() const { return props_->input_types_; } + +int32 Node::num_outputs() const { return props_->output_types_.size(); } +DataType Node::output_type(int32 o) const { return props_->output_types_[o]; } +const DataTypeVector& Node::output_types() const { + return props_->output_types_; +} + +AttrSlice Node::attrs() const { return AttrSlice(def()); } + +const protobuf::RepeatedPtrField<string>& Node::requested_inputs() const { + return def().input(); +} + +const string& Node::requested_device() const { return def().device(); } + gtl::iterator_range<NeighborIter> Node::out_nodes() const { return gtl::make_range(NeighborIter(out_edges_.begin(), false), NeighborIter(out_edges_.end(), false)); @@ -141,16 +185,21 @@ gtl::iterator_range<NeighborIter> Node::in_nodes() const { } void Node::MaybeCopyOnWrite() { - // Properties may be shared between Nodes. Make a copy if so. + // NodeProperties may be shared between Nodes. Make a copy if so. if (!props_->RefCountIsOne()) { - Properties* new_props = - new Properties(props_->op_def_, props_->node_def_, props_->input_types_, - props_->output_types_); + NodeProperties* new_props = + new NodeProperties(props_->op_def_, props_->node_def_, + props_->input_types_, props_->output_types_); props_->Unref(); props_ = new_props; } } +AttrValue* Node::AddAttrHelper(const string& name) { + MaybeCopyOnWrite(); + return &((*props_->node_def_.mutable_attr())[name]); +} + void Node::ClearAttr(const string& name) { MaybeCopyOnWrite(); (*props_->node_def_.mutable_attr()).erase(name); @@ -225,17 +274,6 @@ Status Node::input_node(int idx, const Node** const_n) const { return Status::OK(); } -// Node::Properties - -Node::Properties::Properties(const OpDef* op_def, const NodeDef& node_def, - const DataTypeSlice inputs, - const DataTypeSlice outputs) - : op_def_(op_def), - node_def_(node_def), - input_types_(inputs.begin(), inputs.end()), - output_types_(outputs.begin(), outputs.end()) {} - -Node::Properties::~Properties() {} // Graph @@ -300,14 +338,14 @@ Node* Graph::AddNode(const NodeDef& node_def, Status* status) { } Node* node = AllocateNode( - new Node::Properties(op_def, node_def, inputs, outputs), nullptr); + new NodeProperties(op_def, node_def, inputs, outputs), nullptr); return node; } Node* Graph::CopyNode(Node* node) { DCHECK(!node->IsSource()); DCHECK(!node->IsSink()); - Node::Properties* props = node->properties(); + NodeProperties* props = node->properties(); props->Ref(); Node* copy = AllocateNode(props, node); copy->set_assigned_device_name(node->assigned_device_name()); @@ -502,7 +540,7 @@ bool Graph::IsValidNode(Node* node) const { return nodes_[id] == node; } -Node* Graph::AllocateNode(Node::Properties* props, const Node* cost_node) { +Node* Graph::AllocateNode(NodeProperties* props, const Node* cost_node) { Node* node = nullptr; if (free_nodes_.empty()) { node = new (arena_.Alloc(sizeof(Node))) Node; // placement new |