aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph.cc
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2017-06-23 12:55:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-23 12:59:28 -0700
commit6ada43366663210beb0159b8c1a67b26ebfe6cb7 (patch)
tree1f41bb5c9d2000cb4dd47645f57c181aef3fae3e /tensorflow/core/graph/graph.cc
parent0eff699d3087171cf35671d9d0bd6f8e79441ab3 (diff)
Prepare to not include node_def.proto.h in node_def_util.h
The goal is to make kernels mostly independent of proto headers, which will let us lock down our .so imports. This CL makes a bunch of .cc files either include node_def.proto.h themselves or not need the definition of NodeDef; a second CL will make node_def_util.h not include node_def.proto.h. RELNOTES: n/a PiperOrigin-RevId: 159982117
Diffstat (limited to 'tensorflow/core/graph/graph.cc')
-rw-r--r--tensorflow/core/graph/graph.cc76
1 files changed, 57 insertions, 19 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index dcb8520cf7..5d60e41d26 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include <vector>
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -28,6 +29,26 @@ namespace tensorflow {
const int Graph::kControlSlot = -1;
+class NodeProperties : public core::RefCounted {
+ public:
+ NodeProperties(const OpDef* op_def, const NodeDef& node_def,
+ const DataTypeSlice inputs, const DataTypeSlice outputs)
+ : op_def_(op_def),
+ node_def_(node_def),
+ input_types_(inputs.begin(), inputs.end()),
+ output_types_(outputs.begin(), outputs.end()) {}
+
+ const OpDef* op_def_; // not owned
+ NodeDef node_def_;
+ const DataTypeVector input_types_;
+ const DataTypeVector output_types_;
+
+ private:
+ // Destructor invoked when last reference goes away via Unref()
+ ~NodeProperties() override {}
+ TF_DISALLOW_COPY_AND_ASSIGN(NodeProperties);
+};
+
// Node
#define REF_CLASS(key, value) \
@@ -99,7 +120,7 @@ Node::~Node() {
}
}
-void Node::Initialize(int id, int cost_id, Properties* props) {
+void Node::Initialize(int id, int cost_id, NodeProperties* props) {
DCHECK_EQ(id_, -1);
DCHECK(in_edges_.empty());
DCHECK(out_edges_.empty());
@@ -130,6 +151,29 @@ void Node::Clear() {
assigned_device_name_index_ = 0;
}
+const string& Node::name() const { return props_->node_def_.name(); }
+const string& Node::type_string() const { return props_->node_def_.op(); }
+const NodeDef& Node::def() const { return props_->node_def_; }
+const OpDef& Node::op_def() const { return *props_->op_def_; }
+
+int32 Node::num_inputs() const { return props_->input_types_.size(); }
+DataType Node::input_type(int32 i) const { return props_->input_types_[i]; }
+const DataTypeVector& Node::input_types() const { return props_->input_types_; }
+
+int32 Node::num_outputs() const { return props_->output_types_.size(); }
+DataType Node::output_type(int32 o) const { return props_->output_types_[o]; }
+const DataTypeVector& Node::output_types() const {
+ return props_->output_types_;
+}
+
+AttrSlice Node::attrs() const { return AttrSlice(def()); }
+
+const protobuf::RepeatedPtrField<string>& Node::requested_inputs() const {
+ return def().input();
+}
+
+const string& Node::requested_device() const { return def().device(); }
+
gtl::iterator_range<NeighborIter> Node::out_nodes() const {
return gtl::make_range(NeighborIter(out_edges_.begin(), false),
NeighborIter(out_edges_.end(), false));
@@ -141,16 +185,21 @@ gtl::iterator_range<NeighborIter> Node::in_nodes() const {
}
void Node::MaybeCopyOnWrite() {
- // Properties may be shared between Nodes. Make a copy if so.
+ // NodeProperties may be shared between Nodes. Make a copy if so.
if (!props_->RefCountIsOne()) {
- Properties* new_props =
- new Properties(props_->op_def_, props_->node_def_, props_->input_types_,
- props_->output_types_);
+ NodeProperties* new_props =
+ new NodeProperties(props_->op_def_, props_->node_def_,
+ props_->input_types_, props_->output_types_);
props_->Unref();
props_ = new_props;
}
}
+AttrValue* Node::AddAttrHelper(const string& name) {
+ MaybeCopyOnWrite();
+ return &((*props_->node_def_.mutable_attr())[name]);
+}
+
void Node::ClearAttr(const string& name) {
MaybeCopyOnWrite();
(*props_->node_def_.mutable_attr()).erase(name);
@@ -225,17 +274,6 @@ Status Node::input_node(int idx, const Node** const_n) const {
return Status::OK();
}
-// Node::Properties
-
-Node::Properties::Properties(const OpDef* op_def, const NodeDef& node_def,
- const DataTypeSlice inputs,
- const DataTypeSlice outputs)
- : op_def_(op_def),
- node_def_(node_def),
- input_types_(inputs.begin(), inputs.end()),
- output_types_(outputs.begin(), outputs.end()) {}
-
-Node::Properties::~Properties() {}
// Graph
@@ -300,14 +338,14 @@ Node* Graph::AddNode(const NodeDef& node_def, Status* status) {
}
Node* node = AllocateNode(
- new Node::Properties(op_def, node_def, inputs, outputs), nullptr);
+ new NodeProperties(op_def, node_def, inputs, outputs), nullptr);
return node;
}
Node* Graph::CopyNode(Node* node) {
DCHECK(!node->IsSource());
DCHECK(!node->IsSink());
- Node::Properties* props = node->properties();
+ NodeProperties* props = node->properties();
props->Ref();
Node* copy = AllocateNode(props, node);
copy->set_assigned_device_name(node->assigned_device_name());
@@ -502,7 +540,7 @@ bool Graph::IsValidNode(Node* node) const {
return nodes_[id] == node;
}
-Node* Graph::AllocateNode(Node::Properties* props, const Node* cost_node) {
+Node* Graph::AllocateNode(NodeProperties* props, const Node* cost_node) {
Node* node = nullptr;
if (free_nodes_.empty()) {
node = new (arena_.Alloc(sizeof(Node))) Node; // placement new