aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-10-16 15:29:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-16 15:34:07 -0700
commitdc442f4ce2d3b11b56721337fe2b9e2282be93be (patch)
treeee2d7796823a1430bc4c7a9f2dd577204aa28321 /tensorflow/c/c_api.cc
parent7b6eec7e1175624458a48945bba3f6400e754d33 (diff)
Add return_nodes option to ImportGraphDef
The is similar to the return_tensors option. return_tensors cannot be used to fetch nodes with no outputs, so return_nodes is necessary. In addition, this change also refactors the ImportGraphDef signature to return all optional return values in a single struct. This is to keep the ImportGraphDef signature from getting too long, and also makes the call sites simpler. PiperOrigin-RevId: 172388270
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r--tensorflow/c/c_api.cc18
1 files changed, 9 insertions, 9 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 334f867e47..79fbd8c90c 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -1854,18 +1854,18 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
return;
}
const int last_node_id = graph->graph.num_node_ids();
- std::vector<std::pair<Node*, int>> return_outputs_vec;
- status->status = tensorflow::ImportGraphDef(
- opts->opts, def, &graph->graph, &graph->refiner, &return_outputs_vec);
+ tensorflow::ImportGraphDefResults results;
+ status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
+ &graph->refiner, &results);
if (!status->status.ok()) return;
for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
auto* node = graph->graph.FindNodeId(i);
if (node != nullptr) graph->name_map[node->name()] = node;
}
- DCHECK_EQ(return_outputs_vec.size(), num_return_outputs);
+ DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
for (int i = 0; i < num_return_outputs; ++i) {
- return_outputs[i].oper = ToOperation(return_outputs_vec[i].first);
- return_outputs[i].index = return_outputs_vec[i].second;
+ return_outputs[i].oper = ToOperation(results.return_tensors[i].first);
+ return_outputs[i].index = results.return_tensors[i].second;
}
}
@@ -1945,11 +1945,11 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph,
}
// TOOD(skyewm): change to OutputTensor
- std::vector<std::pair<Node*, int>> return_tensors;
+ tensorflow::ImportGraphDefResults results;
TF_RETURN_IF_ERROR(
- ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &return_tensors));
+ ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
- for (const auto& pair : return_tensors) {
+ for (const auto& pair : results.return_tensors) {
return_nodes->emplace_back(pair.first, pair.second);
}
return Status::OK();