aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/framework
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-09-27 13:28:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-27 13:32:17 -0700
commit301b14c240fe99249dc2225132a7ebe5cbecbdc4 (patch)
tree8de10cf0180ff41211b294c05cae199208487c85 /tensorflow/cc/framework
parent545e3572f7d8928eeb220e8b55c71ad33a9343c6 (diff)
Basic while loop gradient functionality in C++
This change introduces the basic framework to create the gradient graph of a while loop using the C++ API. This supports building the gradient graph as long as the body function of the while loop contains no ops whose gradient function requires a stack. In other words, it doesn't support gradient functions that use the input values to the op (e.g. add will work, but multiply will not). It also doesn't support nested while loops, and doesn't detect all error cases. PiperOrigin-RevId: 170243281
Diffstat (limited to 'tensorflow/cc/framework')
-rw-r--r--tensorflow/cc/framework/gradients.cc82
-rw-r--r--tensorflow/cc/framework/while_gradients.cc197
-rw-r--r--tensorflow/cc/framework/while_gradients.h40
-rw-r--r--tensorflow/cc/framework/while_gradients_test.cc233
4 files changed, 549 insertions, 3 deletions
diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc
index b665ce744d..9825b02586 100644
--- a/tensorflow/cc/framework/gradients.cc
+++ b/tensorflow/cc/framework/gradients.cc
@@ -16,8 +16,9 @@ limitations under the License.
#include <deque>
#include <vector>
-#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
+#include "tensorflow/cc/framework/gradients.h"
+#include "tensorflow/cc/framework/while_gradients.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -25,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/while_context.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/macros.h"
@@ -82,6 +84,13 @@ class SymbolicGradientBuilder {
// from outputs_. Keyed by node id.
std::vector<bool> GetReachableNodes();
+ // Creates the gradient subgraph for a while loop (or just stores
+ // `summed_grads` if not all incoming gradients are available yet). All exit
+ // nodes (which are the first nodes of a loop encountered in the backwards
+ // pass) are passed to this function rather than processed normally.
+ // `summed_grads` is the sum of `exit_node`s gradients.
+ Status ProcessWhileLoop(Node* exit_node, const Output& summed_grads);
+
const Scope& scope_;
const ops::GradOpRegistry* registry_;
const std::vector<Output>& outputs_;
@@ -89,8 +98,7 @@ class SymbolicGradientBuilder {
const std::vector<Output>& grad_inputs_;
std::vector<Output>* grad_outputs_;
- // A vector of output endpoints which represents backpropagated
- // gradients
+ // A vector of output endpoints which represents backpropagated gradients
typedef std::vector<Output> BackpropedGradients;
// backprops_ is a map from a node output to its accumulated
@@ -117,6 +125,12 @@ class SymbolicGradientBuilder {
// frontier. Maps from Output -> index into `grad_outputs_`.
std::unordered_map<Output, int, OutputHash, OutputEq> input_nodes_;
+ // For each while loop in the graph, collects the summed gradients for each of
+ // the loop's exit nodes. Note that unlike backprops_, this map contains the
+ // output of SumGradients(), not the input (i.e. each exit node may have
+ // multiple incoming gradients, but we only store the combined Output here).
+ std::map<WhileContext*, std::map<Node*, Output>> while_backprops_;
+
TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientBuilder);
};
@@ -150,6 +164,7 @@ Status SymbolicGradientBuilder::BackpropAlongEdge(const Output& dst_grad,
std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
std::vector<bool> reachable_nodes(scope_.graph()->num_node_ids(), false);
std::deque<Node*> queue;
+ std::vector<bool> visited(scope_.graph()->num_node_ids(), false);
for (const Output& out : outputs_) {
if (!reachable_nodes[out.node()->id()]) {
queue.push_back(out.node());
@@ -162,8 +177,10 @@ std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
queue.pop_front();
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) continue;
+ if (visited[e->src()->id()]) continue;
queue.push_back(e->src());
reachable_nodes[e->src()->id()] = true;
+ visited[e->src()->id()] = true;
}
}
return reachable_nodes;
@@ -304,6 +321,53 @@ Status SymbolicGradientBuilder::CallGradFunction(
return Status::OK();
}
+Status SymbolicGradientBuilder::ProcessWhileLoop(Node* exit_node,
+ const Output& summed_grads) {
+ // TOOD(skyewm): detect second-order gradient and return bad status
+ // TODO(skyewm): handle (or at least detect) nested while loops
+
+ // TODO(skyewm): handle NoGradient in while loop
+ if (summed_grads == NoGradient()) {
+ return errors::Unimplemented(
+ "Missing gradient into while loop not yet implemented");
+ }
+
+ DCHECK(exit_node->IsExit());
+ WhileContext* while_ctx = exit_node->while_ctx();
+ DCHECK(while_ctx != nullptr);
+
+ // Record 'summed_grads' as the backprop input associated with 'exit_node'
+ std::map<Node*, Output>& backprops = while_backprops_[while_ctx];
+ DCHECK(backprops.find(exit_node) == backprops.end());
+ backprops[exit_node] = summed_grads;
+
+ // Wait until we have all exit nodes' backprops collected before processing
+ // the while loop.
+ // TODO(skyewm): what if not all the exit nodes are reachable?
+ if (backprops.size() < while_ctx->exit_nodes().size()) return Status::OK();
+
+ // We've seen all the exit nodes for this loop and have collected all the
+ // backprops. Create the gradient graph for the while loop.
+ Scope while_scope =
+ scope_.NewSubScope(strings::StrCat(while_ctx->frame_name(), "_grad"));
+ std::vector<Output> dy;
+ for (Node* n : while_ctx->exit_nodes()) dy.push_back(backprops[n]);
+ std::vector<Output> dx;
+ TF_RETURN_IF_ERROR(AddWhileLoopGradient(while_ctx, while_scope, dy, &dx));
+
+ // Backprop along the in edges to the while loop (i.e. the inputs to the enter
+ // nodes)
+ DCHECK_EQ(dx.size(), while_ctx->enter_nodes().size());
+ for (int i = 0; i < dx.size(); ++i) {
+ Node* enter_node = while_ctx->enter_nodes()[i];
+ for (const Edge* e : enter_node->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ TF_RETURN_IF_ERROR(BackpropAlongEdge(dx[i], {e->src(), e->src_output()}));
+ }
+ }
+ return Status::OK();
+}
+
Status SymbolicGradientBuilder::AddGradients() {
// Initialize backprops.
TF_RETURN_IF_ERROR(Initialize());
@@ -346,6 +410,18 @@ Status SymbolicGradientBuilder::AddGradients() {
continue;
}
+ // Special case: if we find an exit node, process the associated while loop.
+ // Note that ProcessWhileLoop() calls BackpropAlongEdge() if necessary
+ // (which updates ready_), and we skip all the regular processing below
+ // after calling it.
+ if (n->IsExit()) {
+ DCHECK_EQ(dy.size(), 1);
+ TF_RETURN_IF_ERROR(ProcessWhileLoop(n, dy[0]));
+ continue;
+ }
+ // All loop-specific control flow ops should have been handled above
+ DCHECK(!n->IsEnter() && !n->IsNextIteration()) << n->DebugString();
+
const size_t num_no_grad = no_grad_dy_indices.size();
if (IsPrimitiveOpWithNoGrad(n->type_string()) || num_no_grad == num_y) {
// No grad defined for this op, or all outputs returned 'NoGradient':
diff --git a/tensorflow/cc/framework/while_gradients.cc b/tensorflow/cc/framework/while_gradients.cc
new file mode 100644
index 0000000000..8234d5bea4
--- /dev/null
+++ b/tensorflow/cc/framework/while_gradients.cc
@@ -0,0 +1,197 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/framework/while_gradients.h"
+
+#include "tensorflow/cc/framework/gradients.h"
+#include "tensorflow/cc/framework/scope_internal.h"
+#include "tensorflow/cc/ops/control_flow_ops_internal.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/cc/ops/while_loop.h"
+
+namespace tensorflow {
+namespace {
+
+using ops::BodyGraphBuilderFn;
+using ops::BuildWhileLoop;
+using ops::CondGraphBuilderFn;
+
+Output ToOutput(OutputTensor output_tensor) {
+ return Output(const_cast<Node*>(output_tensor.node), output_tensor.index);
+}
+
+std::vector<Output> ToOutputVector(
+ const std::vector<OutputTensor>& output_tensors) {
+ size_t n = output_tensors.size();
+ std::vector<Output> result(n);
+ for (int i = 0; i < n; ++i) result[i] = ToOutput(output_tensors[i]);
+ return result;
+}
+
+// The backprop loop counter and main backprop loop run in their own execution
+// frame (conceptually, the main forward loop and forward loop counter run
+// together in a frame, then the backprop loop counter and backprop loop run
+// together in a different frame). This returns the frame name to use for the
+// backprop while loops.
+// TODO(skyewm): make sure this is unique among existing frame names
+string BackPropFrameName(const string& forward_frame_name) {
+ return strings::StrCat(forward_frame_name, "_backprop");
+}
+
+// Creates a loop that counts the number of iterations performed by the
+// while loop associated with `while_ctx`. The returned output yields the
+// iteration count.
+Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope,
+ Output* count) {
+ // Create while loop:
+ // i = 0
+ // while forward loop predicate is true:
+ // ++i
+
+ Output zero = ops::Const(scope, 0, {});
+
+ // Condition function that returns condition output from original while loop.
+ CondGraphBuilderFn cond_fn = [while_ctx](const Scope& scope,
+ const std::vector<Output>& inputs,
+ Output* output) {
+ *output = ToOutput(while_ctx->cond_output());
+ return Status::OK();
+ };
+
+ // Body function that adds one to input.
+ BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope,
+ const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ DCHECK_EQ(inputs.size(), 1);
+ outputs->emplace_back(ops::Add(scope, inputs[0], 1));
+ return scope.status();
+ };
+
+ // Note that this loop runs in the same execution frame as the forward loop.
+ std::vector<Output> outputs;
+ TF_RETURN_IF_ERROR(BuildWhileLoop(scope, {zero}, cond_fn, body_fn,
+ while_ctx->frame_name(), &outputs,
+ /* create_while_ctx */ false));
+ *count = outputs[0];
+ return Status::OK();
+}
+
+// Creates a loop that executes `loop_count` times. The returned output is the
+// boolean predicate indicating if the loop is still executing. This is used to
+// drive the gradient computation for the while loop associated with
+// `while_ctx`.
+Status AddBackPropLoopCounter(WhileContext* while_ctx, const Output& loop_count,
+ const Scope& scope,
+ Output* backprop_execution_pred) {
+ // Create while loop:
+ // n = loop_count
+ // while n > 0:
+ // --n
+
+ // Condition function that returns input > 0.
+ CondGraphBuilderFn cond_fn = [](const Scope& scope,
+ const std::vector<Output>& inputs,
+ Output* output) {
+ DCHECK_EQ(inputs.size(), 1);
+ *output = ops::Greater(scope, inputs[0], 0);
+ return scope.status();
+ };
+
+ // Body function that subtracts one from input.
+ BodyGraphBuilderFn body_fn = [](const Scope& scope,
+ const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ DCHECK_EQ(inputs.size(), 1);
+ outputs->emplace_back(ops::Subtract(scope, inputs[0], 1));
+ return scope.status();
+ };
+
+ string frame_name = BackPropFrameName(while_ctx->frame_name());
+ std::vector<Output> outputs; // unused
+ TF_RETURN_IF_ERROR(BuildWhileLoop(
+ scope, {loop_count}, cond_fn, body_fn, frame_name, &outputs,
+ /* create_while_ctx */ false, backprop_execution_pred));
+ return Status::OK();
+}
+
+// Creates the main backprop loop that computes the gradient of the loop
+// associated with `while_ctx`. `grad_inputs` are the partial derivatives
+// w.r.t. the loop outputs, i.e. the exit nodes. `backprop_execution_pred` is
+// the predicate to use for the backprop loop (see AddBackPropLoopCounter()).
+// The partial derivatives w.r.t. the loop inputs, i.e. the input loop vars, are
+// returned in `grad_outputs`.
+Status AddWhileGradientLoop(WhileContext* while_ctx,
+ const std::vector<Output>& grad_inputs,
+ const Output& backprop_execution_pred,
+ const Scope& parent_scope,
+ std::vector<Output>* grad_outputs) {
+ DCHECK_EQ(grad_inputs.size(), while_ctx->body_outputs().size());
+ DCHECK_EQ(while_ctx->body_inputs().size(), while_ctx->body_outputs().size());
+
+ Scope scope = parent_scope.NewSubScope("while");
+
+ // Create while loop:
+ // while backprop_execution_pred:
+ // forward loop body gradient
+
+ // Condition function that returns 'backprop_execution_pred'.
+ CondGraphBuilderFn cond_fn = [backprop_execution_pred](
+ const Scope& scope,
+ const std::vector<Output>& inputs,
+ Output* output) {
+ *output = backprop_execution_pred;
+ return Status::OK();
+ };
+
+ // Body function that builds while body gradient subgraph.
+ BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope,
+ const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ std::vector<Output> body_outputs =
+ ToOutputVector(while_ctx->body_outputs());
+ std::vector<Output> body_inputs = ToOutputVector(while_ctx->body_inputs());
+ return AddSymbolicGradients(scope, body_outputs, body_inputs, inputs,
+ outputs);
+ };
+
+ string frame_name = BackPropFrameName(while_ctx->frame_name());
+ TF_RETURN_IF_ERROR(BuildWhileLoop(scope, grad_inputs, cond_fn, body_fn,
+ frame_name, grad_outputs,
+ /* create_while_ctx */ false));
+ return Status::OK();
+}
+
+} // namespace
+
+Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ Output forward_loop_count;
+ TF_RETURN_IF_ERROR(AddForwardLoopCounter(
+ while_ctx, scope.NewSubScope("ForwardLoopCounter"), &forward_loop_count));
+
+ // TODO(skyewm): can we combine the backprop loop counter and main gradient
+ // loop into a single loop? The original Python code doesn't combine the
+ // loops, but I'm not sure why.
+ Output backprop_counter_cond;
+ TF_RETURN_IF_ERROR(AddBackPropLoopCounter(
+ while_ctx, forward_loop_count, scope.NewSubScope("BackPropLoopCounter"),
+ &backprop_counter_cond));
+
+ return AddWhileGradientLoop(while_ctx, grad_inputs, backprop_counter_cond,
+ scope, grad_outputs);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/cc/framework/while_gradients.h b/tensorflow/cc/framework/while_gradients.h
new file mode 100644
index 0000000000..8f592accc9
--- /dev/null
+++ b/tensorflow/cc/framework/while_gradients.h
@@ -0,0 +1,40 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_
+#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/core/graph/while_context.h"
+
+// Utility functions for constructing while loop gradients
+
+namespace tensorflow {
+
+// Adds the gradient computation for the while loop associated with
+// `while_ctx`. `grad_inputs` are the partial derivatives w.r.t. the loop
+// outputs, i.e. the exit nodes. The partial derivatives w.r.t. the loop
+// inputs, i.e. the input loop vars, are returned in `grad_outputs`.
+// `grad_inputs` and `grad_outputs` are both in loop-variable order, as defined
+// by the original inputs to BuildWhileLoop().
+// TODO(skyewm): maybe comment on NoGradient once it's supported
+Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs);
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_
diff --git a/tensorflow/cc/framework/while_gradients_test.cc b/tensorflow/cc/framework/while_gradients_test.cc
new file mode 100644
index 0000000000..39fa7477c5
--- /dev/null
+++ b/tensorflow/cc/framework/while_gradients_test.cc
@@ -0,0 +1,233 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/framework/gradients.h"
+#include "tensorflow/cc/framework/testutil.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/cc/ops/while_loop.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+namespace {
+
+class WhileGradientsTest : public ::testing::Test {
+ protected:
+ WhileGradientsTest() : scope_(Scope::NewRootScope()) {}
+
+ void Init(int num_inputs, DataType dtype = DT_INT32) {
+ for (int i = 0; i < num_inputs; ++i) {
+ inputs_.push_back(ops::Placeholder(scope_, dtype));
+ }
+ }
+
+ void CreateLoop(const ops::CondGraphBuilderFn& cond,
+ const ops::BodyGraphBuilderFn& body,
+ const std::vector<Output>* inputs = nullptr) {
+ if (inputs == nullptr) inputs = &inputs_;
+ TF_ASSERT_OK(ops::BuildWhileLoop(scope_, *inputs, cond, body, "test_loop",
+ &outputs_));
+ }
+
+ void CreateBackprop() {
+ TF_ASSERT_OK(
+ AddSymbolicGradients(scope_, outputs_, inputs_, &grad_outputs_));
+ ASSERT_EQ(grad_outputs_.size(), inputs_.size());
+ }
+
+ template <typename T>
+ void Run(const std::vector<Input::Initializer>& input_values,
+ const std::vector<T>& expected_grad_values) {
+ Run<T>(ClientSession(scope_), input_values, expected_grad_values);
+ }
+
+ template <typename T>
+ void Run(const ClientSession& session,
+ const std::vector<Input::Initializer>& input_values,
+ const std::vector<T>& expected_grad_values,
+ const RunOptions& run_options = RunOptions(),
+ RunMetadata* run_metadata = nullptr) {
+ DCHECK_EQ(input_values.size(), inputs_.size());
+ ClientSession::FeedType feeds;
+ for (int i = 0; i < inputs_.size(); ++i) {
+ feeds.emplace(inputs_[i], input_values[i]);
+ }
+
+ std::vector<Operation> run_outputs;
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(run_options, feeds, grad_outputs_, run_outputs,
+ &out_tensors, run_metadata));
+ ASSERT_EQ(out_tensors.size(), grad_outputs_.size());
+
+ DCHECK_EQ(expected_grad_values.size(), out_tensors.size());
+ for (int i = 0; i < out_tensors.size(); ++i) {
+ test::ExpectTensorEqual<T>(
+ out_tensors[i], test::AsTensor<T>({expected_grad_values[i]}, {}));
+ }
+ }
+
+ Scope scope_;
+ std::vector<Output> inputs_;
+ std::vector<Output> outputs_;
+ std::vector<Output> grad_outputs_;
+};
+
+TEST_F(WhileGradientsTest, Basic) {
+ // Create loop: while (i < 10) i += 1
+ Init(1);
+ CreateLoop(
+ [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
+ *output = ops::Less(s, inputs[0], 10);
+ return s.status();
+ },
+ [](const Scope& s, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ // Use AddN, rather than Add, because the gradient function doesn't
+ // depend on the input shapes, and thus we do not need to store
+ // intermediate values in a stack.
+ outputs->push_back(ops::AddN(s, {inputs[0], 1}));
+ return s.status();
+ });
+ CreateBackprop();
+
+ Run<int>({1}, {1});
+ Run<int>({11}, {1});
+}
+
+TEST_F(WhileGradientsTest, MultipleLoopVars) {
+ // Create loop: while (i < 10) i += j; j += 1; k = k
+ Init(3);
+ CreateLoop(
+ [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
+ *output = ops::Less(s, inputs[0], 10);
+ return s.status();
+ },
+ [](const Scope& s, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ outputs->push_back(ops::AddN(s, {inputs[0], inputs[1]}));
+ outputs->push_back(ops::AddN(s, {inputs[1], 1}));
+ outputs->push_back(inputs[2]);
+ return s.status();
+ });
+ CreateBackprop();
+
+ // The following execution traces illustrate why we expect dF/dj to be 5:
+ //
+ // i j k
+ // ---------
+ // 0 1 2 <-- initial values
+ // 1 2 2
+ // 3 3 2
+ // 6 4 2
+ // 10 5 2 <-- while output values
+ // outputs sum = 17
+ //
+ // i j k
+ // ---------
+ // 0 2 2 <-- initial values (add 1 to j)
+ // 2 3 2
+ // 5 4 2
+ // 9 5 2
+ // 14 6 2 <-- while output values
+ // outputs sum = 22
+ //
+ // Calculate the "slope" between j=1 and j=2:
+ // 22 - 17 = 5 => dF/dj = 5
+ Run<int>({0, 1, 2}, {1, 5, 1});
+
+ Run<int>({1, 1, 0}, {1, 5, 1});
+ Run<int>({0, 0, 0}, {1, 6, 1});
+}
+
+TEST_F(WhileGradientsTest, Chaining) {
+ Init(2, DT_DOUBLE);
+
+ // Multiply each input by 2 before passing to while loop to make sure chaining
+ // works properly
+ std::vector<Output> loop_inputs = {ops::Multiply(scope_, inputs_[0], 2.0),
+ ops::Multiply(scope_, inputs_[1], 2.0)};
+
+ // Create loop: while (i > 0 && j > 0) i -= 1
+ CreateLoop(
+ [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
+ *output = ops::LogicalAnd(s, ops::Greater(s, inputs[0], 0.0),
+ ops::Greater(s, inputs[1], 0.0));
+ return s.status();
+ },
+ [](const Scope& s, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ outputs->push_back(ops::AddN(s, {inputs[0], -1.0}));
+ outputs->push_back(inputs[1]);
+ return s.status();
+ },
+ &loop_inputs);
+
+ // Take negative of first output to make sure chaining works properly
+ outputs_[0] = ops::Neg(scope_, outputs_[0]);
+
+ CreateBackprop();
+
+ Run<double>({1.0, 1.0}, {-2.0, 2.0});
+ Run<double>({0.0, 0.0}, {-2.0, 2.0});
+}
+
+TEST_F(WhileGradientsTest, MultipleDevices) {
+ // Make sure loop is created on cpu0
+ scope_ = scope_.WithDevice("/cpu:0");
+
+ // Create loop: while (i < 10) i += j
+ Init(2);
+ CreateLoop(
+ [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
+ *output = ops::Less(s, inputs[0], 10);
+ return s.status();
+ },
+ [](const Scope& s, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ // Place body on cpu1
+ Scope cpu1_scope = s.WithDevice("/cpu:1");
+ outputs->push_back(ops::AddN(cpu1_scope, {inputs[0], inputs[1]}));
+ outputs->push_back(inputs[1]);
+ return cpu1_scope.status();
+ });
+
+ // Build gradient graph on cpu1
+ Scope cpu1_scope = scope_.WithDevice("/cpu:1");
+ TF_ASSERT_OK(
+ AddSymbolicGradients(cpu1_scope, outputs_, inputs_, &grad_outputs_));
+ ASSERT_EQ(grad_outputs_.size(), inputs_.size());
+
+ // Run with two CPU devices and output partition graphs
+ SessionOptions session_options;
+ (*session_options.config.mutable_device_count())["CPU"] = 2;
+ RunOptions run_options;
+ run_options.set_output_partition_graphs(true);
+ RunMetadata run_metadata;
+ Run<int>(ClientSession(scope_, session_options), {0, 1}, {1, 11}, run_options,
+ &run_metadata);
+
+ // Check that at least one node ran on each device
+ ASSERT_EQ(run_metadata.partition_graphs().size(), 2);
+ for (const GraphDef& partition_graph : run_metadata.partition_graphs()) {
+ EXPECT_GE(partition_graph.node().size(), 1);
+ }
+}
+
+} // namespace
+} // namespace tensorflow