aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph.cc
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-08-25 08:55:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-25 10:02:59 -0700
commit6bca9eab85082cf3bc988f9588ddddfb3ed60cb1 (patch)
tree4a018a8bd7f287ef098e99e789f86e986f41a984 /tensorflow/core/graph/graph.cc
parenta85ed2e4601a8111d25f30b14421cac7a94ad372 (diff)
Add to the C++ Node class the ability to fetch input nodes and edges
by index. There are various locations in code where users currently use iteration to find the edge by its already known index, and these functions would be useful to accomplish. In addition, this implements the equivalent functionality of 'op.inputs[i]' in our python Operation class. Given the new functionality, it exposed a weird use of NoOp for nodes that actually had multiple inputs. Modified the test to use custom op definitions to be more correct. Currently this iterates over the edge list, which in the common case will be fast and introduces no additional state to Node. In the future we may want to revisit this. Change: 131299794
Diffstat (limited to 'tensorflow/core/graph/graph.cc')
-rw-r--r--tensorflow/core/graph/graph.cc37
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 95f1c06a2c..648f439607 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -144,6 +144,43 @@ void Node::ClearAttr(const string& name) {
(*props_->node_def_.mutable_attr()).erase(name);
}
+Status Node::input_edge(int idx, const Edge** e) const {
+ if (idx < 0 || idx >= num_inputs()) {
+ return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ",
+ name(), " only has ", num_inputs(),
+ " inputs.");
+ }
+
+ // This does a linear search over the edges. In the common case,
+ // the number of elements is small enough that this search isn't
+ // expensive. Should it become a bottleneck, one can make an
+ // optimization where, if the number of edges is small, we use
+ // linear iteration, and if the number of edges is large, we perform
+ // an indexing step during construction that keeps an array of Edges
+ // indexed by pointer. This would keep the size of each Node small
+ // in the common case but make this function faster when the number
+ // of edges is large.
+ for (const Edge* edge : in_edges()) {
+ if (edge->dst_input() == idx) {
+ *e = edge;
+ return Status::OK();
+ }
+ }
+
+ return errors::NotFound("Could not find input edge ", idx, " for ", name());
+}
+
+Status Node::input_node(int idx, const Node** n) const {
+ const Edge* e;
+ TF_RETURN_IF_ERROR(input_edge(idx, &e));
+ if (e == nullptr) {
+ *n = nullptr;
+ } else {
+ *n = e->src();
+ }
+ return Status::OK();
+}
+
// Node::Properties
Node::Properties::Properties(const OpDef* op_def, const NodeDef& node_def,