diff options
Diffstat (limited to 'tensorflow/cc/framework/while_gradients.cc')
-rw-r--r-- | tensorflow/cc/framework/while_gradients.cc | 197 |
1 files changed, 197 insertions, 0 deletions
diff --git a/tensorflow/cc/framework/while_gradients.cc b/tensorflow/cc/framework/while_gradients.cc new file mode 100644 index 0000000000..8234d5bea4 --- /dev/null +++ b/tensorflow/cc/framework/while_gradients.cc @@ -0,0 +1,197 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/framework/while_gradients.h" + +#include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/cc/ops/while_loop.h" + +namespace tensorflow { +namespace { + +using ops::BodyGraphBuilderFn; +using ops::BuildWhileLoop; +using ops::CondGraphBuilderFn; + +Output ToOutput(OutputTensor output_tensor) { + return Output(const_cast<Node*>(output_tensor.node), output_tensor.index); +} + +std::vector<Output> ToOutputVector( + const std::vector<OutputTensor>& output_tensors) { + size_t n = output_tensors.size(); + std::vector<Output> result(n); + for (int i = 0; i < n; ++i) result[i] = ToOutput(output_tensors[i]); + return result; +} + +// The backprop loop counter and main backprop loop run in their own execution +// frame (conceptually, the main forward loop and forward loop counter run +// together in a frame, then the backprop loop counter and backprop loop run +// together in a different frame). This returns the frame name to use for the +// backprop while loops. +// TODO(skyewm): make sure this is unique among existing frame names +string BackPropFrameName(const string& forward_frame_name) { + return strings::StrCat(forward_frame_name, "_backprop"); +} + +// Creates a loop that counts the number of iterations performed by the +// while loop associated with `while_ctx`. The returned output yields the +// iteration count. +Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope, + Output* count) { + // Create while loop: + // i = 0 + // while forward loop predicate is true: + // ++i + + Output zero = ops::Const(scope, 0, {}); + + // Condition function that returns condition output from original while loop. + CondGraphBuilderFn cond_fn = [while_ctx](const Scope& scope, + const std::vector<Output>& inputs, + Output* output) { + *output = ToOutput(while_ctx->cond_output()); + return Status::OK(); + }; + + // Body function that adds one to input. + BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope, + const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + DCHECK_EQ(inputs.size(), 1); + outputs->emplace_back(ops::Add(scope, inputs[0], 1)); + return scope.status(); + }; + + // Note that this loop runs in the same execution frame as the forward loop. + std::vector<Output> outputs; + TF_RETURN_IF_ERROR(BuildWhileLoop(scope, {zero}, cond_fn, body_fn, + while_ctx->frame_name(), &outputs, + /* create_while_ctx */ false)); + *count = outputs[0]; + return Status::OK(); +} + +// Creates a loop that executes `loop_count` times. The returned output is the +// boolean predicate indicating if the loop is still executing. This is used to +// drive the gradient computation for the while loop associated with +// `while_ctx`. +Status AddBackPropLoopCounter(WhileContext* while_ctx, const Output& loop_count, + const Scope& scope, + Output* backprop_execution_pred) { + // Create while loop: + // n = loop_count + // while n > 0: + // --n + + // Condition function that returns input > 0. + CondGraphBuilderFn cond_fn = [](const Scope& scope, + const std::vector<Output>& inputs, + Output* output) { + DCHECK_EQ(inputs.size(), 1); + *output = ops::Greater(scope, inputs[0], 0); + return scope.status(); + }; + + // Body function that subtracts one from input. + BodyGraphBuilderFn body_fn = [](const Scope& scope, + const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + DCHECK_EQ(inputs.size(), 1); + outputs->emplace_back(ops::Subtract(scope, inputs[0], 1)); + return scope.status(); + }; + + string frame_name = BackPropFrameName(while_ctx->frame_name()); + std::vector<Output> outputs; // unused + TF_RETURN_IF_ERROR(BuildWhileLoop( + scope, {loop_count}, cond_fn, body_fn, frame_name, &outputs, + /* create_while_ctx */ false, backprop_execution_pred)); + return Status::OK(); +} + +// Creates the main backprop loop that computes the gradient of the loop +// associated with `while_ctx`. `grad_inputs` are the partial derivatives +// w.r.t. the loop outputs, i.e. the exit nodes. `backprop_execution_pred` is +// the predicate to use for the backprop loop (see AddBackPropLoopCounter()). +// The partial derivatives w.r.t. the loop inputs, i.e. the input loop vars, are +// returned in `grad_outputs`. +Status AddWhileGradientLoop(WhileContext* while_ctx, + const std::vector<Output>& grad_inputs, + const Output& backprop_execution_pred, + const Scope& parent_scope, + std::vector<Output>* grad_outputs) { + DCHECK_EQ(grad_inputs.size(), while_ctx->body_outputs().size()); + DCHECK_EQ(while_ctx->body_inputs().size(), while_ctx->body_outputs().size()); + + Scope scope = parent_scope.NewSubScope("while"); + + // Create while loop: + // while backprop_execution_pred: + // forward loop body gradient + + // Condition function that returns 'backprop_execution_pred'. + CondGraphBuilderFn cond_fn = [backprop_execution_pred]( + const Scope& scope, + const std::vector<Output>& inputs, + Output* output) { + *output = backprop_execution_pred; + return Status::OK(); + }; + + // Body function that builds while body gradient subgraph. + BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope, + const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + std::vector<Output> body_outputs = + ToOutputVector(while_ctx->body_outputs()); + std::vector<Output> body_inputs = ToOutputVector(while_ctx->body_inputs()); + return AddSymbolicGradients(scope, body_outputs, body_inputs, inputs, + outputs); + }; + + string frame_name = BackPropFrameName(while_ctx->frame_name()); + TF_RETURN_IF_ERROR(BuildWhileLoop(scope, grad_inputs, cond_fn, body_fn, + frame_name, grad_outputs, + /* create_while_ctx */ false)); + return Status::OK(); +} + +} // namespace + +Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + Output forward_loop_count; + TF_RETURN_IF_ERROR(AddForwardLoopCounter( + while_ctx, scope.NewSubScope("ForwardLoopCounter"), &forward_loop_count)); + + // TODO(skyewm): can we combine the backprop loop counter and main gradient + // loop into a single loop? The original Python code doesn't combine the + // loops, but I'm not sure why. + Output backprop_counter_cond; + TF_RETURN_IF_ERROR(AddBackPropLoopCounter( + while_ctx, forward_loop_count, scope.NewSubScope("BackPropLoopCounter"), + &backprop_counter_cond)); + + return AddWhileGradientLoop(while_ctx, grad_inputs, backprop_counter_cond, + scope, grad_outputs); +} + +} // namespace tensorflow |