diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-10-16 15:29:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-16 15:34:07 -0700 |
commit | dc442f4ce2d3b11b56721337fe2b9e2282be93be (patch) | |
tree | ee2d7796823a1430bc4c7a9f2dd577204aa28321 /tensorflow/c/c_api.cc | |
parent | 7b6eec7e1175624458a48945bba3f6400e754d33 (diff) |
Add return_nodes option to ImportGraphDef
The is similar to the return_tensors option. return_tensors cannot be
used to fetch nodes with no outputs, so return_nodes is necessary.
In addition, this change also refactors the ImportGraphDef signature
to return all optional return values in a single struct. This is to
keep the ImportGraphDef signature from getting too long, and also
makes the call sites simpler.
PiperOrigin-RevId: 172388270
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r-- | tensorflow/c/c_api.cc | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 334f867e47..79fbd8c90c 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -1854,18 +1854,18 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, return; } const int last_node_id = graph->graph.num_node_ids(); - std::vector<std::pair<Node*, int>> return_outputs_vec; - status->status = tensorflow::ImportGraphDef( - opts->opts, def, &graph->graph, &graph->refiner, &return_outputs_vec); + tensorflow::ImportGraphDefResults results; + status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, + &graph->refiner, &results); if (!status->status.ok()) return; for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { auto* node = graph->graph.FindNodeId(i); if (node != nullptr) graph->name_map[node->name()] = node; } - DCHECK_EQ(return_outputs_vec.size(), num_return_outputs); + DCHECK_EQ(results.return_tensors.size(), num_return_outputs); for (int i = 0; i < num_return_outputs; ++i) { - return_outputs[i].oper = ToOperation(return_outputs_vec[i].first); - return_outputs[i].index = return_outputs_vec[i].second; + return_outputs[i].oper = ToOperation(results.return_tensors[i].first); + return_outputs[i].index = results.return_tensors[i].second; } } @@ -1945,11 +1945,11 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph, } // TOOD(skyewm): change to OutputTensor - std::vector<std::pair<Node*, int>> return_tensors; + tensorflow::ImportGraphDefResults results; TF_RETURN_IF_ERROR( - ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &return_tensors)); + ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results)); - for (const auto& pair : return_tensors) { + for (const auto& pair : results.return_tensors) { return_nodes->emplace_back(pair.first, pair.second); } return Status::OK(); |