aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph.cc
diff options
context:
space:
mode:
authorGravatar Olivia Nordquist <nolivia@google.com>2017-09-26 19:56:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-26 20:13:58 -0700
commitc65b9f87d91f51a233cb649f4d1a5b5f63a4d5e1 (patch)
treec3c40e0fc0a11857151c1f00c1dd648684d28e50 /tensorflow/core/graph/graph.cc
parent035a9be3cce366ceb57e3bb8d7a436135501061b (diff)
implementing _update_input for the C API
PiperOrigin-RevId: 170147211
Diffstat (limited to 'tensorflow/core/graph/graph.cc')
-rw-r--r--tensorflow/core/graph/graph.cc45
1 files changed, 36 insertions, 9 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 45ab38c395..2ad0081e1f 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -261,7 +261,6 @@ Status Node::input_node(int idx, const Node** const_n) const {
return Status::OK();
}
-
// Graph
Graph::Graph(const OpRegistryInterface* ops)
@@ -420,6 +419,34 @@ void Graph::RemoveEdge(const Edge* e) {
--num_edges_;
}
+Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
+ int dst_index) {
+ TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
+ TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
+ const Edge* e = FindEdge(dst, dst_index);
+ if (e == nullptr) {
+ return errors::InvalidArgument("Couldn't find edge to ",
+ dst->DebugString());
+ }
+ RemoveEdge(e);
+ AddEdge(new_src, new_src_index, dst, dst_index);
+ dst->MaybeCopyOnWrite();
+ (*dst->props_->node_def.mutable_input())[dst_index] =
+ strings::StrCat(new_src->name(), ":", new_src_index);
+ return Status::OK();
+}
+
+const Edge* Graph::FindEdge(const Node* dst, int index) {
+ for (const Edge* e : edges_) {
+ // edges_ will contain null edges if RemoveEdge() was called.
+ if (e == nullptr) continue;
+ if (e->dst() == dst && e->dst_input() == index) {
+ return e;
+ }
+ }
+ return nullptr;
+}
+
Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
return ops_.AddLibrary(fdef_lib);
}
@@ -528,10 +555,10 @@ Status Graph::IsValidNode(const Node* node) const {
Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
TF_RETURN_IF_ERROR(IsValidNode(node));
if (idx >= node->num_outputs()) {
- return errors::InvalidArgument("Node '", node->name(), "' (type: '",
- node->op_def().name(),
- "', num of outputs: ", node->num_outputs(),
- ") does not have ", "output ", idx);
+ return errors::OutOfRange("Node '", node->name(), "' (type: '",
+ node->op_def().name(),
+ "', num of outputs: ", node->num_outputs(),
+ ") does not have ", "output ", idx);
}
return Status::OK();
}
@@ -539,10 +566,10 @@ Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
Status Graph::IsValidInputTensor(const Node* node, int idx) const {
TF_RETURN_IF_ERROR(IsValidNode(node));
if (idx >= node->num_inputs()) {
- return errors::InvalidArgument("Node '", node->name(), "' (type: '",
- node->op_def().name(),
- "', num of inputs: ", node->num_inputs(),
- ") does not have ", "input ", idx);
+ return errors::OutOfRange("Node '", node->name(), "' (type: '",
+ node->op_def().name(),
+ "', num of inputs: ", node->num_inputs(),
+ ") does not have ", "input ", idx);
}
return Status::OK();
}