diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-08-29 08:24:23 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-29 08:28:50 -0700 |
commit | 0fd2a74120b86972441378f79fb5d03e86fed856 (patch) | |
tree | 89c1a3a0fba7638e287a1462802240fdf90c4727 /tensorflow/c/c_api.cc | |
parent | 3c6a603d8ae9035830626df0e261442a59f2b990 (diff) |
Introduce C++ API while loop builder method
This change adds a new function, BuildWhileLoop(), that constructs a
while loop. BuildWhileLoop() takes functions that build the cond and
body graphs, similar to the Python while_loop function. It also
switches the C API to use this new function in order to reduce code
duplication. This is in preparation for while loop gradients, which
are also implemented in the C++ API (along with the other gradient
code).
I didn't write unit tests for BuildWhileLoop, instead relying on the
current C API while loop tests.
This change also disables while loop creation on Android to avoid
pulling in extra C++ dependencies.
PiperOrigin-RevId: 166849829
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r-- | tensorflow/c/c_api.cc | 363 |
1 files changed, 145 insertions, 218 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 1ea70f0598..07c8277a6f 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/cc/ops/while_loop.h" #include "tensorflow/cc/saved_model/loader.h" #endif #include "tensorflow/c/c_api_internal.h" @@ -831,6 +832,30 @@ const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, return attr; } +TensorId ToTensorId(const TF_Output& output) { + return TensorId(output.oper->node.name(), output.index); +} + +#ifndef __ANDROID__ +std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs, + int n) { + std::vector<tensorflow::Output> outputs(n); + for (int i = 0; i < n; ++i) { + outputs[i] = + tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index); + } + return outputs; +} + +void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs, + TF_Output* tf_outputs) { + for (int i = 0; i < outputs.size(); i++) { + tf_outputs[i].oper = ToOperation(outputs[i].node()); + tf_outputs[i].index = outputs[i].index(); + } +} +#endif // __ANDROID__ + } // namespace // Shape functions ----------------------------------------------------------- @@ -1721,14 +1746,6 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, opts->opts.prefix = prefix; } -namespace { - -TensorId ToTensorId(const TF_Output& output) { - return TensorId(output.oper->node.name(), output.index); -} - -} // namespace - void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, TF_Output dst) { @@ -1812,6 +1829,11 @@ void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def, // While loop functions ------------------------------------------------------- namespace { + +#ifndef __ANDROID__ + +// Creates a placeholder representing an input to the cond or body graph. +// TODO(skyewm): remove these from final graph bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name, TF_Output* input, TF_Status* status) { TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name); @@ -1823,130 +1845,50 @@ bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name, return true; } -bool CreateEnter(TF_Graph* g, const char* node_name, const char* frame_name, - const TF_Output& input, TF_Output* enter, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(g->mu) { - TF_OperationDescription* desc = TF_NewOperationLocked(g, "Enter", node_name); - TF_AddInput(desc, input); - TF_SetAttrString(desc, "frame_name", frame_name, strlen(frame_name)); - TF_Operation* oper = TF_FinishOperationLocked(desc, status); - if (!status->status.ok()) return false; - *enter = {oper, 0}; - return true; -} - -bool CreateMerge(TF_Graph* g, const char* name, const TF_Output& input, - const char* backedge_name, int backedge_index, - TF_Output* merge, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(g->mu) { - TF_OperationDescription* desc = TF_NewOperationLocked(g, "Merge", name); - - // The merge nodes accept the while loop's back edges as an input. Use the - // underlying NodeBuilder API directly to create an input to the - // not-yet-created back edge. - std::vector<NodeBuilder::NodeOut> input_list; - input_list.push_back(NodeBuilder::NodeOut(&input.oper->node, input.index)); - // All merge inputs must have same type - DataType type = input.oper->node.output_type(input.index); - input_list.push_back( - NodeBuilder::NodeOut(backedge_name, backedge_index, type)); - - desc->node_builder.Input(input_list); - - TF_Operation* oper = TF_FinishOperationLocked(desc, status); - if (!status->status.ok()) return false; - *merge = {oper, 0}; - return true; -} - -bool CreateSwitch(TF_Graph* g, const char* name, const TF_Output& input, - const TF_Output& predicate, TF_Output* switch_true, - TF_Output* switch_false, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(g->mu) { - TF_OperationDescription* desc = TF_NewOperationLocked(g, "Switch", name); - TF_AddInput(desc, input); - TF_AddInput(desc, predicate); - TF_Operation* oper = TF_FinishOperationLocked(desc, status); - if (!status->status.ok()) return false; - *switch_false = {oper, 0}; - *switch_true = {oper, 1}; - return true; -} - -bool CreateNext(TF_Graph* g, const char* name, const TF_Output& input, - TF_Output* next, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(g->mu) { - TF_OperationDescription* desc = - TF_NewOperationLocked(g, "NextIteration", name); - TF_AddInput(desc, input); - TF_Operation* oper = TF_FinishOperationLocked(desc, status); - if (!status->status.ok()) return false; - *next = {oper, 0}; - return true; -} - -bool CreateExit(TF_Graph* g, const char* name, const TF_Output& input, - TF_Output* exit, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(g->mu) { - TF_OperationDescription* desc = TF_NewOperationLocked(g, "Exit", name); - TF_AddInput(desc, input); - TF_Operation* oper = TF_FinishOperationLocked(desc, status); - if (!status->status.ok()) return false; - *exit = {oper, 0}; - return true; -} - -class ScopedImportGraphDefOptions { - public: - ScopedImportGraphDefOptions() { opts_ = TF_NewImportGraphDefOptions(); } - ~ScopedImportGraphDefOptions() { TF_DeleteImportGraphDefOptions(opts_); } - - TF_ImportGraphDefOptions* get() const { return opts_; } - - private: - TF_ImportGraphDefOptions* opts_; - - TF_DISALLOW_COPY_AND_ASSIGN(ScopedImportGraphDefOptions); -}; - // Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input -// `src_inputs[i]` will have that input replaced with `dst_inputs[i]`. -// `prefix` will be prepended to copied node names. `return_nodes` are nodes -// in `src_graph`, and the new corresponding nodes in `dst_graph` will be -// returned. `return_nodes` should be preallocated to size `nreturn_nodes`. -bool CopyGraph(TF_Graph* src_graph, TF_Graph* dst_graph, - const TF_Output* src_inputs, - const std::vector<TF_Output>& dst_inputs, const char* prefix, - const TF_Output* nodes_to_return, int nreturn_nodes, - TF_Output* return_nodes, TF_Status* s) - EXCLUSIVE_LOCKS_REQUIRED(dst_graph->mu) { +// `src_inputs[i]` will have that input replaced with `dst_inputs[i]`. `prefix` +// will be prepended to copied node names. `control_deps` are nodes in +// `dst_graph` that the copied `src_graph` nodes will have control dependencies +// on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes +// in `dst_graph` will be returned. `return_nodes` must be non-null. +Status CopyGraph(Graph* src_graph, Graph* dst_graph, + tensorflow::ShapeRefiner* dst_refiner, + const TF_Output* src_inputs, + const std::vector<tensorflow::Output>& dst_inputs, + const tensorflow::string& prefix, + const std::vector<tensorflow::Operation>& control_deps, + const TF_Output* nodes_to_return, int nreturn_nodes, + std::vector<tensorflow::Output>* return_nodes) { + DCHECK(return_nodes != nullptr); GraphDef gdef; - src_graph->graph.ToGraphDef(&gdef); + src_graph->ToGraphDef(&gdef); - ScopedImportGraphDefOptions opts; - TF_ImportGraphDefOptionsSetPrefix(opts.get(), prefix); + tensorflow::ImportGraphDefOptions opts; + opts.prefix = prefix; for (int i = 0; i < dst_inputs.size(); ++i) { - TensorId src = ToTensorId(src_inputs[i]); - TF_ImportGraphDefOptionsAddInputMapping(opts.get(), src.first.data(), - src.second, dst_inputs[i]); + opts.input_map[ToTensorId(src_inputs[i])] = + TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index()); } - opts.get()->opts.skip_mapped_nodes = true; + opts.skip_mapped_nodes = true; - // We use the pivot node to control constants in `src_graph` - TF_Operation* pivot = dst_inputs[0].oper; - TF_ImportGraphDefOptionsAddControlDependency(opts.get(), pivot); + for (const tensorflow::Operation& op : control_deps) { + opts.control_dependencies.push_back(op.node()->name()); + } for (int i = 0; i < nreturn_nodes; ++i) { - TF_ImportGraphDefOptionsAddReturnOutput( - opts.get(), nodes_to_return[i].oper->node.name().c_str(), - nodes_to_return[i].index); + opts.return_tensors.push_back(ToTensorId(nodes_to_return[i])); } - GraphImportGraphDefLocked(dst_graph, gdef, opts.get(), return_nodes, - nreturn_nodes, s); - if (TF_GetCode(s) != TF_OK) return false; - return true; + // TOOD(skyewm): change to OutputTensor + std::vector<std::pair<Node*, int>> return_tensors; + TF_RETURN_IF_ERROR( + ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &return_tensors)); + + for (const auto& pair : return_tensors) { + return_nodes->emplace_back(pair.first, pair.second); + } + return Status::OK(); } bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) { @@ -1982,6 +1924,8 @@ bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) { return true; } +#endif // __ANDROID__ + void FreeWhileResources(const TF_WhileParams* params) { TF_DeleteGraph(params->cond_graph); TF_DeleteGraph(params->body_graph); @@ -1999,6 +1943,13 @@ TF_WhileParams EmptyWhileParams() { TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Creating while loops is not supported in Android. File a bug at " + "https://github.com/tensorflow/tensorflow/issues if this feature is " + "important to you"); + return EmptyWhileParams(); +#else if (ninputs == 0) { status->status = InvalidArgument("TF_NewWhile() must be passed at least one input"); @@ -2039,8 +1990,10 @@ TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs, return EmptyWhileParams(); } return params; +#endif // __ANDROID__ } +#ifndef __ANDROID__ namespace { // TODO(skyewm): make nodes in while loop unfetchable like in Python version @@ -2050,113 +2003,90 @@ void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status, TF_Graph* parent = params->cond_graph->parent; TF_Output* parent_inputs = params->cond_graph->parent_inputs; - int n = params->ninputs; + int num_loop_vars = params->ninputs; mutex_lock l(parent->mu); - // Create Enter nodes - std::vector<TF_Output> enter_nodes(n); - for (int i = 0; i < n; ++i) { - if (!CreateEnter(parent, StrCat(params->name, "/enter", i).c_str(), - params->name, parent_inputs[i], &enter_nodes[i], status)) { - return; - } - } - - // Create Merge nodes - std::vector<TF_Output> merge_nodes(n); - for (int i = 0; i < n; ++i) { - if (!CreateMerge(parent, StrCat(params->name, "/merge", i).c_str(), - enter_nodes[i], StrCat(params->name, "/next", i).c_str(), - 0, &merge_nodes[i], status)) { - return; - } - } - - // Copy cond_graph to parent and replace input placeholders with merge node - // outputs, and get handle to new cond output - tensorflow::string cond_prefix = StrCat(params->name, "/cond"); - TF_Output cond_output; - if (!CopyGraph(params->cond_graph, parent, params->cond_inputs, merge_nodes, - cond_prefix.c_str(), ¶ms->cond_output, 1, &cond_output, - status)) { - return; - } - - // Create Switch nodes - std::vector<TF_Output> switch_trues(n); - std::vector<TF_Output> switch_falses(n); - for (int i = 0; i < n; ++i) { - if (!CreateSwitch(parent, StrCat(params->name, "/switch", i).c_str(), - merge_nodes[i], cond_output, &switch_trues[i], - &switch_falses[i], status)) { - return; - } - } - - // Copy body_graph to parent, replace input placeholders with switch node - // true outputs, and get handles to new body outputs - tensorflow::string body_prefix = StrCat(params->name, "/body"); - std::vector<TF_Output> body_outputs(n); - if (!CopyGraph(params->body_graph, parent, params->body_inputs, switch_trues, - body_prefix.c_str(), params->body_outputs, n, - body_outputs.data(), status)) { - return; - } - - // Create Next nodes - std::vector<TF_Output> next_nodes(n); - for (int i = 0; i < n; ++i) { - if (!CreateNext(parent, StrCat(params->name, "/next", i).c_str(), - body_outputs[i], &next_nodes[i], status)) { - return; - } - } - - // Create Exit nodes (which are the outputs of the while loop) - for (int i = 0; i < n; ++i) { - if (!CreateExit(parent, StrCat(params->name, "/exit", i).c_str(), - switch_falses[i], &outputs[i], status)) { - return; - } + // 'cond_fn' copies the cond graph into the parent graph. + tensorflow::ops::CondGraphBuilderFn cond_fn = + [params, parent](const tensorflow::Scope& scope, + const std::vector<tensorflow::Output>& inputs, + tensorflow::Output* output) { + DCHECK_EQ(scope.graph(), &parent->graph); + std::vector<tensorflow::Output> cond_output; + TF_RETURN_IF_ERROR(CopyGraph( + ¶ms->cond_graph->graph, &parent->graph, &parent->refiner, + params->cond_inputs, inputs, scope.impl()->name(), + scope.impl()->control_deps(), ¶ms->cond_output, + /* nreturn_nodes */ 1, &cond_output)); + *output = cond_output[0]; + return Status::OK(); + }; + + // 'body_fn' copies the body graph into the parent graph. + tensorflow::ops::BodyGraphBuilderFn body_fn = + [params, parent, num_loop_vars]( + const tensorflow::Scope& scope, + const std::vector<tensorflow::Output>& inputs, + std::vector<tensorflow::Output>* outputs) { + DCHECK_EQ(scope.graph(), &parent->graph); + TF_RETURN_IF_ERROR( + CopyGraph(¶ms->body_graph->graph, &parent->graph, + &parent->refiner, params->body_inputs, inputs, + scope.impl()->name(), scope.impl()->control_deps(), + params->body_outputs, num_loop_vars, outputs)); + return Status::OK(); + }; + + // Create the while loop using an internal scope. + tensorflow::Scope scope = + NewInternalScope(&parent->graph, &status->status, &parent->refiner) + .NewSubScope(params->name); + + const int first_new_node_id = parent->graph.num_node_ids(); + + tensorflow::OutputList loop_outputs; + status->status = tensorflow::ops::BuildWhileLoop( + scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn, + body_fn, params->name, &loop_outputs); + + // Update name_map with newly-created ops. + // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns + // a bad status. Once we fix this, we may want to return early instead of + // executing the following code. + for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) { + Node* new_node = parent->graph.FindNodeId(i); + if (new_node == nullptr) continue; + parent->name_map[new_node->name()] = new_node; + } + + // Populate 'outputs'. + DCHECK_LE(loop_outputs.size(), num_loop_vars); + for (int i = 0; i < loop_outputs.size(); ++i) { + outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()}; } } } // namespace +#endif // __ANDROID__ void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status, TF_Output* outputs) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Creating while loops is not supported in Android. File a bug at " + "https://github.com/tensorflow/tensorflow/issues if this feature is " + "important to you"); +#else // If it appears the caller created or modified `params`, don't free resources if (!ValidateConstWhileParams(*params, status)) return; TF_FinishWhileHelper(params, status, outputs); FreeWhileResources(params); +#endif // __ANDROID__ } void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); } -#ifndef __ANDROID__ -namespace { - -void OutputsFromTFOutputs(TF_Output* tf_outputs, int n, TF_Status* status, - std::vector<tensorflow::Output>* outputs) { - outputs->resize(n); - for (int i = 0; i < n; i++) { - const TF_Output& tf_output = tf_outputs[i]; - (*outputs)[i] = tensorflow::Output(&tf_output.oper->node, tf_output.index); - } -} - -void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs, - TF_Output* tf_outputs) { - for (int i = 0; i < outputs.size(); i++) { - tf_outputs[i].oper = ToOperation(outputs[i].node()); - tf_outputs[i].index = outputs[i].index(); - } -} - -} // namespace -#endif // __ANDROID__ - void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, TF_Output* dx, TF_Status* status, TF_Output* dy) { #ifdef __ANDROID__ @@ -2165,25 +2095,22 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, "https://github.com/tensorflow/tensorflow/issues if this feature is " "important to you"); #else - std::vector<tensorflow::Output> y_arg; - std::vector<tensorflow::Output> x_arg; + std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny); + std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx); std::vector<tensorflow::Output> dy_arg; - OutputsFromTFOutputs(y, ny, status, &y_arg); - OutputsFromTFOutputs(x, nx, status, &x_arg); { // We need to hold on to the lock while we have a scope that uses TF_Graph. mutex_lock graph_lock(g->mu); - const int max_node_id_before = g->graph.num_node_ids(); + const int first_new_node_id = g->graph.num_node_ids(); tensorflow::Scope scope = NewInternalScope(&g->graph, &status->status, &g->refiner) .NewSubScope("gradients"); if (dx != nullptr) { - std::vector<tensorflow::Output> dx_arg; - OutputsFromTFOutputs(dx, ny, status, &dx_arg); + std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny); status->status = AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg); } else { @@ -2192,7 +2119,7 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, // Update g->name_map with the name_map from the scope, which will contain // the new gradient ops. - for (int i = max_node_id_before; i < g->graph.num_node_ids(); ++i) { + for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) { Node* n = g->graph.FindNodeId(i); if (n == nullptr) continue; g->name_map[n->name()] = n; |