diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-08-29 08:24:23 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-29 08:28:50 -0700 |
commit | 0fd2a74120b86972441378f79fb5d03e86fed856 (patch) | |
tree | 89c1a3a0fba7638e287a1462802240fdf90c4727 /tensorflow/cc/ops | |
parent | 3c6a603d8ae9035830626df0e261442a59f2b990 (diff) |
Introduce C++ API while loop builder method
This change adds a new function, BuildWhileLoop(), that constructs a
while loop. BuildWhileLoop() takes functions that build the cond and
body graphs, similar to the Python while_loop function. It also
switches the C API to use this new function in order to reduce code
duplication. This is in preparation for while loop gradients, which
are also implemented in the C++ API (along with the other gradient
code).
I didn't write unit tests for BuildWhileLoop, instead relying on the
current C API while loop tests.
This change also disables while loop creation on Android to avoid
pulling in extra C++ dependencies.
PiperOrigin-RevId: 166849829
Diffstat (limited to 'tensorflow/cc/ops')
-rw-r--r-- | tensorflow/cc/ops/while_loop.cc | 223 | ||||
-rw-r--r-- | tensorflow/cc/ops/while_loop.h | 64 |
2 files changed, 287 insertions, 0 deletions
diff --git a/tensorflow/cc/ops/while_loop.cc b/tensorflow/cc/ops/while_loop.cc new file mode 100644 index 0000000000..27da77bbe0 --- /dev/null +++ b/tensorflow/cc/ops/while_loop.cc @@ -0,0 +1,223 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/while_loop.h" + +#include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/graph/node_builder.h" + +namespace tensorflow { +namespace ops { + +namespace { + +// Utility function for converting to internal C++ datatypes. +OutputTensor ToOutputTensor(const Output& output) { + return OutputTensor(output.node(), output.index()); +} + +// Utility function for converting to internal C++ datatypes. +std::vector<OutputTensor> ToOutputTensors(const std::vector<Output>& outputs) { + std::vector<OutputTensor> result(outputs.size()); + for (int i = 0; i < outputs.size(); ++i) { + result[i] = ToOutputTensor(outputs[i]); + } + return result; +} + +// Utility function for converting to internal C++ datatypes. +std::vector<Node*> ToNodes(const std::vector<Output>& outputs) { + std::vector<Node*> result(outputs.size()); + for (int i = 0; i < outputs.size(); ++i) { + result[i] = outputs[i].node(); + } + return result; +} + +// Manually generates the name of the `loop_var_idx`-th NextIteration node of a +// loop being constructed with `scope`. This is used to define the backedge +// before the NextIteration node is created. +string NextIterationName(const Scope& scope, int loop_var_idx) { + string result; + const string& prefix = scope.impl()->name(); + if (!prefix.empty()) strings::StrAppend(&result, prefix, "/"); + strings::StrAppend(&result, "NextIteration"); + if (loop_var_idx > 0) strings::StrAppend(&result, "_", loop_var_idx); + return result; +} + +// Creates the `loop_var_idx`-th Merge node of a loop being constructed with +// `scope`. `enter_output` is the `loop_var_idx`-th Enter node's output. +Status CreateMerge(const Scope& scope, int loop_var_idx, + const Output& enter_output, Output* merge_output) { + // The merge nodes accept the while loop's back edges as an input (i.e. the + // not-yet-created next iteration nodes). Use the underlying NodeBuilder API + // directly to create the back edge. + NodeBuilder::NodeOut enter_input(enter_output.node(), enter_output.index()); + + const int next_output_index = 0; + DataType dtype = enter_output.node()->output_type(0); + NodeBuilder::NodeOut next_input(NextIterationName(scope, loop_var_idx), + next_output_index, dtype); + + std::vector<NodeBuilder::NodeOut> input_list({enter_input, next_input}); + const string unique_name = scope.GetUniqueNameForOp("Merge"); + NodeBuilder builder = NodeBuilder(unique_name, "Merge").Input(input_list); + scope.UpdateBuilder(&builder); + + Node* merge_node; + TF_RETURN_IF_ERROR(builder.Finalize(scope.graph(), &merge_node)); + TF_RETURN_IF_ERROR(scope.DoShapeInference(merge_node)); + *merge_output = Output(merge_node, 0); + return Status::OK(); +} + +// Creates the condition subgraph defined by `cond`. +Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond, + const std::vector<Output>& inputs, Output* output) { + // The control dependency is for constants in the cond graph, and other ops + // that do not depend on the loop variables. This ensures that these ops are + // in the while loop frame (since they will indirectly depend on an Enter node + // defining the frame) and that they are executed once per loop iteration. + // + // TODO(skyewm): the control dep will be added to all nodes in the cond graph. + // This is at best unnecessary, and at worst may prevent different parts of + // different loop iterations from executing in parallel. + Scope cond_scope = + scope.NewSubScope("cond").WithControlDependencies(inputs[0]); + Output raw_cond_out; + TF_RETURN_IF_ERROR(cond(cond_scope, inputs, &raw_cond_out)); + if (raw_cond_out.type() != DT_BOOL) { + return errors::InvalidArgument( + "BuildWhileLoop: 'cond' argument must return a boolean output, got ", + DataTypeString(raw_cond_out.type())); + } + *output = LoopCond(scope, raw_cond_out).output; + return Status::OK(); +} + +// Create the bdoy subgraph defined by `body`. `outputs` must be non-null and +// empty. +Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body, + const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + DCHECK(outputs != nullptr); + DCHECK(outputs->empty()); + + // The control dependency is analogous to that in CreateCond(). + Scope body_scope = + scope.NewSubScope("body").WithControlDependencies(inputs[0]); + TF_RETURN_IF_ERROR(body(body_scope, inputs, outputs)); + const size_t num_loop_vars = inputs.size(); + if (outputs->size() != num_loop_vars) { + return errors::InvalidArgument( + "BuildWhileLoop: 'body' argument expected to return ", num_loop_vars, + "outputs, got ", outputs->size()); + } + // TODO(skyewm): check output types/shapes + return Status::OK(); +} + +} // namespace + +// A while loop with a single loop variable looks like this: +// +// (output) +// ^ +---------------+ +// | | body subgraph +-------------+ +// Exit +---------------+ | +// ^ ^ | +// | | | +// Switch<--------+ v +// ^ | NextIteration +// | +------+--------+ | +// +---->| cond subgraph | | +// | +---------------+ | +// Merge<---------------------------+ +// ^ +// | +// Enter +// ^ +// | +// (input) +// +// If there are multiple loop variables, each of the control flow ops is +// duplicated for each loop variable. +// TODO(skyewm): link to public version of design doc +Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs, + const CondGraphBuilderFn& cond, + const BodyGraphBuilderFn& body, const string& frame_name, + OutputList* outputs) { + DCHECK(!inputs.empty()); + DCHECK(outputs != nullptr); + DCHECK(outputs->empty()); + + TF_RETURN_IF_ERROR(scope.status()); + const size_t num_loop_vars = inputs.size(); + + std::vector<Output> enter_outputs(num_loop_vars); + for (int i = 0; i < num_loop_vars; ++i) { + enter_outputs[i] = internal::Enter(scope, inputs[i], frame_name); + } + TF_RETURN_IF_ERROR(scope.status()); + + std::vector<Output> merge_outputs(num_loop_vars); + for (int i = 0; i < num_loop_vars; ++i) { + TF_RETURN_IF_ERROR( + CreateMerge(scope, i, enter_outputs[i], &merge_outputs[i])); + } + + Output cond_out; + TF_RETURN_IF_ERROR(CreateCond(scope, cond, merge_outputs, &cond_out)); + + std::vector<Output> switch_trues(num_loop_vars); + std::vector<Output> switch_falses(num_loop_vars); + for (int i = 0; i < num_loop_vars; ++i) { + auto switch_i = Switch(scope, merge_outputs[i], cond_out); + switch_trues[i] = switch_i.output_true; + switch_falses[i] = switch_i.output_false; + } + TF_RETURN_IF_ERROR(scope.status()); + + std::vector<Output> body_outputs; + TF_RETURN_IF_ERROR(CreateBody(scope, body, switch_trues, &body_outputs)); + + std::vector<Output> next_outputs(num_loop_vars); + for (int i = 0; i < num_loop_vars; ++i) { + next_outputs[i] = NextIteration(scope, body_outputs[i]); + DCHECK_EQ(next_outputs[i].node()->name(), NextIterationName(scope, i)); + } + TF_RETURN_IF_ERROR(scope.status()); + + // Create the backedges from the NextIteration nodes to the Merge nodes. + for (int i = 0; i < num_loop_vars; ++i) { + const int merge_backedge_output_index = 1; + scope.graph()->AddEdge(next_outputs[i].node(), next_outputs[i].index(), + merge_outputs[i].node(), + merge_backedge_output_index); + } + + outputs->resize(num_loop_vars); + for (int i = 0; i < num_loop_vars; ++i) { + (*outputs)[i] = internal::Exit(scope, switch_falses[i]); + } + return scope.status(); +} + +} // namespace ops +} // namespace tensorflow diff --git a/tensorflow/cc/ops/while_loop.h b/tensorflow/cc/ops/while_loop.h new file mode 100644 index 0000000000..253d5d8935 --- /dev/null +++ b/tensorflow/cc/ops/while_loop.h @@ -0,0 +1,64 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_ +#define THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_ + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" + +namespace tensorflow { +namespace ops { + +// Function that takes cond graph inputs and returns cond graph boolean output. +// 'output' need not be set if an error is returned. +typedef std::function<Status(const Scope&, const std::vector<Output>& inputs, + Output* output)> + CondGraphBuilderFn; + +// Function that takes body graph inputs and returns body graph outputs. +// 'outputs' need not be populated if an error is returned. +typedef std::function<Status(const Scope&, const std::vector<Output>& inputs, + std::vector<Output>* outputs)> + BodyGraphBuilderFn; + +// Constructs a while loop. +// +// Arguments: +// * scope: used to construct the while loop. +// * inputs: the initial values of the loop variables. Must be non-empty. +// * cond: a function that builds the condition graph of the loop. Takes the +// current loop variables as inputs and returns a scalar boolean Output +// indicating whether the loop should continue. +// * body: a function that builds the body graph of the loop. Takes the current +// loop variables as inputs and returns the updated loop variables. +// * frame_name: the frame name to use for this while loop. This should be a +// unique name. This will be used as a prefix for created operations. +// * outputs: output param that returns final loop variable outputs in non-error +// case. Must be non-null and empty. +// +// Returns an error if the while loop could not be fully constructed. +// +// TODO(skyewm): clean up partially-constructed loop in error case +// TODO(skyewm): create public interface to this method +Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs, + const CondGraphBuilderFn& cond, + const BodyGraphBuilderFn& body, const string& frame_name, + OutputList* outputs); + +} // namespace ops +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_ |