diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-10-30 08:07:11 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-30 08:10:56 -0700 |
commit | ce0238198052358d102ca7786ad9be60a5e76d28 (patch) | |
tree | b1694c3fe23b4933b7967f9494cb7337e673b07e /tensorflow/c/c_api.cc | |
parent | ef4490f637e17f3ce599f55522e63d06f470e540 (diff) |
Add ability to fetch return nodes and unused input mappings from C API GraphDef import
This change introduces yet another ImportGraphDef function to the C
API (TF_GraphImportGraphDefWithResults), but this one has extensible
return values so we shouldn't have to add more in the future.
This change also modifies the ImportGraphDef C interface to manage all
string data for the user.
PiperOrigin-RevId: 173894710
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r-- | tensorflow/c/c_api.cc | 227 |
1 files changed, 159 insertions, 68 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index cd98393e0a..b43d202f4e 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -86,6 +86,7 @@ using tensorflow::errors::FailedPrecondition; using tensorflow::errors::InvalidArgument; using tensorflow::gtl::ArraySlice; using tensorflow::mutex_lock; +using tensorflow::string; using tensorflow::strings::StrCat; extern "C" { @@ -366,7 +367,7 @@ namespace { // Reset helper for converting character arrays to string vectors. void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, int ncontainers, TF_Status* status) { - std::vector<tensorflow::string> container_names(ncontainers); + std::vector<string> container_names(ncontainers); for (int i = 0; i < ncontainers; ++i) { container_names[i] = containers[i]; } @@ -482,7 +483,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { const char* limit = input + src_size; *dst = Tensor(static_cast<DataType>(src->dtype), src->shape); - auto dstarray = dst->flat<tensorflow::string>(); + auto dstarray = dst->flat<string>(); for (tensorflow::int64 i = 0; i < num_elements; ++i) { tensorflow::uint64 offset = reinterpret_cast<const tensorflow::uint64*>(input)[i]; @@ -556,9 +557,9 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, // Compute bytes needed for encoding. size_t size = 0; - const auto& srcarray = src.flat<tensorflow::string>(); + const auto& srcarray = src.flat<string>(); for (int i = 0; i < srcarray.size(); ++i) { - const tensorflow::string& s = srcarray(i); + const string& s = srcarray(i); // uint64 starting_offset, TF_StringEncode-d string. size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size()); } @@ -572,7 +573,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, for (int i = 0; i < srcarray.size(); ++i) { *offsets = (dst - data_start); offsets++; - const tensorflow::string& s = srcarray(i); + const string& s = srcarray(i); size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status); if (!status->status.ok()) { status->status = InvalidArgument( @@ -637,10 +638,9 @@ static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs, } } -static bool TF_Run_Inputs( - TF_Tensor* const* c_inputs, - std::vector<std::pair<tensorflow::string, Tensor>>* input_pairs, - TF_Status* status) { +static bool TF_Run_Inputs(TF_Tensor* const* c_inputs, + std::vector<std::pair<string, Tensor>>* input_pairs, + TF_Status* status) { const int ninputs = input_pairs->size(); for (int i = 0; i < ninputs; ++i) { status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second); @@ -652,13 +652,12 @@ static bool TF_Run_Inputs( static void TF_Run_Helper( Session* session, const char* handle, const TF_Buffer* run_options, // Input tensors - const std::vector<std::pair<tensorflow::string, Tensor>>& input_pairs, + const std::vector<std::pair<string, Tensor>>& input_pairs, // Output tensors - const std::vector<tensorflow::string>& output_tensor_names, - TF_Tensor** c_outputs, + const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs, // Target nodes - const std::vector<tensorflow::string>& target_oper_names, - TF_Buffer* run_metadata, TF_Status* status) { + const std::vector<string>& target_oper_names, TF_Buffer* run_metadata, + TF_Status* status) { const int noutputs = output_tensor_names.size(); std::vector<Tensor> outputs(noutputs); Status result; @@ -718,16 +717,16 @@ void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options, const char** c_target_oper_names, int ntargets, TF_Buffer* run_metadata, TF_Status* status) { TF_Run_Setup(noutputs, c_outputs, status); - std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs); + std::vector<std::pair<string, Tensor>> input_pairs(ninputs); if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; for (int i = 0; i < ninputs; ++i) { input_pairs[i].first = c_input_names[i]; } - std::vector<tensorflow::string> output_names(noutputs); + std::vector<string> output_names(noutputs); for (int i = 0; i < noutputs; ++i) { output_names[i] = c_output_names[i]; } - std::vector<tensorflow::string> target_oper_names(ntargets); + std::vector<string> target_oper_names(ntargets); for (int i = 0; i < ntargets; ++i) { target_oper_names[i] = c_target_oper_names[i]; } @@ -745,9 +744,9 @@ void TF_PRunSetup(TF_DeprecatedSession* s, const char** handle, TF_Status* status) { *handle = nullptr; - std::vector<tensorflow::string> input_names(ninputs); - std::vector<tensorflow::string> output_names(noutputs); - std::vector<tensorflow::string> target_oper_names(ntargets); + std::vector<string> input_names(ninputs); + std::vector<string> output_names(noutputs); + std::vector<string> target_oper_names(ntargets); for (int i = 0; i < ninputs; ++i) { input_names[i] = c_input_names[i]; } @@ -757,7 +756,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s, for (int i = 0; i < ntargets; ++i) { target_oper_names[i] = c_target_oper_names[i]; } - tensorflow::string new_handle; + string new_handle; status->status = s->session->PRunSetup(input_names, output_names, target_oper_names, &new_handle); if (status->status.ok()) { @@ -776,17 +775,17 @@ void TF_PRun(TF_DeprecatedSession* s, const char* handle, const char** c_target_oper_names, int ntargets, TF_Status* status) { TF_Run_Setup(noutputs, c_outputs, status); - std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs); + std::vector<std::pair<string, Tensor>> input_pairs(ninputs); if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; for (int i = 0; i < ninputs; ++i) { input_pairs[i].first = c_input_names[i]; } - std::vector<tensorflow::string> output_names(noutputs); + std::vector<string> output_names(noutputs); for (int i = 0; i < noutputs; ++i) { output_names[i] = c_output_names[i]; } - std::vector<tensorflow::string> target_oper_names(ntargets); + std::vector<string> target_oper_names(ntargets); for (int i = 0; i < ntargets; ++i) { target_oper_names[i] = c_target_oper_names[i]; } @@ -881,7 +880,7 @@ TF_Operation* ToOperation(Node* node) { return static_cast<TF_Operation*>(static_cast<void*>(node)); } -tensorflow::string OutputName(const TF_Output& output) { +string OutputName(const TF_Output& output) { return StrCat(output.oper->node.name(), ":", output.index); } @@ -1254,7 +1253,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, return; } desc->colocation_constraints.clear(); - for (const tensorflow::string& location : attr_value.list().s()) { + for (const string& location : attr_value.list().s()) { desc->colocation_constraints.insert(location); } } else { @@ -1276,8 +1275,8 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, if (!desc->colocation_constraints.empty()) { desc->node_builder.Attr( tensorflow::kColocationAttrName, - std::vector<tensorflow::string>(desc->colocation_constraints.begin(), - desc->colocation_constraints.end())); + std::vector<string>(desc->colocation_constraints.begin(), + desc->colocation_constraints.end())); } status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret); @@ -1500,7 +1499,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, for (int i = 0; i < oper->node.op_def().attr_size(); ++i) { const auto& a = oper->node.op_def().attr(i); if (a.name().compare(attr_name) != 0) continue; - const tensorflow::string& typestr = a.type(); + const string& typestr = a.type(); if (typestr == "list(string)") { metadata.type = TF_ATTR_STRING; } else if (typestr == "list(int)") { @@ -1580,7 +1579,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, const auto len = std::min(max_values, attr->list().s_size()); char* p = static_cast<char*>(storage); for (int i = 0; i < len; ++i) { - const tensorflow::string& s = attr->list().s(i); + const string& s = attr->list().s(i); values[i] = p; lengths[i] = s.size(); if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) { @@ -1824,7 +1823,11 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, TF_Output dst) { - opts->opts.input_map[TensorId(src_name, src_index)] = ToTensorId(dst); + opts->tensor_id_data.push_back(src_name); + const string& src_name_str = opts->tensor_id_data.back(); + // We don't need to store dst's name in tensor_id_data, since `dst` must + // outlive the ImportGraphDef call. + opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst); } void TF_ImportGraphDefOptionsRemapControlDependency( @@ -1840,7 +1843,9 @@ extern void TF_ImportGraphDefOptionsAddControlDependency( void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts, const char* oper_name, int index) { - opts->opts.return_tensors.push_back({oper_name, index}); + opts->tensor_id_data.push_back(oper_name); + const string& oper_name_str = opts->tensor_id_data.back(); + opts->opts.return_tensors.emplace_back(oper_name_str, index); } int TF_ImportGraphDefOptionsNumReturnOutputs( @@ -1848,57 +1853,142 @@ int TF_ImportGraphDefOptionsNumReturnOutputs( return opts->opts.return_tensors.size(); } +void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts, + const char* oper_name) { + opts->opts.return_nodes.push_back(oper_name); +} + +int TF_ImportGraphDefOptionsNumReturnOperations( + const TF_ImportGraphDefOptions* opts) { + return opts->opts.return_nodes.size(); +} + +void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results, + int* num_outputs, + TF_Output** outputs) { + *num_outputs = results->return_tensors.size(); + *outputs = results->return_tensors.data(); +} + +void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results, + int* num_opers, + TF_Operation*** opers) { + *num_opers = results->return_nodes.size(); + *opers = results->return_nodes.data(); +} + +void TF_ImportGraphDefResultsUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_unused_input_mappings, + const char*** src_names, int** src_indexes) { + *num_unused_input_mappings = results->unused_key_names.size(); + *src_names = results->unused_key_names.data(); + *src_indexes = results->unused_key_indexes.data(); +} + +void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) { + delete results; +} + static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, const TF_ImportGraphDefOptions* opts, - TF_Output* return_outputs, - int num_return_outputs, TF_Status* status) + TF_ImportGraphDefResults* tf_results, + TF_Status* status) EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { - if (num_return_outputs != opts->opts.return_tensors.size()) { - status->status = InvalidArgument("Expected 'num_return_outputs' to be ", - opts->opts.return_tensors.size(), ", got ", - num_return_outputs); - return; - } - if (num_return_outputs > 0 && return_outputs == nullptr) { - status->status = InvalidArgument( - "'return_outputs' must be preallocated to length ", num_return_outputs); - return; - } const int last_node_id = graph->graph.num_node_ids(); tensorflow::ImportGraphDefResults results; status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, &graph->refiner, &results); if (!status->status.ok()) return; + + // Add new nodes to name_map for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { auto* node = graph->graph.FindNodeId(i); if (node != nullptr) graph->name_map[node->name()] = node; } - DCHECK_EQ(results.return_tensors.size(), num_return_outputs); - for (int i = 0; i < num_return_outputs; ++i) { - return_outputs[i].oper = ToOperation(results.return_tensors[i].first); - return_outputs[i].index = results.return_tensors[i].second; + + // Populate return_tensors + DCHECK(tf_results->return_tensors.empty()); + tf_results->return_tensors.resize(results.return_tensors.size()); + for (int i = 0; i < results.return_tensors.size(); ++i) { + tf_results->return_tensors[i].oper = + ToOperation(results.return_tensors[i].first); + tf_results->return_tensors[i].index = results.return_tensors[i].second; + } + + // Populate return_nodes + DCHECK(tf_results->return_nodes.empty()); + tf_results->return_nodes.resize(results.return_nodes.size()); + for (int i = 0; i < results.return_nodes.size(); ++i) { + tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]); + } + + // Populate unused map keys + DCHECK(tf_results->unused_key_names.empty()); + DCHECK(tf_results->unused_key_indexes.empty()); + DCHECK(tf_results->unused_key_names_data.empty()); + tf_results->unused_key_names.resize(results.unused_input_map_keys.size()); + tf_results->unused_key_indexes.resize(results.unused_input_map_keys.size()); + for (int i = 0; i < results.unused_input_map_keys.size(); ++i) { + TensorId id = results.unused_input_map_keys[i]; + tf_results->unused_key_names_data.push_back(id.first.ToString()); + tf_results->unused_key_names[i] = + tf_results->unused_key_names_data.back().c_str(); + tf_results->unused_key_indexes[i] = id.second; + } +} + +TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( + TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, TF_Status* status) { + GraphDef def; + if (!def.ParseFromArray(graph_def->data, graph_def->length)) { + status->status = InvalidArgument("Invalid GraphDef"); + return nullptr; } + auto results = new TF_ImportGraphDefResults(); + mutex_lock l(graph->mu); + GraphImportGraphDefLocked(graph, def, options, results, status); + if (!status->status.ok()) { + delete results; + return nullptr; + } + return results; } void TF_GraphImportGraphDefWithReturnOutputs( TF_Graph* graph, const TF_Buffer* graph_def, - const TF_ImportGraphDefOptions* opts, TF_Output* return_outputs, + const TF_ImportGraphDefOptions* options, TF_Output* return_outputs, int num_return_outputs, TF_Status* status) { + if (num_return_outputs != options->opts.return_tensors.size()) { + status->status = InvalidArgument("Expected 'num_return_outputs' to be ", + options->opts.return_tensors.size(), + ", got ", num_return_outputs); + return; + } + if (num_return_outputs > 0 && return_outputs == nullptr) { + status->status = InvalidArgument( + "'return_outputs' must be preallocated to length ", num_return_outputs); + return; + } GraphDef def; if (!def.ParseFromArray(graph_def->data, graph_def->length)) { status->status = InvalidArgument("Invalid GraphDef"); return; } + TF_ImportGraphDefResults results; mutex_lock l(graph->mu); - GraphImportGraphDefLocked(graph, def, opts, return_outputs, - num_return_outputs, status); + GraphImportGraphDefLocked(graph, def, options, &results, status); + DCHECK_EQ(results.return_tensors.size(), num_return_outputs); + memcpy(return_outputs, results.return_tensors.data(), + num_return_outputs * sizeof(TF_Output)); } void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Status* status) { - TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, options, nullptr, 0, - status); + TF_ImportGraphDefResults* results = + TF_GraphImportGraphDefWithResults(graph, graph_def, options, status); + TF_DeleteImportGraphDefResults(results); } // While loop functions ------------------------------------------------------- @@ -1930,7 +2020,7 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph, tensorflow::ShapeRefiner* dst_refiner, const TF_Output* src_inputs, const std::vector<tensorflow::Output>& dst_inputs, - const tensorflow::string& prefix, + const string& prefix, const std::vector<tensorflow::Operation>& control_deps, const TF_Output* nodes_to_return, int nreturn_nodes, std::vector<tensorflow::Output>* return_nodes) { @@ -2257,9 +2347,9 @@ TF_Session* TF_LoadSessionFromSavedModel( return nullptr; } - std::unordered_set<tensorflow::string> tag_set; + std::unordered_set<string> tag_set; for (int i = 0; i < tags_len; i++) { - tag_set.insert(tensorflow::string(tags[i])); + tag_set.insert(string(tags[i])); } tensorflow::SavedModelBundle bundle; @@ -2275,8 +2365,9 @@ TF_Session* TF_LoadSessionFromSavedModel( // TODO(jhseu): When Session is modified to take Graphs instead of // GraphDefs, return the Graph generated in LoadSavedModel(). TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions(); + TF_ImportGraphDefResults results; GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(), - import_opts, nullptr, 0, status); + import_opts, &results, status); TF_DeleteImportGraphDefOptions(import_opts); if (TF_GetCode(status) != TF_OK) return nullptr; @@ -2372,20 +2463,20 @@ void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, TF_Run_Setup(noutputs, output_values, status); // Convert from TF_Output and TF_Tensor to a string and Tensor. - std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs); + std::vector<std::pair<string, Tensor>> input_pairs(ninputs); if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; for (int i = 0; i < ninputs; ++i) { input_pairs[i].first = OutputName(inputs[i]); } // Convert from TF_Output to string names. - std::vector<tensorflow::string> output_names(noutputs); + std::vector<string> output_names(noutputs); for (int i = 0; i < noutputs; ++i) { output_names[i] = OutputName(outputs[i]); } // Convert from TF_Operation* to string names. - std::vector<tensorflow::string> target_names(ntargets); + std::vector<string> target_names(ntargets); for (int i = 0; i < ntargets; ++i) { target_names[i] = target_opers[i]->node.name(); } @@ -2406,22 +2497,22 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, return; } - std::vector<tensorflow::string> input_names(ninputs); + std::vector<string> input_names(ninputs); for (int i = 0; i < ninputs; ++i) { input_names[i] = OutputName(inputs[i]); } - std::vector<tensorflow::string> output_names(noutputs); + std::vector<string> output_names(noutputs); for (int i = 0; i < noutputs; ++i) { output_names[i] = OutputName(outputs[i]); } - std::vector<tensorflow::string> target_names(ntargets); + std::vector<string> target_names(ntargets); for (int i = 0; i < ntargets; ++i) { target_names[i] = target_opers[i]->node.name(); } - tensorflow::string new_handle; + string new_handle; status->status = session->session->PRunSetup(input_names, output_names, target_names, &new_handle); if (status->status.ok()) { @@ -2452,20 +2543,20 @@ void TF_SessionPRun(TF_Session* session, const char* handle, TF_Run_Setup(noutputs, output_values, status); // Convert from TF_Output and TF_Tensor to a string and Tensor. - std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs); + std::vector<std::pair<string, Tensor>> input_pairs(ninputs); if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; for (int i = 0; i < ninputs; ++i) { input_pairs[i].first = OutputName(inputs[i]); } // Convert from TF_Output to string names. - std::vector<tensorflow::string> output_names(noutputs); + std::vector<string> output_names(noutputs); for (int i = 0; i < noutputs; ++i) { output_names[i] = OutputName(outputs[i]); } // Convert from TF_Operation* to string names. - std::vector<tensorflow::string> target_names(ntargets); + std::vector<string> target_names(ntargets); for (int i = 0; i < ntargets; ++i) { target_names[i] = target_opers[i]->node.name(); } |