diff options
author | Vijay Vasudevan <vrv@google.com> | 2016-08-25 08:55:38 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-25 10:02:59 -0700 |
commit | 6bca9eab85082cf3bc988f9588ddddfb3ed60cb1 (patch) | |
tree | 4a018a8bd7f287ef098e99e789f86e986f41a984 /tensorflow/core/graph/graph.cc | |
parent | a85ed2e4601a8111d25f30b14421cac7a94ad372 (diff) |
Add to the C++ Node class the ability to fetch input nodes and edges
by index. There are various locations in code where users currently
use iteration to find the edge by its already known index, and these
functions would be useful to accomplish.
In addition, this implements the equivalent functionality
of 'op.inputs[i]' in our python Operation class.
Given the new functionality, it exposed a weird use of NoOp for nodes
that actually had multiple inputs. Modified the test to use custom op
definitions to be more correct.
Currently this iterates over the edge list, which in the common case
will be fast and introduces no additional state to Node. In the future
we may want to revisit this.
Change: 131299794
Diffstat (limited to 'tensorflow/core/graph/graph.cc')
-rw-r--r-- | tensorflow/core/graph/graph.cc | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 95f1c06a2c..648f439607 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -144,6 +144,43 @@ void Node::ClearAttr(const string& name) { (*props_->node_def_.mutable_attr()).erase(name); } +Status Node::input_edge(int idx, const Edge** e) const { + if (idx < 0 || idx >= num_inputs()) { + return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ", + name(), " only has ", num_inputs(), + " inputs."); + } + + // This does a linear search over the edges. In the common case, + // the number of elements is small enough that this search isn't + // expensive. Should it become a bottleneck, one can make an + // optimization where, if the number of edges is small, we use + // linear iteration, and if the number of edges is large, we perform + // an indexing step during construction that keeps an array of Edges + // indexed by pointer. This would keep the size of each Node small + // in the common case but make this function faster when the number + // of edges is large. + for (const Edge* edge : in_edges()) { + if (edge->dst_input() == idx) { + *e = edge; + return Status::OK(); + } + } + + return errors::NotFound("Could not find input edge ", idx, " for ", name()); +} + +Status Node::input_node(int idx, const Node** n) const { + const Edge* e; + TF_RETURN_IF_ERROR(input_edge(idx, &e)); + if (e == nullptr) { + *n = nullptr; + } else { + *n = e->src(); + } + return Status::OK(); +} + // Node::Properties Node::Properties::Properties(const OpDef* op_def, const NodeDef& node_def, |