aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/while_loop_test.cc39
-rw-r--r--tensorflow/cc/BUILD31
-rw-r--r--tensorflow/cc/framework/gradients.cc82
-rw-r--r--tensorflow/cc/framework/while_gradients.cc197
-rw-r--r--tensorflow/cc/framework/while_gradients.h40
-rw-r--r--tensorflow/cc/framework/while_gradients_test.cc233
-rw-r--r--tensorflow/cc/ops/while_loop.h7
-rw-r--r--tensorflow/contrib/cmake/tf_cc_ops.cmake2
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/graph/graph_partition_test.cc37
10 files changed, 658 insertions, 11 deletions
diff --git a/tensorflow/c/while_loop_test.cc b/tensorflow/c/while_loop_test.cc
index 27be5d787f..4698560bbe 100644
--- a/tensorflow/c/while_loop_test.cc
+++ b/tensorflow/c/while_loop_test.cc
@@ -73,6 +73,11 @@ class CApiWhileLoopTest : public ::testing::Test {
}
void Run(std::initializer_list<int> input_values) {
+ Run(outputs_, input_values);
+ }
+
+ void Run(const std::vector<TF_Output>& run_outputs,
+ std::initializer_list<int> input_values) {
DCHECK_EQ(inputs_.size(), input_values.size());
std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs(inputs_.size());
int i = 0;
@@ -82,7 +87,7 @@ class CApiWhileLoopTest : public ::testing::Test {
}
csession_.reset(new CSession(graph_, s_));
csession_->SetInputs(inputs);
- csession_->SetOutputs(outputs_);
+ csession_->SetOutputs(run_outputs);
csession_->Run(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
}
@@ -402,4 +407,36 @@ TEST_F(CApiWhileLoopTest, BadTypes) {
TF_AbortWhile(params_.get());
}
+// This is a basic test to make sure the C++ gradient code can handle while
+// loops created by the C API (which calls the C++ API under the hood). There
+// are more while loop gradient tests in cc/framework/while_gradients_test.cc.
+TEST_F(CApiWhileLoopTest, Gradients) {
+ Init(1);
+
+ // Create loop: while (i < 10) i += 1
+ TF_Operation* ten = ScalarConst(10, params_->cond_graph, s_);
+ TF_Operation* less_than =
+ LessThan(params_->cond_inputs[0], {ten, 0}, params_->cond_graph, s_);
+ DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ params_->cond_output = {less_than, 0};
+
+ TF_Operation* one = ScalarConst(1, params_->body_graph, s_);
+ TF_Operation* add =
+ Add(params_->body_inputs[0], {one, 0}, params_->body_graph, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ params_->body_outputs[0] = {add, 0};
+
+ ExpectOK();
+
+ // Create backprop graph
+ TF_Output grad_output;
+ TF_AddGradients(graph_, outputs_.data(), outputs_.size(), inputs_.data(), 1,
+ nullptr, s_, &grad_output);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ // Run gradient
+ Run({grad_output}, {0});
+ ExpectOutputValue(0, 1);
+}
+
} // namespace
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index b0c8cc3d0a..3682ebd943 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -19,13 +19,20 @@ load(
cc_library(
name = "gradients",
- srcs = ["framework/gradients.cc"],
+ srcs = [
+ "framework/gradients.cc",
+ "framework/while_gradients.cc",
+ "framework/while_gradients.h",
+ ],
hdrs = ["framework/gradients.h"],
deps = [
":cc_ops",
+ ":cc_ops_internal",
":grad_op_registry",
":ops",
":scope",
+ ":scope_internal",
+ ":while_loop",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -52,6 +59,28 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "framework_while_gradients_test",
+ size = "small",
+ srcs = ["framework/while_gradients_test.cc"],
+ deps = [
+ ":cc_ops",
+ ":client_session",
+ ":grad_op_registry",
+ ":grad_ops",
+ ":gradients",
+ ":testutil",
+ ":while_loop",
+ "//tensorflow/core:all_kernels",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
cc_library(
name = "gradient_checker",
srcs = ["framework/gradient_checker.cc"],
diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc
index b665ce744d..9825b02586 100644
--- a/tensorflow/cc/framework/gradients.cc
+++ b/tensorflow/cc/framework/gradients.cc
@@ -16,8 +16,9 @@ limitations under the License.
#include <deque>
#include <vector>
-#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
+#include "tensorflow/cc/framework/gradients.h"
+#include "tensorflow/cc/framework/while_gradients.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -25,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/while_context.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/macros.h"
@@ -82,6 +84,13 @@ class SymbolicGradientBuilder {
// from outputs_. Keyed by node id.
std::vector<bool> GetReachableNodes();
+ // Creates the gradient subgraph for a while loop (or just stores
+ // `summed_grads` if not all incoming gradients are available yet). All exit
+ // nodes (which are the first nodes of a loop encountered in the backwards
+ // pass) are passed to this function rather than processed normally.
+ // `summed_grads` is the sum of `exit_node`s gradients.
+ Status ProcessWhileLoop(Node* exit_node, const Output& summed_grads);
+
const Scope& scope_;
const ops::GradOpRegistry* registry_;
const std::vector<Output>& outputs_;
@@ -89,8 +98,7 @@ class SymbolicGradientBuilder {
const std::vector<Output>& grad_inputs_;
std::vector<Output>* grad_outputs_;
- // A vector of output endpoints which represents backpropagated
- // gradients
+ // A vector of output endpoints which represents backpropagated gradients
typedef std::vector<Output> BackpropedGradients;
// backprops_ is a map from a node output to its accumulated
@@ -117,6 +125,12 @@ class SymbolicGradientBuilder {
// frontier. Maps from Output -> index into `grad_outputs_`.
std::unordered_map<Output, int, OutputHash, OutputEq> input_nodes_;
+ // For each while loop in the graph, collects the summed gradients for each of
+ // the loop's exit nodes. Note that unlike backprops_, this map contains the
+ // output of SumGradients(), not the input (i.e. each exit node may have
+ // multiple incoming gradients, but we only store the combined Output here).
+ std::map<WhileContext*, std::map<Node*, Output>> while_backprops_;
+
TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientBuilder);
};
@@ -150,6 +164,7 @@ Status SymbolicGradientBuilder::BackpropAlongEdge(const Output& dst_grad,
std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
std::vector<bool> reachable_nodes(scope_.graph()->num_node_ids(), false);
std::deque<Node*> queue;
+ std::vector<bool> visited(scope_.graph()->num_node_ids(), false);
for (const Output& out : outputs_) {
if (!reachable_nodes[out.node()->id()]) {
queue.push_back(out.node());
@@ -162,8 +177,10 @@ std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
queue.pop_front();
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) continue;
+ if (visited[e->src()->id()]) continue;
queue.push_back(e->src());
reachable_nodes[e->src()->id()] = true;
+ visited[e->src()->id()] = true;
}
}
return reachable_nodes;
@@ -304,6 +321,53 @@ Status SymbolicGradientBuilder::CallGradFunction(
return Status::OK();
}
+Status SymbolicGradientBuilder::ProcessWhileLoop(Node* exit_node,
+ const Output& summed_grads) {
+ // TOOD(skyewm): detect second-order gradient and return bad status
+ // TODO(skyewm): handle (or at least detect) nested while loops
+
+ // TODO(skyewm): handle NoGradient in while loop
+ if (summed_grads == NoGradient()) {
+ return errors::Unimplemented(
+ "Missing gradient into while loop not yet implemented");
+ }
+
+ DCHECK(exit_node->IsExit());
+ WhileContext* while_ctx = exit_node->while_ctx();
+ DCHECK(while_ctx != nullptr);
+
+ // Record 'summed_grads' as the backprop input associated with 'exit_node'
+ std::map<Node*, Output>& backprops = while_backprops_[while_ctx];
+ DCHECK(backprops.find(exit_node) == backprops.end());
+ backprops[exit_node] = summed_grads;
+
+ // Wait until we have all exit nodes' backprops collected before processing
+ // the while loop.
+ // TODO(skyewm): what if not all the exit nodes are reachable?
+ if (backprops.size() < while_ctx->exit_nodes().size()) return Status::OK();
+
+ // We've seen all the exit nodes for this loop and have collected all the
+ // backprops. Create the gradient graph for the while loop.
+ Scope while_scope =
+ scope_.NewSubScope(strings::StrCat(while_ctx->frame_name(), "_grad"));
+ std::vector<Output> dy;
+ for (Node* n : while_ctx->exit_nodes()) dy.push_back(backprops[n]);
+ std::vector<Output> dx;
+ TF_RETURN_IF_ERROR(AddWhileLoopGradient(while_ctx, while_scope, dy, &dx));
+
+ // Backprop along the in edges to the while loop (i.e. the inputs to the enter
+ // nodes)
+ DCHECK_EQ(dx.size(), while_ctx->enter_nodes().size());
+ for (int i = 0; i < dx.size(); ++i) {
+ Node* enter_node = while_ctx->enter_nodes()[i];
+ for (const Edge* e : enter_node->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ TF_RETURN_IF_ERROR(BackpropAlongEdge(dx[i], {e->src(), e->src_output()}));
+ }
+ }
+ return Status::OK();
+}
+
Status SymbolicGradientBuilder::AddGradients() {
// Initialize backprops.
TF_RETURN_IF_ERROR(Initialize());
@@ -346,6 +410,18 @@ Status SymbolicGradientBuilder::AddGradients() {
continue;
}
+ // Special case: if we find an exit node, process the associated while loop.
+ // Note that ProcessWhileLoop() calls BackpropAlongEdge() if necessary
+ // (which updates ready_), and we skip all the regular processing below
+ // after calling it.
+ if (n->IsExit()) {
+ DCHECK_EQ(dy.size(), 1);
+ TF_RETURN_IF_ERROR(ProcessWhileLoop(n, dy[0]));
+ continue;
+ }
+ // All loop-specific control flow ops should have been handled above
+ DCHECK(!n->IsEnter() && !n->IsNextIteration()) << n->DebugString();
+
const size_t num_no_grad = no_grad_dy_indices.size();
if (IsPrimitiveOpWithNoGrad(n->type_string()) || num_no_grad == num_y) {
// No grad defined for this op, or all outputs returned 'NoGradient':
diff --git a/tensorflow/cc/framework/while_gradients.cc b/tensorflow/cc/framework/while_gradients.cc
new file mode 100644
index 0000000000..8234d5bea4
--- /dev/null
+++ b/tensorflow/cc/framework/while_gradients.cc
@@ -0,0 +1,197 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/framework/while_gradients.h"
+
+#include "tensorflow/cc/framework/gradients.h"
+#include "tensorflow/cc/framework/scope_internal.h"
+#include "tensorflow/cc/ops/control_flow_ops_internal.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/cc/ops/while_loop.h"
+
+namespace tensorflow {
+namespace {
+
+using ops::BodyGraphBuilderFn;
+using ops::BuildWhileLoop;
+using ops::CondGraphBuilderFn;
+
+Output ToOutput(OutputTensor output_tensor) {
+ return Output(const_cast<Node*>(output_tensor.node), output_tensor.index);
+}
+
+std::vector<Output> ToOutputVector(
+ const std::vector<OutputTensor>& output_tensors) {
+ size_t n = output_tensors.size();
+ std::vector<Output> result(n);
+ for (int i = 0; i < n; ++i) result[i] = ToOutput(output_tensors[i]);
+ return result;
+}
+
+// The backprop loop counter and main backprop loop run in their own execution
+// frame (conceptually, the main forward loop and forward loop counter run
+// together in a frame, then the backprop loop counter and backprop loop run
+// together in a different frame). This returns the frame name to use for the
+// backprop while loops.
+// TODO(skyewm): make sure this is unique among existing frame names
+string BackPropFrameName(const string& forward_frame_name) {
+ return strings::StrCat(forward_frame_name, "_backprop");
+}
+
+// Creates a loop that counts the number of iterations performed by the
+// while loop associated with `while_ctx`. The returned output yields the
+// iteration count.
+Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope,
+ Output* count) {
+ // Create while loop:
+ // i = 0
+ // while forward loop predicate is true:
+ // ++i
+
+ Output zero = ops::Const(scope, 0, {});
+
+ // Condition function that returns condition output from original while loop.
+ CondGraphBuilderFn cond_fn = [while_ctx](const Scope& scope,
+ const std::vector<Output>& inputs,
+ Output* output) {
+ *output = ToOutput(while_ctx->cond_output());
+ return Status::OK();
+ };
+
+ // Body function that adds one to input.
+ BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope,
+ const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ DCHECK_EQ(inputs.size(), 1);
+ outputs->emplace_back(ops::Add(scope, inputs[0], 1));
+ return scope.status();
+ };
+
+ // Note that this loop runs in the same execution frame as the forward loop.
+ std::vector<Output> outputs;
+ TF_RETURN_IF_ERROR(BuildWhileLoop(scope, {zero}, cond_fn, body_fn,
+ while_ctx->frame_name(), &outputs,
+ /* create_while_ctx */ false));
+ *count = outputs[0];
+ return Status::OK();
+}
+
+// Creates a loop that executes `loop_count` times. The returned output is the
+// boolean predicate indicating if the loop is still executing. This is used to
+// drive the gradient computation for the while loop associated with
+// `while_ctx`.
+Status AddBackPropLoopCounter(WhileContext* while_ctx, const Output& loop_count,
+ const Scope& scope,
+ Output* backprop_execution_pred) {
+ // Create while loop:
+ // n = loop_count
+ // while n > 0:
+ // --n
+
+ // Condition function that returns input > 0.
+ CondGraphBuilderFn cond_fn = [](const Scope& scope,
+ const std::vector<Output>& inputs,
+ Output* output) {
+ DCHECK_EQ(inputs.size(), 1);
+ *output = ops::Greater(scope, inputs[0], 0);
+ return scope.status();
+ };
+
+ // Body function that subtracts one from input.
+ BodyGraphBuilderFn body_fn = [](const Scope& scope,
+ const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ DCHECK_EQ(inputs.size(), 1);
+ outputs->emplace_back(ops::Subtract(scope, inputs[0], 1));
+ return scope.status();
+ };
+
+ string frame_name = BackPropFrameName(while_ctx->frame_name());
+ std::vector<Output> outputs; // unused
+ TF_RETURN_IF_ERROR(BuildWhileLoop(
+ scope, {loop_count}, cond_fn, body_fn, frame_name, &outputs,
+ /* create_while_ctx */ false, backprop_execution_pred));
+ return Status::OK();
+}
+
+// Creates the main backprop loop that computes the gradient of the loop
+// associated with `while_ctx`. `grad_inputs` are the partial derivatives
+// w.r.t. the loop outputs, i.e. the exit nodes. `backprop_execution_pred` is
+// the predicate to use for the backprop loop (see AddBackPropLoopCounter()).
+// The partial derivatives w.r.t. the loop inputs, i.e. the input loop vars, are
+// returned in `grad_outputs`.
+Status AddWhileGradientLoop(WhileContext* while_ctx,
+ const std::vector<Output>& grad_inputs,
+ const Output& backprop_execution_pred,
+ const Scope& parent_scope,
+ std::vector<Output>* grad_outputs) {
+ DCHECK_EQ(grad_inputs.size(), while_ctx->body_outputs().size());
+ DCHECK_EQ(while_ctx->body_inputs().size(), while_ctx->body_outputs().size());
+
+ Scope scope = parent_scope.NewSubScope("while");
+
+ // Create while loop:
+ // while backprop_execution_pred:
+ // forward loop body gradient
+
+ // Condition function that returns 'backprop_execution_pred'.
+ CondGraphBuilderFn cond_fn = [backprop_execution_pred](
+ const Scope& scope,
+ const std::vector<Output>& inputs,
+ Output* output) {
+ *output = backprop_execution_pred;
+ return Status::OK();
+ };
+
+ // Body function that builds while body gradient subgraph.
+ BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope,
+ const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ std::vector<Output> body_outputs =
+ ToOutputVector(while_ctx->body_outputs());
+ std::vector<Output> body_inputs = ToOutputVector(while_ctx->body_inputs());
+ return AddSymbolicGradients(scope, body_outputs, body_inputs, inputs,
+ outputs);
+ };
+
+ string frame_name = BackPropFrameName(while_ctx->frame_name());
+ TF_RETURN_IF_ERROR(BuildWhileLoop(scope, grad_inputs, cond_fn, body_fn,
+ frame_name, grad_outputs,
+ /* create_while_ctx */ false));
+ return Status::OK();
+}
+
+} // namespace
+
+Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ Output forward_loop_count;
+ TF_RETURN_IF_ERROR(AddForwardLoopCounter(
+ while_ctx, scope.NewSubScope("ForwardLoopCounter"), &forward_loop_count));
+
+ // TODO(skyewm): can we combine the backprop loop counter and main gradient
+ // loop into a single loop? The original Python code doesn't combine the
+ // loops, but I'm not sure why.
+ Output backprop_counter_cond;
+ TF_RETURN_IF_ERROR(AddBackPropLoopCounter(
+ while_ctx, forward_loop_count, scope.NewSubScope("BackPropLoopCounter"),
+ &backprop_counter_cond));
+
+ return AddWhileGradientLoop(while_ctx, grad_inputs, backprop_counter_cond,
+ scope, grad_outputs);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/cc/framework/while_gradients.h b/tensorflow/cc/framework/while_gradients.h
new file mode 100644
index 0000000000..8f592accc9
--- /dev/null
+++ b/tensorflow/cc/framework/while_gradients.h
@@ -0,0 +1,40 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_
+#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/core/graph/while_context.h"
+
+// Utility functions for constructing while loop gradients
+
+namespace tensorflow {
+
+// Adds the gradient computation for the while loop associated with
+// `while_ctx`. `grad_inputs` are the partial derivatives w.r.t. the loop
+// outputs, i.e. the exit nodes. The partial derivatives w.r.t. the loop
+// inputs, i.e. the input loop vars, are returned in `grad_outputs`.
+// `grad_inputs` and `grad_outputs` are both in loop-variable order, as defined
+// by the original inputs to BuildWhileLoop().
+// TODO(skyewm): maybe comment on NoGradient once it's supported
+Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs);
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_
diff --git a/tensorflow/cc/framework/while_gradients_test.cc b/tensorflow/cc/framework/while_gradients_test.cc
new file mode 100644
index 0000000000..39fa7477c5
--- /dev/null
+++ b/tensorflow/cc/framework/while_gradients_test.cc
@@ -0,0 +1,233 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/framework/gradients.h"
+#include "tensorflow/cc/framework/testutil.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/cc/ops/while_loop.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+namespace {
+
+class WhileGradientsTest : public ::testing::Test {
+ protected:
+ WhileGradientsTest() : scope_(Scope::NewRootScope()) {}
+
+ void Init(int num_inputs, DataType dtype = DT_INT32) {
+ for (int i = 0; i < num_inputs; ++i) {
+ inputs_.push_back(ops::Placeholder(scope_, dtype));
+ }
+ }
+
+ void CreateLoop(const ops::CondGraphBuilderFn& cond,
+ const ops::BodyGraphBuilderFn& body,
+ const std::vector<Output>* inputs = nullptr) {
+ if (inputs == nullptr) inputs = &inputs_;
+ TF_ASSERT_OK(ops::BuildWhileLoop(scope_, *inputs, cond, body, "test_loop",
+ &outputs_));
+ }
+
+ void CreateBackprop() {
+ TF_ASSERT_OK(
+ AddSymbolicGradients(scope_, outputs_, inputs_, &grad_outputs_));
+ ASSERT_EQ(grad_outputs_.size(), inputs_.size());
+ }
+
+ template <typename T>
+ void Run(const std::vector<Input::Initializer>& input_values,
+ const std::vector<T>& expected_grad_values) {
+ Run<T>(ClientSession(scope_), input_values, expected_grad_values);
+ }
+
+ template <typename T>
+ void Run(const ClientSession& session,
+ const std::vector<Input::Initializer>& input_values,
+ const std::vector<T>& expected_grad_values,
+ const RunOptions& run_options = RunOptions(),
+ RunMetadata* run_metadata = nullptr) {
+ DCHECK_EQ(input_values.size(), inputs_.size());
+ ClientSession::FeedType feeds;
+ for (int i = 0; i < inputs_.size(); ++i) {
+ feeds.emplace(inputs_[i], input_values[i]);
+ }
+
+ std::vector<Operation> run_outputs;
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(run_options, feeds, grad_outputs_, run_outputs,
+ &out_tensors, run_metadata));
+ ASSERT_EQ(out_tensors.size(), grad_outputs_.size());
+
+ DCHECK_EQ(expected_grad_values.size(), out_tensors.size());
+ for (int i = 0; i < out_tensors.size(); ++i) {
+ test::ExpectTensorEqual<T>(
+ out_tensors[i], test::AsTensor<T>({expected_grad_values[i]}, {}));
+ }
+ }
+
+ Scope scope_;
+ std::vector<Output> inputs_;
+ std::vector<Output> outputs_;
+ std::vector<Output> grad_outputs_;
+};
+
+TEST_F(WhileGradientsTest, Basic) {
+ // Create loop: while (i < 10) i += 1
+ Init(1);
+ CreateLoop(
+ [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
+ *output = ops::Less(s, inputs[0], 10);
+ return s.status();
+ },
+ [](const Scope& s, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ // Use AddN, rather than Add, because the gradient function doesn't
+ // depend on the input shapes, and thus we do not need to store
+ // intermediate values in a stack.
+ outputs->push_back(ops::AddN(s, {inputs[0], 1}));
+ return s.status();
+ });
+ CreateBackprop();
+
+ Run<int>({1}, {1});
+ Run<int>({11}, {1});
+}
+
+TEST_F(WhileGradientsTest, MultipleLoopVars) {
+ // Create loop: while (i < 10) i += j; j += 1; k = k
+ Init(3);
+ CreateLoop(
+ [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
+ *output = ops::Less(s, inputs[0], 10);
+ return s.status();
+ },
+ [](const Scope& s, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ outputs->push_back(ops::AddN(s, {inputs[0], inputs[1]}));
+ outputs->push_back(ops::AddN(s, {inputs[1], 1}));
+ outputs->push_back(inputs[2]);
+ return s.status();
+ });
+ CreateBackprop();
+
+ // The following execution traces illustrate why we expect dF/dj to be 5:
+ //
+ // i j k
+ // ---------
+ // 0 1 2 <-- initial values
+ // 1 2 2
+ // 3 3 2
+ // 6 4 2
+ // 10 5 2 <-- while output values
+ // outputs sum = 17
+ //
+ // i j k
+ // ---------
+ // 0 2 2 <-- initial values (add 1 to j)
+ // 2 3 2
+ // 5 4 2
+ // 9 5 2
+ // 14 6 2 <-- while output values
+ // outputs sum = 22
+ //
+ // Calculate the "slope" between j=1 and j=2:
+ // 22 - 17 = 5 => dF/dj = 5
+ Run<int>({0, 1, 2}, {1, 5, 1});
+
+ Run<int>({1, 1, 0}, {1, 5, 1});
+ Run<int>({0, 0, 0}, {1, 6, 1});
+}
+
+TEST_F(WhileGradientsTest, Chaining) {
+ Init(2, DT_DOUBLE);
+
+ // Multiply each input by 2 before passing to while loop to make sure chaining
+ // works properly
+ std::vector<Output> loop_inputs = {ops::Multiply(scope_, inputs_[0], 2.0),
+ ops::Multiply(scope_, inputs_[1], 2.0)};
+
+ // Create loop: while (i > 0 && j > 0) i -= 1
+ CreateLoop(
+ [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
+ *output = ops::LogicalAnd(s, ops::Greater(s, inputs[0], 0.0),
+ ops::Greater(s, inputs[1], 0.0));
+ return s.status();
+ },
+ [](const Scope& s, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ outputs->push_back(ops::AddN(s, {inputs[0], -1.0}));
+ outputs->push_back(inputs[1]);
+ return s.status();
+ },
+ &loop_inputs);
+
+ // Take negative of first output to make sure chaining works properly
+ outputs_[0] = ops::Neg(scope_, outputs_[0]);
+
+ CreateBackprop();
+
+ Run<double>({1.0, 1.0}, {-2.0, 2.0});
+ Run<double>({0.0, 0.0}, {-2.0, 2.0});
+}
+
+TEST_F(WhileGradientsTest, MultipleDevices) {
+ // Make sure loop is created on cpu0
+ scope_ = scope_.WithDevice("/cpu:0");
+
+ // Create loop: while (i < 10) i += j
+ Init(2);
+ CreateLoop(
+ [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
+ *output = ops::Less(s, inputs[0], 10);
+ return s.status();
+ },
+ [](const Scope& s, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ // Place body on cpu1
+ Scope cpu1_scope = s.WithDevice("/cpu:1");
+ outputs->push_back(ops::AddN(cpu1_scope, {inputs[0], inputs[1]}));
+ outputs->push_back(inputs[1]);
+ return cpu1_scope.status();
+ });
+
+ // Build gradient graph on cpu1
+ Scope cpu1_scope = scope_.WithDevice("/cpu:1");
+ TF_ASSERT_OK(
+ AddSymbolicGradients(cpu1_scope, outputs_, inputs_, &grad_outputs_));
+ ASSERT_EQ(grad_outputs_.size(), inputs_.size());
+
+ // Run with two CPU devices and output partition graphs
+ SessionOptions session_options;
+ (*session_options.config.mutable_device_count())["CPU"] = 2;
+ RunOptions run_options;
+ run_options.set_output_partition_graphs(true);
+ RunMetadata run_metadata;
+ Run<int>(ClientSession(scope_, session_options), {0, 1}, {1, 11}, run_options,
+ &run_metadata);
+
+ // Check that at least one node ran on each device
+ ASSERT_EQ(run_metadata.partition_graphs().size(), 2);
+ for (const GraphDef& partition_graph : run_metadata.partition_graphs()) {
+ EXPECT_GE(partition_graph.node().size(), 1);
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/cc/ops/while_loop.h b/tensorflow/cc/ops/while_loop.h
index 82181516d6..a04476056a 100644
--- a/tensorflow/cc/ops/while_loop.h
+++ b/tensorflow/cc/ops/while_loop.h
@@ -49,7 +49,12 @@ typedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
// * outputs: output param that returns final loop variable outputs in non-error
// case. Must be non-null and empty.
// * create_while_ctx: if true, a WhileContext is created and populated for this
-// loop. See core/graph/while_context.h for more details.
+// loop. See core/graph/while_context.h for more details on
+// WhileContexts. This is set to false for loops used as part of gradient
+// computations, since they're part of the gradient for a loop in the
+// forward-pass.
+// TODO(skyewm): revisit this. Should we create WhileContexts for all loops,
+// even if we don't need them?
// * cond_output: if non-null, the output of the predicate is returned. This
// will always be a LoopCond node.
//
diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake
index 6632433087..a5f5ae5478 100644
--- a/tensorflow/contrib/cmake/tf_cc_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake
@@ -135,6 +135,8 @@ set(tf_cc_srcs
"${tensorflow_source_dir}/tensorflow/cc/framework/gradient_checker.cc"
"${tensorflow_source_dir}/tensorflow/cc/framework/gradients.h"
"${tensorflow_source_dir}/tensorflow/cc/framework/gradients.cc"
+ "${tensorflow_source_dir}/tensorflow/cc/framework/while_gradients.h"
+ "${tensorflow_source_dir}/tensorflow/cc/framework/while_gradients.cc"
)
file(GLOB_RECURSE tf_cc_test_srcs
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 5502eebd7f..5ca5ef916b 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2613,6 +2613,7 @@ tf_cc_tests(
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:scope",
"//tensorflow/cc:sendrecv_ops",
+ "//tensorflow/cc:while_loop",
"//tensorflow/core/kernels:ops_util",
"//third_party/eigen3",
],
diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc
index 8dde7320ed..858ef8ac01 100644
--- a/tensorflow/core/graph/graph_partition_test.cc
+++ b/tensorflow/core/graph/graph_partition_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/cc/ops/random_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/cc/ops/while_loop.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/op.h"
@@ -72,10 +73,13 @@ void Partition(const GraphDef& graph_def,
GraphConstructorOptions opts;
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &g));
- // Assigns devices to each node. Uses 1st letter of the node name as
- // the device index.
+ // Assigns devices to each node. Uses 1st letter of the node name as the
+ // device index if no device is specified.
for (Node* node : g.nodes()) {
- node->set_assigned_device_name(DeviceName(node));
+ string device_name = !node->requested_device().empty()
+ ? node->requested_device()
+ : DeviceName(node);
+ node->set_assigned_device_name(device_name);
}
PartitionOptions popts;
@@ -368,7 +372,7 @@ TEST_F(GraphPartitionTest, CrossDevice_DataControl) {
ExpectMatchB();
}
-TEST_F(GraphPartitionTest, CrossDeviceLoop) {
+TEST_F(GraphPartitionTest, CrossDeviceLoopSimple) {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
auto a1 = BoolInput(in_.WithOpName("A1"));
auto a2 = ::tensorflow::ops::internal::Enter(in_.WithOpName("A2"), a1, "foo");
@@ -382,7 +386,7 @@ TEST_F(GraphPartitionTest, CrossDeviceLoop) {
CheckLoopConstruction(ToGraphDef());
}
-TEST_F(GraphPartitionTest, CrossDeviceLoop1) {
+TEST_F(GraphPartitionTest, CrossDeviceLoopSimple1) {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
auto a1 = BoolInput(in_.WithOpName("A1"));
auto a2 = ::tensorflow::ops::internal::Enter(in_.WithOpName("B2"), a1, "foo");
@@ -407,6 +411,29 @@ TEST_F(GraphPartitionTest, CrossDeviceLoop1) {
}
}
+TEST_F(GraphPartitionTest, CrossDeviceLoopFull) {
+ Scope cpu0 = in_.WithDevice("/job:a/replica:0/task:0/cpu:0");
+ auto p1 = ops::Placeholder(cpu0, DT_INT32);
+ auto p2 = ops::Placeholder(cpu0, DT_INT32);
+ OutputList outputs;
+ // while i1 < 10: i1 += i2
+ TF_ASSERT_OK(ops::BuildWhileLoop(
+ cpu0, {p1, p2},
+ [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
+ *output = ops::Less(s, inputs[0], 10);
+ return s.status();
+ },
+ [](const Scope& s, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ Scope cpu1 = s.WithDevice("/job:a/replica:0/task:0/cpu:1");
+ outputs->push_back(ops::AddN(cpu1, {inputs[0], inputs[1]}));
+ outputs->push_back(inputs[1]);
+ return s.status();
+ },
+ "test_loop", &outputs));
+ CheckLoopConstruction(ToGraphDef());
+}
+
TEST_F(GraphPartitionTest, PartitionIncompleteGraph) {
NodeDef ndef;
Graph g(OpRegistry::Global());