aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-10-30 08:07:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-30 08:10:56 -0700
commitce0238198052358d102ca7786ad9be60a5e76d28 (patch)
treeb1694c3fe23b4933b7967f9494cb7337e673b07e /tensorflow/c/c_api.cc
parentef4490f637e17f3ce599f55522e63d06f470e540 (diff)
Add ability to fetch return nodes and unused input mappings from C API GraphDef import
This change introduces yet another ImportGraphDef function to the C API (TF_GraphImportGraphDefWithResults), but this one has extensible return values so we shouldn't have to add more in the future. This change also modifies the ImportGraphDef C interface to manage all string data for the user. PiperOrigin-RevId: 173894710
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r--tensorflow/c/c_api.cc227
1 files changed, 159 insertions, 68 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index cd98393e0a..b43d202f4e 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -86,6 +86,7 @@ using tensorflow::errors::FailedPrecondition;
using tensorflow::errors::InvalidArgument;
using tensorflow::gtl::ArraySlice;
using tensorflow::mutex_lock;
+using tensorflow::string;
using tensorflow::strings::StrCat;
extern "C" {
@@ -366,7 +367,7 @@ namespace {
// Reset helper for converting character arrays to string vectors.
void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers,
int ncontainers, TF_Status* status) {
- std::vector<tensorflow::string> container_names(ncontainers);
+ std::vector<string> container_names(ncontainers);
for (int i = 0; i < ncontainers; ++i) {
container_names[i] = containers[i];
}
@@ -482,7 +483,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
const char* limit = input + src_size;
*dst = Tensor(static_cast<DataType>(src->dtype), src->shape);
- auto dstarray = dst->flat<tensorflow::string>();
+ auto dstarray = dst->flat<string>();
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
tensorflow::uint64 offset =
reinterpret_cast<const tensorflow::uint64*>(input)[i];
@@ -556,9 +557,9 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
// Compute bytes needed for encoding.
size_t size = 0;
- const auto& srcarray = src.flat<tensorflow::string>();
+ const auto& srcarray = src.flat<string>();
for (int i = 0; i < srcarray.size(); ++i) {
- const tensorflow::string& s = srcarray(i);
+ const string& s = srcarray(i);
// uint64 starting_offset, TF_StringEncode-d string.
size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size());
}
@@ -572,7 +573,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
for (int i = 0; i < srcarray.size(); ++i) {
*offsets = (dst - data_start);
offsets++;
- const tensorflow::string& s = srcarray(i);
+ const string& s = srcarray(i);
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
if (!status->status.ok()) {
status->status = InvalidArgument(
@@ -637,10 +638,9 @@ static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
}
}
-static bool TF_Run_Inputs(
- TF_Tensor* const* c_inputs,
- std::vector<std::pair<tensorflow::string, Tensor>>* input_pairs,
- TF_Status* status) {
+static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
+ std::vector<std::pair<string, Tensor>>* input_pairs,
+ TF_Status* status) {
const int ninputs = input_pairs->size();
for (int i = 0; i < ninputs; ++i) {
status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
@@ -652,13 +652,12 @@ static bool TF_Run_Inputs(
static void TF_Run_Helper(
Session* session, const char* handle, const TF_Buffer* run_options,
// Input tensors
- const std::vector<std::pair<tensorflow::string, Tensor>>& input_pairs,
+ const std::vector<std::pair<string, Tensor>>& input_pairs,
// Output tensors
- const std::vector<tensorflow::string>& output_tensor_names,
- TF_Tensor** c_outputs,
+ const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
// Target nodes
- const std::vector<tensorflow::string>& target_oper_names,
- TF_Buffer* run_metadata, TF_Status* status) {
+ const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
+ TF_Status* status) {
const int noutputs = output_tensor_names.size();
std::vector<Tensor> outputs(noutputs);
Status result;
@@ -718,16 +717,16 @@ void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
const char** c_target_oper_names, int ntargets,
TF_Buffer* run_metadata, TF_Status* status) {
TF_Run_Setup(noutputs, c_outputs, status);
- std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs);
+ std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
for (int i = 0; i < ninputs; ++i) {
input_pairs[i].first = c_input_names[i];
}
- std::vector<tensorflow::string> output_names(noutputs);
+ std::vector<string> output_names(noutputs);
for (int i = 0; i < noutputs; ++i) {
output_names[i] = c_output_names[i];
}
- std::vector<tensorflow::string> target_oper_names(ntargets);
+ std::vector<string> target_oper_names(ntargets);
for (int i = 0; i < ntargets; ++i) {
target_oper_names[i] = c_target_oper_names[i];
}
@@ -745,9 +744,9 @@ void TF_PRunSetup(TF_DeprecatedSession* s,
const char** handle, TF_Status* status) {
*handle = nullptr;
- std::vector<tensorflow::string> input_names(ninputs);
- std::vector<tensorflow::string> output_names(noutputs);
- std::vector<tensorflow::string> target_oper_names(ntargets);
+ std::vector<string> input_names(ninputs);
+ std::vector<string> output_names(noutputs);
+ std::vector<string> target_oper_names(ntargets);
for (int i = 0; i < ninputs; ++i) {
input_names[i] = c_input_names[i];
}
@@ -757,7 +756,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s,
for (int i = 0; i < ntargets; ++i) {
target_oper_names[i] = c_target_oper_names[i];
}
- tensorflow::string new_handle;
+ string new_handle;
status->status = s->session->PRunSetup(input_names, output_names,
target_oper_names, &new_handle);
if (status->status.ok()) {
@@ -776,17 +775,17 @@ void TF_PRun(TF_DeprecatedSession* s, const char* handle,
const char** c_target_oper_names, int ntargets,
TF_Status* status) {
TF_Run_Setup(noutputs, c_outputs, status);
- std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs);
+ std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
for (int i = 0; i < ninputs; ++i) {
input_pairs[i].first = c_input_names[i];
}
- std::vector<tensorflow::string> output_names(noutputs);
+ std::vector<string> output_names(noutputs);
for (int i = 0; i < noutputs; ++i) {
output_names[i] = c_output_names[i];
}
- std::vector<tensorflow::string> target_oper_names(ntargets);
+ std::vector<string> target_oper_names(ntargets);
for (int i = 0; i < ntargets; ++i) {
target_oper_names[i] = c_target_oper_names[i];
}
@@ -881,7 +880,7 @@ TF_Operation* ToOperation(Node* node) {
return static_cast<TF_Operation*>(static_cast<void*>(node));
}
-tensorflow::string OutputName(const TF_Output& output) {
+string OutputName(const TF_Output& output) {
return StrCat(output.oper->node.name(), ":", output.index);
}
@@ -1254,7 +1253,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
return;
}
desc->colocation_constraints.clear();
- for (const tensorflow::string& location : attr_value.list().s()) {
+ for (const string& location : attr_value.list().s()) {
desc->colocation_constraints.insert(location);
}
} else {
@@ -1276,8 +1275,8 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
if (!desc->colocation_constraints.empty()) {
desc->node_builder.Attr(
tensorflow::kColocationAttrName,
- std::vector<tensorflow::string>(desc->colocation_constraints.begin(),
- desc->colocation_constraints.end()));
+ std::vector<string>(desc->colocation_constraints.begin(),
+ desc->colocation_constraints.end()));
}
status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret);
@@ -1500,7 +1499,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
const auto& a = oper->node.op_def().attr(i);
if (a.name().compare(attr_name) != 0) continue;
- const tensorflow::string& typestr = a.type();
+ const string& typestr = a.type();
if (typestr == "list(string)") {
metadata.type = TF_ATTR_STRING;
} else if (typestr == "list(int)") {
@@ -1580,7 +1579,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
const auto len = std::min(max_values, attr->list().s_size());
char* p = static_cast<char*>(storage);
for (int i = 0; i < len; ++i) {
- const tensorflow::string& s = attr->list().s(i);
+ const string& s = attr->list().s(i);
values[i] = p;
lengths[i] = s.size();
if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
@@ -1824,7 +1823,11 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
const char* src_name,
int src_index, TF_Output dst) {
- opts->opts.input_map[TensorId(src_name, src_index)] = ToTensorId(dst);
+ opts->tensor_id_data.push_back(src_name);
+ const string& src_name_str = opts->tensor_id_data.back();
+ // We don't need to store dst's name in tensor_id_data, since `dst` must
+ // outlive the ImportGraphDef call.
+ opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
}
void TF_ImportGraphDefOptionsRemapControlDependency(
@@ -1840,7 +1843,9 @@ extern void TF_ImportGraphDefOptionsAddControlDependency(
void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
const char* oper_name, int index) {
- opts->opts.return_tensors.push_back({oper_name, index});
+ opts->tensor_id_data.push_back(oper_name);
+ const string& oper_name_str = opts->tensor_id_data.back();
+ opts->opts.return_tensors.emplace_back(oper_name_str, index);
}
int TF_ImportGraphDefOptionsNumReturnOutputs(
@@ -1848,57 +1853,142 @@ int TF_ImportGraphDefOptionsNumReturnOutputs(
return opts->opts.return_tensors.size();
}
+void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
+ const char* oper_name) {
+ opts->opts.return_nodes.push_back(oper_name);
+}
+
+int TF_ImportGraphDefOptionsNumReturnOperations(
+ const TF_ImportGraphDefOptions* opts) {
+ return opts->opts.return_nodes.size();
+}
+
+void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
+ int* num_outputs,
+ TF_Output** outputs) {
+ *num_outputs = results->return_tensors.size();
+ *outputs = results->return_tensors.data();
+}
+
+void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
+ int* num_opers,
+ TF_Operation*** opers) {
+ *num_opers = results->return_nodes.size();
+ *opers = results->return_nodes.data();
+}
+
+void TF_ImportGraphDefResultsUnusedInputMappings(
+ TF_ImportGraphDefResults* results, int* num_unused_input_mappings,
+ const char*** src_names, int** src_indexes) {
+ *num_unused_input_mappings = results->unused_key_names.size();
+ *src_names = results->unused_key_names.data();
+ *src_indexes = results->unused_key_indexes.data();
+}
+
+void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
+ delete results;
+}
+
static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
const TF_ImportGraphDefOptions* opts,
- TF_Output* return_outputs,
- int num_return_outputs, TF_Status* status)
+ TF_ImportGraphDefResults* tf_results,
+ TF_Status* status)
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
- if (num_return_outputs != opts->opts.return_tensors.size()) {
- status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
- opts->opts.return_tensors.size(), ", got ",
- num_return_outputs);
- return;
- }
- if (num_return_outputs > 0 && return_outputs == nullptr) {
- status->status = InvalidArgument(
- "'return_outputs' must be preallocated to length ", num_return_outputs);
- return;
- }
const int last_node_id = graph->graph.num_node_ids();
tensorflow::ImportGraphDefResults results;
status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
&graph->refiner, &results);
if (!status->status.ok()) return;
+
+ // Add new nodes to name_map
for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
auto* node = graph->graph.FindNodeId(i);
if (node != nullptr) graph->name_map[node->name()] = node;
}
- DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
- for (int i = 0; i < num_return_outputs; ++i) {
- return_outputs[i].oper = ToOperation(results.return_tensors[i].first);
- return_outputs[i].index = results.return_tensors[i].second;
+
+ // Populate return_tensors
+ DCHECK(tf_results->return_tensors.empty());
+ tf_results->return_tensors.resize(results.return_tensors.size());
+ for (int i = 0; i < results.return_tensors.size(); ++i) {
+ tf_results->return_tensors[i].oper =
+ ToOperation(results.return_tensors[i].first);
+ tf_results->return_tensors[i].index = results.return_tensors[i].second;
+ }
+
+ // Populate return_nodes
+ DCHECK(tf_results->return_nodes.empty());
+ tf_results->return_nodes.resize(results.return_nodes.size());
+ for (int i = 0; i < results.return_nodes.size(); ++i) {
+ tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
+ }
+
+ // Populate unused map keys
+ DCHECK(tf_results->unused_key_names.empty());
+ DCHECK(tf_results->unused_key_indexes.empty());
+ DCHECK(tf_results->unused_key_names_data.empty());
+ tf_results->unused_key_names.resize(results.unused_input_map_keys.size());
+ tf_results->unused_key_indexes.resize(results.unused_input_map_keys.size());
+ for (int i = 0; i < results.unused_input_map_keys.size(); ++i) {
+ TensorId id = results.unused_input_map_keys[i];
+ tf_results->unused_key_names_data.push_back(id.first.ToString());
+ tf_results->unused_key_names[i] =
+ tf_results->unused_key_names_data.back().c_str();
+ tf_results->unused_key_indexes[i] = id.second;
+ }
+}
+
+TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
+ TF_Graph* graph, const TF_Buffer* graph_def,
+ const TF_ImportGraphDefOptions* options, TF_Status* status) {
+ GraphDef def;
+ if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
+ status->status = InvalidArgument("Invalid GraphDef");
+ return nullptr;
}
+ auto results = new TF_ImportGraphDefResults();
+ mutex_lock l(graph->mu);
+ GraphImportGraphDefLocked(graph, def, options, results, status);
+ if (!status->status.ok()) {
+ delete results;
+ return nullptr;
+ }
+ return results;
}
void TF_GraphImportGraphDefWithReturnOutputs(
TF_Graph* graph, const TF_Buffer* graph_def,
- const TF_ImportGraphDefOptions* opts, TF_Output* return_outputs,
+ const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
int num_return_outputs, TF_Status* status) {
+ if (num_return_outputs != options->opts.return_tensors.size()) {
+ status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
+ options->opts.return_tensors.size(),
+ ", got ", num_return_outputs);
+ return;
+ }
+ if (num_return_outputs > 0 && return_outputs == nullptr) {
+ status->status = InvalidArgument(
+ "'return_outputs' must be preallocated to length ", num_return_outputs);
+ return;
+ }
GraphDef def;
if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
status->status = InvalidArgument("Invalid GraphDef");
return;
}
+ TF_ImportGraphDefResults results;
mutex_lock l(graph->mu);
- GraphImportGraphDefLocked(graph, def, opts, return_outputs,
- num_return_outputs, status);
+ GraphImportGraphDefLocked(graph, def, options, &results, status);
+ DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
+ memcpy(return_outputs, results.return_tensors.data(),
+ num_return_outputs * sizeof(TF_Output));
}
void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
const TF_ImportGraphDefOptions* options,
TF_Status* status) {
- TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, options, nullptr, 0,
- status);
+ TF_ImportGraphDefResults* results =
+ TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
+ TF_DeleteImportGraphDefResults(results);
}
// While loop functions -------------------------------------------------------
@@ -1930,7 +2020,7 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph,
tensorflow::ShapeRefiner* dst_refiner,
const TF_Output* src_inputs,
const std::vector<tensorflow::Output>& dst_inputs,
- const tensorflow::string& prefix,
+ const string& prefix,
const std::vector<tensorflow::Operation>& control_deps,
const TF_Output* nodes_to_return, int nreturn_nodes,
std::vector<tensorflow::Output>* return_nodes) {
@@ -2257,9 +2347,9 @@ TF_Session* TF_LoadSessionFromSavedModel(
return nullptr;
}
- std::unordered_set<tensorflow::string> tag_set;
+ std::unordered_set<string> tag_set;
for (int i = 0; i < tags_len; i++) {
- tag_set.insert(tensorflow::string(tags[i]));
+ tag_set.insert(string(tags[i]));
}
tensorflow::SavedModelBundle bundle;
@@ -2275,8 +2365,9 @@ TF_Session* TF_LoadSessionFromSavedModel(
// TODO(jhseu): When Session is modified to take Graphs instead of
// GraphDefs, return the Graph generated in LoadSavedModel().
TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
+ TF_ImportGraphDefResults results;
GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
- import_opts, nullptr, 0, status);
+ import_opts, &results, status);
TF_DeleteImportGraphDefOptions(import_opts);
if (TF_GetCode(status) != TF_OK) return nullptr;
@@ -2372,20 +2463,20 @@ void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
TF_Run_Setup(noutputs, output_values, status);
// Convert from TF_Output and TF_Tensor to a string and Tensor.
- std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs);
+ std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
for (int i = 0; i < ninputs; ++i) {
input_pairs[i].first = OutputName(inputs[i]);
}
// Convert from TF_Output to string names.
- std::vector<tensorflow::string> output_names(noutputs);
+ std::vector<string> output_names(noutputs);
for (int i = 0; i < noutputs; ++i) {
output_names[i] = OutputName(outputs[i]);
}
// Convert from TF_Operation* to string names.
- std::vector<tensorflow::string> target_names(ntargets);
+ std::vector<string> target_names(ntargets);
for (int i = 0; i < ntargets; ++i) {
target_names[i] = target_opers[i]->node.name();
}
@@ -2406,22 +2497,22 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
return;
}
- std::vector<tensorflow::string> input_names(ninputs);
+ std::vector<string> input_names(ninputs);
for (int i = 0; i < ninputs; ++i) {
input_names[i] = OutputName(inputs[i]);
}
- std::vector<tensorflow::string> output_names(noutputs);
+ std::vector<string> output_names(noutputs);
for (int i = 0; i < noutputs; ++i) {
output_names[i] = OutputName(outputs[i]);
}
- std::vector<tensorflow::string> target_names(ntargets);
+ std::vector<string> target_names(ntargets);
for (int i = 0; i < ntargets; ++i) {
target_names[i] = target_opers[i]->node.name();
}
- tensorflow::string new_handle;
+ string new_handle;
status->status = session->session->PRunSetup(input_names, output_names,
target_names, &new_handle);
if (status->status.ok()) {
@@ -2452,20 +2543,20 @@ void TF_SessionPRun(TF_Session* session, const char* handle,
TF_Run_Setup(noutputs, output_values, status);
// Convert from TF_Output and TF_Tensor to a string and Tensor.
- std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs);
+ std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
for (int i = 0; i < ninputs; ++i) {
input_pairs[i].first = OutputName(inputs[i]);
}
// Convert from TF_Output to string names.
- std::vector<tensorflow::string> output_names(noutputs);
+ std::vector<string> output_names(noutputs);
for (int i = 0; i < noutputs; ++i) {
output_names[i] = OutputName(outputs[i]);
}
// Convert from TF_Operation* to string names.
- std::vector<tensorflow::string> target_names(ntargets);
+ std::vector<string> target_names(ntargets);
for (int i = 0; i < ntargets; ++i) {
target_names[i] = target_opers[i]->node.name();
}