aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/framework/while_gradients.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc/framework/while_gradients.cc')
-rw-r--r--tensorflow/cc/framework/while_gradients.cc197
1 files changed, 197 insertions, 0 deletions
diff --git a/tensorflow/cc/framework/while_gradients.cc b/tensorflow/cc/framework/while_gradients.cc
new file mode 100644
index 0000000000..8234d5bea4
--- /dev/null
+++ b/tensorflow/cc/framework/while_gradients.cc
@@ -0,0 +1,197 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/framework/while_gradients.h"
+
+#include "tensorflow/cc/framework/gradients.h"
+#include "tensorflow/cc/framework/scope_internal.h"
+#include "tensorflow/cc/ops/control_flow_ops_internal.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/cc/ops/while_loop.h"
+
+namespace tensorflow {
+namespace {
+
+using ops::BodyGraphBuilderFn;
+using ops::BuildWhileLoop;
+using ops::CondGraphBuilderFn;
+
+Output ToOutput(OutputTensor output_tensor) {
+ return Output(const_cast<Node*>(output_tensor.node), output_tensor.index);
+}
+
+std::vector<Output> ToOutputVector(
+ const std::vector<OutputTensor>& output_tensors) {
+ size_t n = output_tensors.size();
+ std::vector<Output> result(n);
+ for (int i = 0; i < n; ++i) result[i] = ToOutput(output_tensors[i]);
+ return result;
+}
+
+// The backprop loop counter and main backprop loop run in their own execution
+// frame (conceptually, the main forward loop and forward loop counter run
+// together in a frame, then the backprop loop counter and backprop loop run
+// together in a different frame). This returns the frame name to use for the
+// backprop while loops.
+// TODO(skyewm): make sure this is unique among existing frame names
+string BackPropFrameName(const string& forward_frame_name) {
+ return strings::StrCat(forward_frame_name, "_backprop");
+}
+
+// Creates a loop that counts the number of iterations performed by the
+// while loop associated with `while_ctx`. The returned output yields the
+// iteration count.
+Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope,
+ Output* count) {
+ // Create while loop:
+ // i = 0
+ // while forward loop predicate is true:
+ // ++i
+
+ Output zero = ops::Const(scope, 0, {});
+
+ // Condition function that returns condition output from original while loop.
+ CondGraphBuilderFn cond_fn = [while_ctx](const Scope& scope,
+ const std::vector<Output>& inputs,
+ Output* output) {
+ *output = ToOutput(while_ctx->cond_output());
+ return Status::OK();
+ };
+
+ // Body function that adds one to input.
+ BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope,
+ const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ DCHECK_EQ(inputs.size(), 1);
+ outputs->emplace_back(ops::Add(scope, inputs[0], 1));
+ return scope.status();
+ };
+
+ // Note that this loop runs in the same execution frame as the forward loop.
+ std::vector<Output> outputs;
+ TF_RETURN_IF_ERROR(BuildWhileLoop(scope, {zero}, cond_fn, body_fn,
+ while_ctx->frame_name(), &outputs,
+ /* create_while_ctx */ false));
+ *count = outputs[0];
+ return Status::OK();
+}
+
+// Creates a loop that executes `loop_count` times. The returned output is the
+// boolean predicate indicating if the loop is still executing. This is used to
+// drive the gradient computation for the while loop associated with
+// `while_ctx`.
+Status AddBackPropLoopCounter(WhileContext* while_ctx, const Output& loop_count,
+ const Scope& scope,
+ Output* backprop_execution_pred) {
+ // Create while loop:
+ // n = loop_count
+ // while n > 0:
+ // --n
+
+ // Condition function that returns input > 0.
+ CondGraphBuilderFn cond_fn = [](const Scope& scope,
+ const std::vector<Output>& inputs,
+ Output* output) {
+ DCHECK_EQ(inputs.size(), 1);
+ *output = ops::Greater(scope, inputs[0], 0);
+ return scope.status();
+ };
+
+ // Body function that subtracts one from input.
+ BodyGraphBuilderFn body_fn = [](const Scope& scope,
+ const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ DCHECK_EQ(inputs.size(), 1);
+ outputs->emplace_back(ops::Subtract(scope, inputs[0], 1));
+ return scope.status();
+ };
+
+ string frame_name = BackPropFrameName(while_ctx->frame_name());
+ std::vector<Output> outputs; // unused
+ TF_RETURN_IF_ERROR(BuildWhileLoop(
+ scope, {loop_count}, cond_fn, body_fn, frame_name, &outputs,
+ /* create_while_ctx */ false, backprop_execution_pred));
+ return Status::OK();
+}
+
+// Creates the main backprop loop that computes the gradient of the loop
+// associated with `while_ctx`. `grad_inputs` are the partial derivatives
+// w.r.t. the loop outputs, i.e. the exit nodes. `backprop_execution_pred` is
+// the predicate to use for the backprop loop (see AddBackPropLoopCounter()).
+// The partial derivatives w.r.t. the loop inputs, i.e. the input loop vars, are
+// returned in `grad_outputs`.
+Status AddWhileGradientLoop(WhileContext* while_ctx,
+ const std::vector<Output>& grad_inputs,
+ const Output& backprop_execution_pred,
+ const Scope& parent_scope,
+ std::vector<Output>* grad_outputs) {
+ DCHECK_EQ(grad_inputs.size(), while_ctx->body_outputs().size());
+ DCHECK_EQ(while_ctx->body_inputs().size(), while_ctx->body_outputs().size());
+
+ Scope scope = parent_scope.NewSubScope("while");
+
+ // Create while loop:
+ // while backprop_execution_pred:
+ // forward loop body gradient
+
+ // Condition function that returns 'backprop_execution_pred'.
+ CondGraphBuilderFn cond_fn = [backprop_execution_pred](
+ const Scope& scope,
+ const std::vector<Output>& inputs,
+ Output* output) {
+ *output = backprop_execution_pred;
+ return Status::OK();
+ };
+
+ // Body function that builds while body gradient subgraph.
+ BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope,
+ const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ std::vector<Output> body_outputs =
+ ToOutputVector(while_ctx->body_outputs());
+ std::vector<Output> body_inputs = ToOutputVector(while_ctx->body_inputs());
+ return AddSymbolicGradients(scope, body_outputs, body_inputs, inputs,
+ outputs);
+ };
+
+ string frame_name = BackPropFrameName(while_ctx->frame_name());
+ TF_RETURN_IF_ERROR(BuildWhileLoop(scope, grad_inputs, cond_fn, body_fn,
+ frame_name, grad_outputs,
+ /* create_while_ctx */ false));
+ return Status::OK();
+}
+
+} // namespace
+
+Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ Output forward_loop_count;
+ TF_RETURN_IF_ERROR(AddForwardLoopCounter(
+ while_ctx, scope.NewSubScope("ForwardLoopCounter"), &forward_loop_count));
+
+ // TODO(skyewm): can we combine the backprop loop counter and main gradient
+ // loop into a single loop? The original Python code doesn't combine the
+ // loops, but I'm not sure why.
+ Output backprop_counter_cond;
+ TF_RETURN_IF_ERROR(AddBackPropLoopCounter(
+ while_ctx, forward_loop_count, scope.NewSubScope("BackPropLoopCounter"),
+ &backprop_counter_cond));
+
+ return AddWhileGradientLoop(while_ctx, grad_inputs, backprop_counter_cond,
+ scope, grad_outputs);
+}
+
+} // namespace tensorflow