aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-08-29 08:24:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-29 08:28:50 -0700
commit0fd2a74120b86972441378f79fb5d03e86fed856 (patch)
tree89c1a3a0fba7638e287a1462802240fdf90c4727 /tensorflow/c/c_api.cc
parent3c6a603d8ae9035830626df0e261442a59f2b990 (diff)
Introduce C++ API while loop builder method
This change adds a new function, BuildWhileLoop(), that constructs a while loop. BuildWhileLoop() takes functions that build the cond and body graphs, similar to the Python while_loop function. It also switches the C API to use this new function in order to reduce code duplication. This is in preparation for while loop gradients, which are also implemented in the C++ API (along with the other gradient code). I didn't write unit tests for BuildWhileLoop, instead relying on the current C API while loop tests. This change also disables while loop creation on Android to avoid pulling in extra C++ dependencies. PiperOrigin-RevId: 166849829
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r--tensorflow/c/c_api.cc363
1 files changed, 145 insertions, 218 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 1ea70f0598..07c8277a6f 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope_internal.h"
+#include "tensorflow/cc/ops/while_loop.h"
#include "tensorflow/cc/saved_model/loader.h"
#endif
#include "tensorflow/c/c_api_internal.h"
@@ -831,6 +832,30 @@ const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
return attr;
}
+TensorId ToTensorId(const TF_Output& output) {
+ return TensorId(output.oper->node.name(), output.index);
+}
+
+#ifndef __ANDROID__
+std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs,
+ int n) {
+ std::vector<tensorflow::Output> outputs(n);
+ for (int i = 0; i < n; ++i) {
+ outputs[i] =
+ tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index);
+ }
+ return outputs;
+}
+
+void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
+ TF_Output* tf_outputs) {
+ for (int i = 0; i < outputs.size(); i++) {
+ tf_outputs[i].oper = ToOperation(outputs[i].node());
+ tf_outputs[i].index = outputs[i].index();
+ }
+}
+#endif // __ANDROID__
+
} // namespace
// Shape functions -----------------------------------------------------------
@@ -1721,14 +1746,6 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
opts->opts.prefix = prefix;
}
-namespace {
-
-TensorId ToTensorId(const TF_Output& output) {
- return TensorId(output.oper->node.name(), output.index);
-}
-
-} // namespace
-
void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
const char* src_name,
int src_index, TF_Output dst) {
@@ -1812,6 +1829,11 @@ void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
// While loop functions -------------------------------------------------------
namespace {
+
+#ifndef __ANDROID__
+
+// Creates a placeholder representing an input to the cond or body graph.
+// TODO(skyewm): remove these from final graph
bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
TF_Output* input, TF_Status* status) {
TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
@@ -1823,130 +1845,50 @@ bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
return true;
}
-bool CreateEnter(TF_Graph* g, const char* node_name, const char* frame_name,
- const TF_Output& input, TF_Output* enter, TF_Status* status)
- EXCLUSIVE_LOCKS_REQUIRED(g->mu) {
- TF_OperationDescription* desc = TF_NewOperationLocked(g, "Enter", node_name);
- TF_AddInput(desc, input);
- TF_SetAttrString(desc, "frame_name", frame_name, strlen(frame_name));
- TF_Operation* oper = TF_FinishOperationLocked(desc, status);
- if (!status->status.ok()) return false;
- *enter = {oper, 0};
- return true;
-}
-
-bool CreateMerge(TF_Graph* g, const char* name, const TF_Output& input,
- const char* backedge_name, int backedge_index,
- TF_Output* merge, TF_Status* status)
- EXCLUSIVE_LOCKS_REQUIRED(g->mu) {
- TF_OperationDescription* desc = TF_NewOperationLocked(g, "Merge", name);
-
- // The merge nodes accept the while loop's back edges as an input. Use the
- // underlying NodeBuilder API directly to create an input to the
- // not-yet-created back edge.
- std::vector<NodeBuilder::NodeOut> input_list;
- input_list.push_back(NodeBuilder::NodeOut(&input.oper->node, input.index));
- // All merge inputs must have same type
- DataType type = input.oper->node.output_type(input.index);
- input_list.push_back(
- NodeBuilder::NodeOut(backedge_name, backedge_index, type));
-
- desc->node_builder.Input(input_list);
-
- TF_Operation* oper = TF_FinishOperationLocked(desc, status);
- if (!status->status.ok()) return false;
- *merge = {oper, 0};
- return true;
-}
-
-bool CreateSwitch(TF_Graph* g, const char* name, const TF_Output& input,
- const TF_Output& predicate, TF_Output* switch_true,
- TF_Output* switch_false, TF_Status* status)
- EXCLUSIVE_LOCKS_REQUIRED(g->mu) {
- TF_OperationDescription* desc = TF_NewOperationLocked(g, "Switch", name);
- TF_AddInput(desc, input);
- TF_AddInput(desc, predicate);
- TF_Operation* oper = TF_FinishOperationLocked(desc, status);
- if (!status->status.ok()) return false;
- *switch_false = {oper, 0};
- *switch_true = {oper, 1};
- return true;
-}
-
-bool CreateNext(TF_Graph* g, const char* name, const TF_Output& input,
- TF_Output* next, TF_Status* status)
- EXCLUSIVE_LOCKS_REQUIRED(g->mu) {
- TF_OperationDescription* desc =
- TF_NewOperationLocked(g, "NextIteration", name);
- TF_AddInput(desc, input);
- TF_Operation* oper = TF_FinishOperationLocked(desc, status);
- if (!status->status.ok()) return false;
- *next = {oper, 0};
- return true;
-}
-
-bool CreateExit(TF_Graph* g, const char* name, const TF_Output& input,
- TF_Output* exit, TF_Status* status)
- EXCLUSIVE_LOCKS_REQUIRED(g->mu) {
- TF_OperationDescription* desc = TF_NewOperationLocked(g, "Exit", name);
- TF_AddInput(desc, input);
- TF_Operation* oper = TF_FinishOperationLocked(desc, status);
- if (!status->status.ok()) return false;
- *exit = {oper, 0};
- return true;
-}
-
-class ScopedImportGraphDefOptions {
- public:
- ScopedImportGraphDefOptions() { opts_ = TF_NewImportGraphDefOptions(); }
- ~ScopedImportGraphDefOptions() { TF_DeleteImportGraphDefOptions(opts_); }
-
- TF_ImportGraphDefOptions* get() const { return opts_; }
-
- private:
- TF_ImportGraphDefOptions* opts_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(ScopedImportGraphDefOptions);
-};
-
// Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
-// `src_inputs[i]` will have that input replaced with `dst_inputs[i]`.
-// `prefix` will be prepended to copied node names. `return_nodes` are nodes
-// in `src_graph`, and the new corresponding nodes in `dst_graph` will be
-// returned. `return_nodes` should be preallocated to size `nreturn_nodes`.
-bool CopyGraph(TF_Graph* src_graph, TF_Graph* dst_graph,
- const TF_Output* src_inputs,
- const std::vector<TF_Output>& dst_inputs, const char* prefix,
- const TF_Output* nodes_to_return, int nreturn_nodes,
- TF_Output* return_nodes, TF_Status* s)
- EXCLUSIVE_LOCKS_REQUIRED(dst_graph->mu) {
+// `src_inputs[i]` will have that input replaced with `dst_inputs[i]`. `prefix`
+// will be prepended to copied node names. `control_deps` are nodes in
+// `dst_graph` that the copied `src_graph` nodes will have control dependencies
+// on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes
+// in `dst_graph` will be returned. `return_nodes` must be non-null.
+Status CopyGraph(Graph* src_graph, Graph* dst_graph,
+ tensorflow::ShapeRefiner* dst_refiner,
+ const TF_Output* src_inputs,
+ const std::vector<tensorflow::Output>& dst_inputs,
+ const tensorflow::string& prefix,
+ const std::vector<tensorflow::Operation>& control_deps,
+ const TF_Output* nodes_to_return, int nreturn_nodes,
+ std::vector<tensorflow::Output>* return_nodes) {
+ DCHECK(return_nodes != nullptr);
GraphDef gdef;
- src_graph->graph.ToGraphDef(&gdef);
+ src_graph->ToGraphDef(&gdef);
- ScopedImportGraphDefOptions opts;
- TF_ImportGraphDefOptionsSetPrefix(opts.get(), prefix);
+ tensorflow::ImportGraphDefOptions opts;
+ opts.prefix = prefix;
for (int i = 0; i < dst_inputs.size(); ++i) {
- TensorId src = ToTensorId(src_inputs[i]);
- TF_ImportGraphDefOptionsAddInputMapping(opts.get(), src.first.data(),
- src.second, dst_inputs[i]);
+ opts.input_map[ToTensorId(src_inputs[i])] =
+ TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index());
}
- opts.get()->opts.skip_mapped_nodes = true;
+ opts.skip_mapped_nodes = true;
- // We use the pivot node to control constants in `src_graph`
- TF_Operation* pivot = dst_inputs[0].oper;
- TF_ImportGraphDefOptionsAddControlDependency(opts.get(), pivot);
+ for (const tensorflow::Operation& op : control_deps) {
+ opts.control_dependencies.push_back(op.node()->name());
+ }
for (int i = 0; i < nreturn_nodes; ++i) {
- TF_ImportGraphDefOptionsAddReturnOutput(
- opts.get(), nodes_to_return[i].oper->node.name().c_str(),
- nodes_to_return[i].index);
+ opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
}
- GraphImportGraphDefLocked(dst_graph, gdef, opts.get(), return_nodes,
- nreturn_nodes, s);
- if (TF_GetCode(s) != TF_OK) return false;
- return true;
+ // TOOD(skyewm): change to OutputTensor
+ std::vector<std::pair<Node*, int>> return_tensors;
+ TF_RETURN_IF_ERROR(
+ ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &return_tensors));
+
+ for (const auto& pair : return_tensors) {
+ return_nodes->emplace_back(pair.first, pair.second);
+ }
+ return Status::OK();
}
bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
@@ -1982,6 +1924,8 @@ bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
return true;
}
+#endif // __ANDROID__
+
void FreeWhileResources(const TF_WhileParams* params) {
TF_DeleteGraph(params->cond_graph);
TF_DeleteGraph(params->body_graph);
@@ -1999,6 +1943,13 @@ TF_WhileParams EmptyWhileParams() {
TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
TF_Status* status) {
+#ifdef __ANDROID__
+ status->status = tensorflow::errors::Unimplemented(
+ "Creating while loops is not supported in Android. File a bug at "
+ "https://github.com/tensorflow/tensorflow/issues if this feature is "
+ "important to you");
+ return EmptyWhileParams();
+#else
if (ninputs == 0) {
status->status =
InvalidArgument("TF_NewWhile() must be passed at least one input");
@@ -2039,8 +1990,10 @@ TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
return EmptyWhileParams();
}
return params;
+#endif // __ANDROID__
}
+#ifndef __ANDROID__
namespace {
// TODO(skyewm): make nodes in while loop unfetchable like in Python version
@@ -2050,113 +2003,90 @@ void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
TF_Graph* parent = params->cond_graph->parent;
TF_Output* parent_inputs = params->cond_graph->parent_inputs;
- int n = params->ninputs;
+ int num_loop_vars = params->ninputs;
mutex_lock l(parent->mu);
- // Create Enter nodes
- std::vector<TF_Output> enter_nodes(n);
- for (int i = 0; i < n; ++i) {
- if (!CreateEnter(parent, StrCat(params->name, "/enter", i).c_str(),
- params->name, parent_inputs[i], &enter_nodes[i], status)) {
- return;
- }
- }
-
- // Create Merge nodes
- std::vector<TF_Output> merge_nodes(n);
- for (int i = 0; i < n; ++i) {
- if (!CreateMerge(parent, StrCat(params->name, "/merge", i).c_str(),
- enter_nodes[i], StrCat(params->name, "/next", i).c_str(),
- 0, &merge_nodes[i], status)) {
- return;
- }
- }
-
- // Copy cond_graph to parent and replace input placeholders with merge node
- // outputs, and get handle to new cond output
- tensorflow::string cond_prefix = StrCat(params->name, "/cond");
- TF_Output cond_output;
- if (!CopyGraph(params->cond_graph, parent, params->cond_inputs, merge_nodes,
- cond_prefix.c_str(), &params->cond_output, 1, &cond_output,
- status)) {
- return;
- }
-
- // Create Switch nodes
- std::vector<TF_Output> switch_trues(n);
- std::vector<TF_Output> switch_falses(n);
- for (int i = 0; i < n; ++i) {
- if (!CreateSwitch(parent, StrCat(params->name, "/switch", i).c_str(),
- merge_nodes[i], cond_output, &switch_trues[i],
- &switch_falses[i], status)) {
- return;
- }
- }
-
- // Copy body_graph to parent, replace input placeholders with switch node
- // true outputs, and get handles to new body outputs
- tensorflow::string body_prefix = StrCat(params->name, "/body");
- std::vector<TF_Output> body_outputs(n);
- if (!CopyGraph(params->body_graph, parent, params->body_inputs, switch_trues,
- body_prefix.c_str(), params->body_outputs, n,
- body_outputs.data(), status)) {
- return;
- }
-
- // Create Next nodes
- std::vector<TF_Output> next_nodes(n);
- for (int i = 0; i < n; ++i) {
- if (!CreateNext(parent, StrCat(params->name, "/next", i).c_str(),
- body_outputs[i], &next_nodes[i], status)) {
- return;
- }
- }
-
- // Create Exit nodes (which are the outputs of the while loop)
- for (int i = 0; i < n; ++i) {
- if (!CreateExit(parent, StrCat(params->name, "/exit", i).c_str(),
- switch_falses[i], &outputs[i], status)) {
- return;
- }
+ // 'cond_fn' copies the cond graph into the parent graph.
+ tensorflow::ops::CondGraphBuilderFn cond_fn =
+ [params, parent](const tensorflow::Scope& scope,
+ const std::vector<tensorflow::Output>& inputs,
+ tensorflow::Output* output) {
+ DCHECK_EQ(scope.graph(), &parent->graph);
+ std::vector<tensorflow::Output> cond_output;
+ TF_RETURN_IF_ERROR(CopyGraph(
+ &params->cond_graph->graph, &parent->graph, &parent->refiner,
+ params->cond_inputs, inputs, scope.impl()->name(),
+ scope.impl()->control_deps(), &params->cond_output,
+ /* nreturn_nodes */ 1, &cond_output));
+ *output = cond_output[0];
+ return Status::OK();
+ };
+
+ // 'body_fn' copies the body graph into the parent graph.
+ tensorflow::ops::BodyGraphBuilderFn body_fn =
+ [params, parent, num_loop_vars](
+ const tensorflow::Scope& scope,
+ const std::vector<tensorflow::Output>& inputs,
+ std::vector<tensorflow::Output>* outputs) {
+ DCHECK_EQ(scope.graph(), &parent->graph);
+ TF_RETURN_IF_ERROR(
+ CopyGraph(&params->body_graph->graph, &parent->graph,
+ &parent->refiner, params->body_inputs, inputs,
+ scope.impl()->name(), scope.impl()->control_deps(),
+ params->body_outputs, num_loop_vars, outputs));
+ return Status::OK();
+ };
+
+ // Create the while loop using an internal scope.
+ tensorflow::Scope scope =
+ NewInternalScope(&parent->graph, &status->status, &parent->refiner)
+ .NewSubScope(params->name);
+
+ const int first_new_node_id = parent->graph.num_node_ids();
+
+ tensorflow::OutputList loop_outputs;
+ status->status = tensorflow::ops::BuildWhileLoop(
+ scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn,
+ body_fn, params->name, &loop_outputs);
+
+ // Update name_map with newly-created ops.
+ // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns
+ // a bad status. Once we fix this, we may want to return early instead of
+ // executing the following code.
+ for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) {
+ Node* new_node = parent->graph.FindNodeId(i);
+ if (new_node == nullptr) continue;
+ parent->name_map[new_node->name()] = new_node;
+ }
+
+ // Populate 'outputs'.
+ DCHECK_LE(loop_outputs.size(), num_loop_vars);
+ for (int i = 0; i < loop_outputs.size(); ++i) {
+ outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()};
}
}
} // namespace
+#endif // __ANDROID__
void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
TF_Output* outputs) {
+#ifdef __ANDROID__
+ status->status = tensorflow::errors::Unimplemented(
+ "Creating while loops is not supported in Android. File a bug at "
+ "https://github.com/tensorflow/tensorflow/issues if this feature is "
+ "important to you");
+#else
// If it appears the caller created or modified `params`, don't free resources
if (!ValidateConstWhileParams(*params, status)) return;
TF_FinishWhileHelper(params, status, outputs);
FreeWhileResources(params);
+#endif // __ANDROID__
}
void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
-#ifndef __ANDROID__
-namespace {
-
-void OutputsFromTFOutputs(TF_Output* tf_outputs, int n, TF_Status* status,
- std::vector<tensorflow::Output>* outputs) {
- outputs->resize(n);
- for (int i = 0; i < n; i++) {
- const TF_Output& tf_output = tf_outputs[i];
- (*outputs)[i] = tensorflow::Output(&tf_output.oper->node, tf_output.index);
- }
-}
-
-void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
- TF_Output* tf_outputs) {
- for (int i = 0; i < outputs.size(); i++) {
- tf_outputs[i].oper = ToOperation(outputs[i].node());
- tf_outputs[i].index = outputs[i].index();
- }
-}
-
-} // namespace
-#endif // __ANDROID__
-
void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
TF_Output* dx, TF_Status* status, TF_Output* dy) {
#ifdef __ANDROID__
@@ -2165,25 +2095,22 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
"https://github.com/tensorflow/tensorflow/issues if this feature is "
"important to you");
#else
- std::vector<tensorflow::Output> y_arg;
- std::vector<tensorflow::Output> x_arg;
+ std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny);
+ std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx);
std::vector<tensorflow::Output> dy_arg;
- OutputsFromTFOutputs(y, ny, status, &y_arg);
- OutputsFromTFOutputs(x, nx, status, &x_arg);
{
// We need to hold on to the lock while we have a scope that uses TF_Graph.
mutex_lock graph_lock(g->mu);
- const int max_node_id_before = g->graph.num_node_ids();
+ const int first_new_node_id = g->graph.num_node_ids();
tensorflow::Scope scope =
NewInternalScope(&g->graph, &status->status, &g->refiner)
.NewSubScope("gradients");
if (dx != nullptr) {
- std::vector<tensorflow::Output> dx_arg;
- OutputsFromTFOutputs(dx, ny, status, &dx_arg);
+ std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
status->status =
AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
} else {
@@ -2192,7 +2119,7 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
// Update g->name_map with the name_map from the scope, which will contain
// the new gradient ops.
- for (int i = max_node_id_before; i < g->graph.num_node_ids(); ++i) {
+ for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
Node* n = g->graph.FindNodeId(i);
if (n == nullptr) continue;
g->name_map[n->name()] = n;