diff options
author | Olivia Nordquist <nolivia@google.com> | 2017-09-26 19:56:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-26 20:13:58 -0700 |
commit | c65b9f87d91f51a233cb649f4d1a5b5f63a4d5e1 (patch) | |
tree | c3c40e0fc0a11857151c1f00c1dd648684d28e50 /tensorflow/core/graph/graph.cc | |
parent | 035a9be3cce366ceb57e3bb8d7a436135501061b (diff) |
implementing _update_input for the C API
PiperOrigin-RevId: 170147211
Diffstat (limited to 'tensorflow/core/graph/graph.cc')
-rw-r--r-- | tensorflow/core/graph/graph.cc | 45 |
1 files changed, 36 insertions, 9 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 45ab38c395..2ad0081e1f 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -261,7 +261,6 @@ Status Node::input_node(int idx, const Node** const_n) const { return Status::OK(); } - // Graph Graph::Graph(const OpRegistryInterface* ops) @@ -420,6 +419,34 @@ void Graph::RemoveEdge(const Edge* e) { --num_edges_; } +Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst, + int dst_index) { + TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index)); + TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index)); + const Edge* e = FindEdge(dst, dst_index); + if (e == nullptr) { + return errors::InvalidArgument("Couldn't find edge to ", + dst->DebugString()); + } + RemoveEdge(e); + AddEdge(new_src, new_src_index, dst, dst_index); + dst->MaybeCopyOnWrite(); + (*dst->props_->node_def.mutable_input())[dst_index] = + strings::StrCat(new_src->name(), ":", new_src_index); + return Status::OK(); +} + +const Edge* Graph::FindEdge(const Node* dst, int index) { + for (const Edge* e : edges_) { + // edges_ will contain null edges if RemoveEdge() was called. + if (e == nullptr) continue; + if (e->dst() == dst && e->dst_input() == index) { + return e; + } + } + return nullptr; +} + Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { return ops_.AddLibrary(fdef_lib); } @@ -528,10 +555,10 @@ Status Graph::IsValidNode(const Node* node) const { Status Graph::IsValidOutputTensor(const Node* node, int idx) const { TF_RETURN_IF_ERROR(IsValidNode(node)); if (idx >= node->num_outputs()) { - return errors::InvalidArgument("Node '", node->name(), "' (type: '", - node->op_def().name(), - "', num of outputs: ", node->num_outputs(), - ") does not have ", "output ", idx); + return errors::OutOfRange("Node '", node->name(), "' (type: '", + node->op_def().name(), + "', num of outputs: ", node->num_outputs(), + ") does not have ", "output ", idx); } return Status::OK(); } @@ -539,10 +566,10 @@ Status Graph::IsValidOutputTensor(const Node* node, int idx) const { Status Graph::IsValidInputTensor(const Node* node, int idx) const { TF_RETURN_IF_ERROR(IsValidNode(node)); if (idx >= node->num_inputs()) { - return errors::InvalidArgument("Node '", node->name(), "' (type: '", - node->op_def().name(), - "', num of inputs: ", node->num_inputs(), - ") does not have ", "input ", idx); + return errors::OutOfRange("Node '", node->name(), "' (type: '", + node->op_def().name(), + "', num of inputs: ", node->num_inputs(), + ") does not have ", "input ", idx); } return Status::OK(); } |