aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/ops
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-08-29 08:24:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-29 08:28:50 -0700
commit0fd2a74120b86972441378f79fb5d03e86fed856 (patch)
tree89c1a3a0fba7638e287a1462802240fdf90c4727 /tensorflow/cc/ops
parent3c6a603d8ae9035830626df0e261442a59f2b990 (diff)
Introduce C++ API while loop builder method
This change adds a new function, BuildWhileLoop(), that constructs a while loop. BuildWhileLoop() takes functions that build the cond and body graphs, similar to the Python while_loop function. It also switches the C API to use this new function in order to reduce code duplication. This is in preparation for while loop gradients, which are also implemented in the C++ API (along with the other gradient code). I didn't write unit tests for BuildWhileLoop, instead relying on the current C API while loop tests. This change also disables while loop creation on Android to avoid pulling in extra C++ dependencies. PiperOrigin-RevId: 166849829
Diffstat (limited to 'tensorflow/cc/ops')
-rw-r--r--tensorflow/cc/ops/while_loop.cc223
-rw-r--r--tensorflow/cc/ops/while_loop.h64
2 files changed, 287 insertions, 0 deletions
diff --git a/tensorflow/cc/ops/while_loop.cc b/tensorflow/cc/ops/while_loop.cc
new file mode 100644
index 0000000000..27da77bbe0
--- /dev/null
+++ b/tensorflow/cc/ops/while_loop.cc
@@ -0,0 +1,223 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/while_loop.h"
+
+#include "tensorflow/cc/framework/scope_internal.h"
+#include "tensorflow/cc/ops/control_flow_ops_internal.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/shape_refiner.h"
+#include "tensorflow/core/graph/node_builder.h"
+
+namespace tensorflow {
+namespace ops {
+
+namespace {
+
+// Utility function for converting to internal C++ datatypes.
+OutputTensor ToOutputTensor(const Output& output) {
+ return OutputTensor(output.node(), output.index());
+}
+
+// Utility function for converting to internal C++ datatypes.
+std::vector<OutputTensor> ToOutputTensors(const std::vector<Output>& outputs) {
+ std::vector<OutputTensor> result(outputs.size());
+ for (int i = 0; i < outputs.size(); ++i) {
+ result[i] = ToOutputTensor(outputs[i]);
+ }
+ return result;
+}
+
+// Utility function for converting to internal C++ datatypes.
+std::vector<Node*> ToNodes(const std::vector<Output>& outputs) {
+ std::vector<Node*> result(outputs.size());
+ for (int i = 0; i < outputs.size(); ++i) {
+ result[i] = outputs[i].node();
+ }
+ return result;
+}
+
+// Manually generates the name of the `loop_var_idx`-th NextIteration node of a
+// loop being constructed with `scope`. This is used to define the backedge
+// before the NextIteration node is created.
+string NextIterationName(const Scope& scope, int loop_var_idx) {
+ string result;
+ const string& prefix = scope.impl()->name();
+ if (!prefix.empty()) strings::StrAppend(&result, prefix, "/");
+ strings::StrAppend(&result, "NextIteration");
+ if (loop_var_idx > 0) strings::StrAppend(&result, "_", loop_var_idx);
+ return result;
+}
+
+// Creates the `loop_var_idx`-th Merge node of a loop being constructed with
+// `scope`. `enter_output` is the `loop_var_idx`-th Enter node's output.
+Status CreateMerge(const Scope& scope, int loop_var_idx,
+ const Output& enter_output, Output* merge_output) {
+ // The merge nodes accept the while loop's back edges as an input (i.e. the
+ // not-yet-created next iteration nodes). Use the underlying NodeBuilder API
+ // directly to create the back edge.
+ NodeBuilder::NodeOut enter_input(enter_output.node(), enter_output.index());
+
+ const int next_output_index = 0;
+ DataType dtype = enter_output.node()->output_type(0);
+ NodeBuilder::NodeOut next_input(NextIterationName(scope, loop_var_idx),
+ next_output_index, dtype);
+
+ std::vector<NodeBuilder::NodeOut> input_list({enter_input, next_input});
+ const string unique_name = scope.GetUniqueNameForOp("Merge");
+ NodeBuilder builder = NodeBuilder(unique_name, "Merge").Input(input_list);
+ scope.UpdateBuilder(&builder);
+
+ Node* merge_node;
+ TF_RETURN_IF_ERROR(builder.Finalize(scope.graph(), &merge_node));
+ TF_RETURN_IF_ERROR(scope.DoShapeInference(merge_node));
+ *merge_output = Output(merge_node, 0);
+ return Status::OK();
+}
+
+// Creates the condition subgraph defined by `cond`.
+Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond,
+ const std::vector<Output>& inputs, Output* output) {
+ // The control dependency is for constants in the cond graph, and other ops
+ // that do not depend on the loop variables. This ensures that these ops are
+ // in the while loop frame (since they will indirectly depend on an Enter node
+ // defining the frame) and that they are executed once per loop iteration.
+ //
+ // TODO(skyewm): the control dep will be added to all nodes in the cond graph.
+ // This is at best unnecessary, and at worst may prevent different parts of
+ // different loop iterations from executing in parallel.
+ Scope cond_scope =
+ scope.NewSubScope("cond").WithControlDependencies(inputs[0]);
+ Output raw_cond_out;
+ TF_RETURN_IF_ERROR(cond(cond_scope, inputs, &raw_cond_out));
+ if (raw_cond_out.type() != DT_BOOL) {
+ return errors::InvalidArgument(
+ "BuildWhileLoop: 'cond' argument must return a boolean output, got ",
+ DataTypeString(raw_cond_out.type()));
+ }
+ *output = LoopCond(scope, raw_cond_out).output;
+ return Status::OK();
+}
+
+// Create the bdoy subgraph defined by `body`. `outputs` must be non-null and
+// empty.
+Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body,
+ const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ DCHECK(outputs != nullptr);
+ DCHECK(outputs->empty());
+
+ // The control dependency is analogous to that in CreateCond().
+ Scope body_scope =
+ scope.NewSubScope("body").WithControlDependencies(inputs[0]);
+ TF_RETURN_IF_ERROR(body(body_scope, inputs, outputs));
+ const size_t num_loop_vars = inputs.size();
+ if (outputs->size() != num_loop_vars) {
+ return errors::InvalidArgument(
+ "BuildWhileLoop: 'body' argument expected to return ", num_loop_vars,
+ "outputs, got ", outputs->size());
+ }
+ // TODO(skyewm): check output types/shapes
+ return Status::OK();
+}
+
+} // namespace
+
+// A while loop with a single loop variable looks like this:
+//
+// (output)
+// ^ +---------------+
+// | | body subgraph +-------------+
+// Exit +---------------+ |
+// ^ ^ |
+// | | |
+// Switch<--------+ v
+// ^ | NextIteration
+// | +------+--------+ |
+// +---->| cond subgraph | |
+// | +---------------+ |
+// Merge<---------------------------+
+// ^
+// |
+// Enter
+// ^
+// |
+// (input)
+//
+// If there are multiple loop variables, each of the control flow ops is
+// duplicated for each loop variable.
+// TODO(skyewm): link to public version of design doc
+Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
+ const CondGraphBuilderFn& cond,
+ const BodyGraphBuilderFn& body, const string& frame_name,
+ OutputList* outputs) {
+ DCHECK(!inputs.empty());
+ DCHECK(outputs != nullptr);
+ DCHECK(outputs->empty());
+
+ TF_RETURN_IF_ERROR(scope.status());
+ const size_t num_loop_vars = inputs.size();
+
+ std::vector<Output> enter_outputs(num_loop_vars);
+ for (int i = 0; i < num_loop_vars; ++i) {
+ enter_outputs[i] = internal::Enter(scope, inputs[i], frame_name);
+ }
+ TF_RETURN_IF_ERROR(scope.status());
+
+ std::vector<Output> merge_outputs(num_loop_vars);
+ for (int i = 0; i < num_loop_vars; ++i) {
+ TF_RETURN_IF_ERROR(
+ CreateMerge(scope, i, enter_outputs[i], &merge_outputs[i]));
+ }
+
+ Output cond_out;
+ TF_RETURN_IF_ERROR(CreateCond(scope, cond, merge_outputs, &cond_out));
+
+ std::vector<Output> switch_trues(num_loop_vars);
+ std::vector<Output> switch_falses(num_loop_vars);
+ for (int i = 0; i < num_loop_vars; ++i) {
+ auto switch_i = Switch(scope, merge_outputs[i], cond_out);
+ switch_trues[i] = switch_i.output_true;
+ switch_falses[i] = switch_i.output_false;
+ }
+ TF_RETURN_IF_ERROR(scope.status());
+
+ std::vector<Output> body_outputs;
+ TF_RETURN_IF_ERROR(CreateBody(scope, body, switch_trues, &body_outputs));
+
+ std::vector<Output> next_outputs(num_loop_vars);
+ for (int i = 0; i < num_loop_vars; ++i) {
+ next_outputs[i] = NextIteration(scope, body_outputs[i]);
+ DCHECK_EQ(next_outputs[i].node()->name(), NextIterationName(scope, i));
+ }
+ TF_RETURN_IF_ERROR(scope.status());
+
+ // Create the backedges from the NextIteration nodes to the Merge nodes.
+ for (int i = 0; i < num_loop_vars; ++i) {
+ const int merge_backedge_output_index = 1;
+ scope.graph()->AddEdge(next_outputs[i].node(), next_outputs[i].index(),
+ merge_outputs[i].node(),
+ merge_backedge_output_index);
+ }
+
+ outputs->resize(num_loop_vars);
+ for (int i = 0; i < num_loop_vars; ++i) {
+ (*outputs)[i] = internal::Exit(scope, switch_falses[i]);
+ }
+ return scope.status();
+}
+
+} // namespace ops
+} // namespace tensorflow
diff --git a/tensorflow/cc/ops/while_loop.h b/tensorflow/cc/ops/while_loop.h
new file mode 100644
index 0000000000..253d5d8935
--- /dev/null
+++ b/tensorflow/cc/ops/while_loop.h
@@ -0,0 +1,64 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_
+#define THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/framework/scope.h"
+
+namespace tensorflow {
+namespace ops {
+
+// Function that takes cond graph inputs and returns cond graph boolean output.
+// 'output' need not be set if an error is returned.
+typedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
+ Output* output)>
+ CondGraphBuilderFn;
+
+// Function that takes body graph inputs and returns body graph outputs.
+// 'outputs' need not be populated if an error is returned.
+typedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs)>
+ BodyGraphBuilderFn;
+
+// Constructs a while loop.
+//
+// Arguments:
+// * scope: used to construct the while loop.
+// * inputs: the initial values of the loop variables. Must be non-empty.
+// * cond: a function that builds the condition graph of the loop. Takes the
+// current loop variables as inputs and returns a scalar boolean Output
+// indicating whether the loop should continue.
+// * body: a function that builds the body graph of the loop. Takes the current
+// loop variables as inputs and returns the updated loop variables.
+// * frame_name: the frame name to use for this while loop. This should be a
+// unique name. This will be used as a prefix for created operations.
+// * outputs: output param that returns final loop variable outputs in non-error
+// case. Must be non-null and empty.
+//
+// Returns an error if the while loop could not be fully constructed.
+//
+// TODO(skyewm): clean up partially-constructed loop in error case
+// TODO(skyewm): create public interface to this method
+Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
+ const CondGraphBuilderFn& cond,
+ const BodyGraphBuilderFn& body, const string& frame_name,
+ OutputList* outputs);
+
+} // namespace ops
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_