diff options
-rw-r--r-- | tensorflow/c/while_loop_test.cc | 39 | ||||
-rw-r--r-- | tensorflow/cc/BUILD | 31 | ||||
-rw-r--r-- | tensorflow/cc/framework/gradients.cc | 82 | ||||
-rw-r--r-- | tensorflow/cc/framework/while_gradients.cc | 197 | ||||
-rw-r--r-- | tensorflow/cc/framework/while_gradients.h | 40 | ||||
-rw-r--r-- | tensorflow/cc/framework/while_gradients_test.cc | 233 | ||||
-rw-r--r-- | tensorflow/cc/ops/while_loop.h | 7 | ||||
-rw-r--r-- | tensorflow/contrib/cmake/tf_cc_ops.cmake | 2 | ||||
-rw-r--r-- | tensorflow/core/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_partition_test.cc | 37 |
10 files changed, 658 insertions, 11 deletions
diff --git a/tensorflow/c/while_loop_test.cc b/tensorflow/c/while_loop_test.cc index 27be5d787f..4698560bbe 100644 --- a/tensorflow/c/while_loop_test.cc +++ b/tensorflow/c/while_loop_test.cc @@ -73,6 +73,11 @@ class CApiWhileLoopTest : public ::testing::Test { } void Run(std::initializer_list<int> input_values) { + Run(outputs_, input_values); + } + + void Run(const std::vector<TF_Output>& run_outputs, + std::initializer_list<int> input_values) { DCHECK_EQ(inputs_.size(), input_values.size()); std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs(inputs_.size()); int i = 0; @@ -82,7 +87,7 @@ class CApiWhileLoopTest : public ::testing::Test { } csession_.reset(new CSession(graph_, s_)); csession_->SetInputs(inputs); - csession_->SetOutputs(outputs_); + csession_->SetOutputs(run_outputs); csession_->Run(s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); } @@ -402,4 +407,36 @@ TEST_F(CApiWhileLoopTest, BadTypes) { TF_AbortWhile(params_.get()); } +// This is a basic test to make sure the C++ gradient code can handle while +// loops created by the C API (which calls the C++ API under the hood). There +// are more while loop gradient tests in cc/framework/while_gradients_test.cc. +TEST_F(CApiWhileLoopTest, Gradients) { + Init(1); + + // Create loop: while (i < 10) i += 1 + TF_Operation* ten = ScalarConst(10, params_->cond_graph, s_); + TF_Operation* less_than = + LessThan(params_->cond_inputs[0], {ten, 0}, params_->cond_graph, s_); + DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->cond_output = {less_than, 0}; + + TF_Operation* one = ScalarConst(1, params_->body_graph, s_); + TF_Operation* add = + Add(params_->body_inputs[0], {one, 0}, params_->body_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->body_outputs[0] = {add, 0}; + + ExpectOK(); + + // Create backprop graph + TF_Output grad_output; + TF_AddGradients(graph_, outputs_.data(), outputs_.size(), inputs_.data(), 1, + nullptr, s_, &grad_output); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Run gradient + Run({grad_output}, {0}); + ExpectOutputValue(0, 1); +} + } // namespace diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index b0c8cc3d0a..3682ebd943 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -19,13 +19,20 @@ load( cc_library( name = "gradients", - srcs = ["framework/gradients.cc"], + srcs = [ + "framework/gradients.cc", + "framework/while_gradients.cc", + "framework/while_gradients.h", + ], hdrs = ["framework/gradients.h"], deps = [ ":cc_ops", + ":cc_ops_internal", ":grad_op_registry", ":ops", ":scope", + ":scope_internal", + ":while_loop", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -52,6 +59,28 @@ tf_cc_test( ], ) +tf_cc_test( + name = "framework_while_gradients_test", + size = "small", + srcs = ["framework/while_gradients_test.cc"], + deps = [ + ":cc_ops", + ":client_session", + ":grad_op_registry", + ":grad_ops", + ":gradients", + ":testutil", + ":while_loop", + "//tensorflow/core:all_kernels", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "gradient_checker", srcs = ["framework/gradient_checker.cc"], diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index b665ce744d..9825b02586 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -16,8 +16,9 @@ limitations under the License. #include <deque> #include <vector> -#include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/framework/while_gradients.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" @@ -25,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/while_context.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/macros.h" @@ -82,6 +84,13 @@ class SymbolicGradientBuilder { // from outputs_. Keyed by node id. std::vector<bool> GetReachableNodes(); + // Creates the gradient subgraph for a while loop (or just stores + // `summed_grads` if not all incoming gradients are available yet). All exit + // nodes (which are the first nodes of a loop encountered in the backwards + // pass) are passed to this function rather than processed normally. + // `summed_grads` is the sum of `exit_node`s gradients. + Status ProcessWhileLoop(Node* exit_node, const Output& summed_grads); + const Scope& scope_; const ops::GradOpRegistry* registry_; const std::vector<Output>& outputs_; @@ -89,8 +98,7 @@ class SymbolicGradientBuilder { const std::vector<Output>& grad_inputs_; std::vector<Output>* grad_outputs_; - // A vector of output endpoints which represents backpropagated - // gradients + // A vector of output endpoints which represents backpropagated gradients typedef std::vector<Output> BackpropedGradients; // backprops_ is a map from a node output to its accumulated @@ -117,6 +125,12 @@ class SymbolicGradientBuilder { // frontier. Maps from Output -> index into `grad_outputs_`. std::unordered_map<Output, int, OutputHash, OutputEq> input_nodes_; + // For each while loop in the graph, collects the summed gradients for each of + // the loop's exit nodes. Note that unlike backprops_, this map contains the + // output of SumGradients(), not the input (i.e. each exit node may have + // multiple incoming gradients, but we only store the combined Output here). + std::map<WhileContext*, std::map<Node*, Output>> while_backprops_; + TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientBuilder); }; @@ -150,6 +164,7 @@ Status SymbolicGradientBuilder::BackpropAlongEdge(const Output& dst_grad, std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() { std::vector<bool> reachable_nodes(scope_.graph()->num_node_ids(), false); std::deque<Node*> queue; + std::vector<bool> visited(scope_.graph()->num_node_ids(), false); for (const Output& out : outputs_) { if (!reachable_nodes[out.node()->id()]) { queue.push_back(out.node()); @@ -162,8 +177,10 @@ std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() { queue.pop_front(); for (const Edge* e : n->in_edges()) { if (e->IsControlEdge()) continue; + if (visited[e->src()->id()]) continue; queue.push_back(e->src()); reachable_nodes[e->src()->id()] = true; + visited[e->src()->id()] = true; } } return reachable_nodes; @@ -304,6 +321,53 @@ Status SymbolicGradientBuilder::CallGradFunction( return Status::OK(); } +Status SymbolicGradientBuilder::ProcessWhileLoop(Node* exit_node, + const Output& summed_grads) { + // TOOD(skyewm): detect second-order gradient and return bad status + // TODO(skyewm): handle (or at least detect) nested while loops + + // TODO(skyewm): handle NoGradient in while loop + if (summed_grads == NoGradient()) { + return errors::Unimplemented( + "Missing gradient into while loop not yet implemented"); + } + + DCHECK(exit_node->IsExit()); + WhileContext* while_ctx = exit_node->while_ctx(); + DCHECK(while_ctx != nullptr); + + // Record 'summed_grads' as the backprop input associated with 'exit_node' + std::map<Node*, Output>& backprops = while_backprops_[while_ctx]; + DCHECK(backprops.find(exit_node) == backprops.end()); + backprops[exit_node] = summed_grads; + + // Wait until we have all exit nodes' backprops collected before processing + // the while loop. + // TODO(skyewm): what if not all the exit nodes are reachable? + if (backprops.size() < while_ctx->exit_nodes().size()) return Status::OK(); + + // We've seen all the exit nodes for this loop and have collected all the + // backprops. Create the gradient graph for the while loop. + Scope while_scope = + scope_.NewSubScope(strings::StrCat(while_ctx->frame_name(), "_grad")); + std::vector<Output> dy; + for (Node* n : while_ctx->exit_nodes()) dy.push_back(backprops[n]); + std::vector<Output> dx; + TF_RETURN_IF_ERROR(AddWhileLoopGradient(while_ctx, while_scope, dy, &dx)); + + // Backprop along the in edges to the while loop (i.e. the inputs to the enter + // nodes) + DCHECK_EQ(dx.size(), while_ctx->enter_nodes().size()); + for (int i = 0; i < dx.size(); ++i) { + Node* enter_node = while_ctx->enter_nodes()[i]; + for (const Edge* e : enter_node->in_edges()) { + if (e->IsControlEdge()) continue; + TF_RETURN_IF_ERROR(BackpropAlongEdge(dx[i], {e->src(), e->src_output()})); + } + } + return Status::OK(); +} + Status SymbolicGradientBuilder::AddGradients() { // Initialize backprops. TF_RETURN_IF_ERROR(Initialize()); @@ -346,6 +410,18 @@ Status SymbolicGradientBuilder::AddGradients() { continue; } + // Special case: if we find an exit node, process the associated while loop. + // Note that ProcessWhileLoop() calls BackpropAlongEdge() if necessary + // (which updates ready_), and we skip all the regular processing below + // after calling it. + if (n->IsExit()) { + DCHECK_EQ(dy.size(), 1); + TF_RETURN_IF_ERROR(ProcessWhileLoop(n, dy[0])); + continue; + } + // All loop-specific control flow ops should have been handled above + DCHECK(!n->IsEnter() && !n->IsNextIteration()) << n->DebugString(); + const size_t num_no_grad = no_grad_dy_indices.size(); if (IsPrimitiveOpWithNoGrad(n->type_string()) || num_no_grad == num_y) { // No grad defined for this op, or all outputs returned 'NoGradient': diff --git a/tensorflow/cc/framework/while_gradients.cc b/tensorflow/cc/framework/while_gradients.cc new file mode 100644 index 0000000000..8234d5bea4 --- /dev/null +++ b/tensorflow/cc/framework/while_gradients.cc @@ -0,0 +1,197 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/framework/while_gradients.h" + +#include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/cc/ops/while_loop.h" + +namespace tensorflow { +namespace { + +using ops::BodyGraphBuilderFn; +using ops::BuildWhileLoop; +using ops::CondGraphBuilderFn; + +Output ToOutput(OutputTensor output_tensor) { + return Output(const_cast<Node*>(output_tensor.node), output_tensor.index); +} + +std::vector<Output> ToOutputVector( + const std::vector<OutputTensor>& output_tensors) { + size_t n = output_tensors.size(); + std::vector<Output> result(n); + for (int i = 0; i < n; ++i) result[i] = ToOutput(output_tensors[i]); + return result; +} + +// The backprop loop counter and main backprop loop run in their own execution +// frame (conceptually, the main forward loop and forward loop counter run +// together in a frame, then the backprop loop counter and backprop loop run +// together in a different frame). This returns the frame name to use for the +// backprop while loops. +// TODO(skyewm): make sure this is unique among existing frame names +string BackPropFrameName(const string& forward_frame_name) { + return strings::StrCat(forward_frame_name, "_backprop"); +} + +// Creates a loop that counts the number of iterations performed by the +// while loop associated with `while_ctx`. The returned output yields the +// iteration count. +Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope, + Output* count) { + // Create while loop: + // i = 0 + // while forward loop predicate is true: + // ++i + + Output zero = ops::Const(scope, 0, {}); + + // Condition function that returns condition output from original while loop. + CondGraphBuilderFn cond_fn = [while_ctx](const Scope& scope, + const std::vector<Output>& inputs, + Output* output) { + *output = ToOutput(while_ctx->cond_output()); + return Status::OK(); + }; + + // Body function that adds one to input. + BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope, + const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + DCHECK_EQ(inputs.size(), 1); + outputs->emplace_back(ops::Add(scope, inputs[0], 1)); + return scope.status(); + }; + + // Note that this loop runs in the same execution frame as the forward loop. + std::vector<Output> outputs; + TF_RETURN_IF_ERROR(BuildWhileLoop(scope, {zero}, cond_fn, body_fn, + while_ctx->frame_name(), &outputs, + /* create_while_ctx */ false)); + *count = outputs[0]; + return Status::OK(); +} + +// Creates a loop that executes `loop_count` times. The returned output is the +// boolean predicate indicating if the loop is still executing. This is used to +// drive the gradient computation for the while loop associated with +// `while_ctx`. +Status AddBackPropLoopCounter(WhileContext* while_ctx, const Output& loop_count, + const Scope& scope, + Output* backprop_execution_pred) { + // Create while loop: + // n = loop_count + // while n > 0: + // --n + + // Condition function that returns input > 0. + CondGraphBuilderFn cond_fn = [](const Scope& scope, + const std::vector<Output>& inputs, + Output* output) { + DCHECK_EQ(inputs.size(), 1); + *output = ops::Greater(scope, inputs[0], 0); + return scope.status(); + }; + + // Body function that subtracts one from input. + BodyGraphBuilderFn body_fn = [](const Scope& scope, + const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + DCHECK_EQ(inputs.size(), 1); + outputs->emplace_back(ops::Subtract(scope, inputs[0], 1)); + return scope.status(); + }; + + string frame_name = BackPropFrameName(while_ctx->frame_name()); + std::vector<Output> outputs; // unused + TF_RETURN_IF_ERROR(BuildWhileLoop( + scope, {loop_count}, cond_fn, body_fn, frame_name, &outputs, + /* create_while_ctx */ false, backprop_execution_pred)); + return Status::OK(); +} + +// Creates the main backprop loop that computes the gradient of the loop +// associated with `while_ctx`. `grad_inputs` are the partial derivatives +// w.r.t. the loop outputs, i.e. the exit nodes. `backprop_execution_pred` is +// the predicate to use for the backprop loop (see AddBackPropLoopCounter()). +// The partial derivatives w.r.t. the loop inputs, i.e. the input loop vars, are +// returned in `grad_outputs`. +Status AddWhileGradientLoop(WhileContext* while_ctx, + const std::vector<Output>& grad_inputs, + const Output& backprop_execution_pred, + const Scope& parent_scope, + std::vector<Output>* grad_outputs) { + DCHECK_EQ(grad_inputs.size(), while_ctx->body_outputs().size()); + DCHECK_EQ(while_ctx->body_inputs().size(), while_ctx->body_outputs().size()); + + Scope scope = parent_scope.NewSubScope("while"); + + // Create while loop: + // while backprop_execution_pred: + // forward loop body gradient + + // Condition function that returns 'backprop_execution_pred'. + CondGraphBuilderFn cond_fn = [backprop_execution_pred]( + const Scope& scope, + const std::vector<Output>& inputs, + Output* output) { + *output = backprop_execution_pred; + return Status::OK(); + }; + + // Body function that builds while body gradient subgraph. + BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope, + const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + std::vector<Output> body_outputs = + ToOutputVector(while_ctx->body_outputs()); + std::vector<Output> body_inputs = ToOutputVector(while_ctx->body_inputs()); + return AddSymbolicGradients(scope, body_outputs, body_inputs, inputs, + outputs); + }; + + string frame_name = BackPropFrameName(while_ctx->frame_name()); + TF_RETURN_IF_ERROR(BuildWhileLoop(scope, grad_inputs, cond_fn, body_fn, + frame_name, grad_outputs, + /* create_while_ctx */ false)); + return Status::OK(); +} + +} // namespace + +Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + Output forward_loop_count; + TF_RETURN_IF_ERROR(AddForwardLoopCounter( + while_ctx, scope.NewSubScope("ForwardLoopCounter"), &forward_loop_count)); + + // TODO(skyewm): can we combine the backprop loop counter and main gradient + // loop into a single loop? The original Python code doesn't combine the + // loops, but I'm not sure why. + Output backprop_counter_cond; + TF_RETURN_IF_ERROR(AddBackPropLoopCounter( + while_ctx, forward_loop_count, scope.NewSubScope("BackPropLoopCounter"), + &backprop_counter_cond)); + + return AddWhileGradientLoop(while_ctx, grad_inputs, backprop_counter_cond, + scope, grad_outputs); +} + +} // namespace tensorflow diff --git a/tensorflow/cc/framework/while_gradients.h b/tensorflow/cc/framework/while_gradients.h new file mode 100644 index 0000000000..8f592accc9 --- /dev/null +++ b/tensorflow/cc/framework/while_gradients.h @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ +#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/core/graph/while_context.h" + +// Utility functions for constructing while loop gradients + +namespace tensorflow { + +// Adds the gradient computation for the while loop associated with +// `while_ctx`. `grad_inputs` are the partial derivatives w.r.t. the loop +// outputs, i.e. the exit nodes. The partial derivatives w.r.t. the loop +// inputs, i.e. the input loop vars, are returned in `grad_outputs`. +// `grad_inputs` and `grad_outputs` are both in loop-variable order, as defined +// by the original inputs to BuildWhileLoop(). +// TODO(skyewm): maybe comment on NoGradient once it's supported +Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ diff --git a/tensorflow/cc/framework/while_gradients_test.cc b/tensorflow/cc/framework/while_gradients_test.cc new file mode 100644 index 0000000000..39fa7477c5 --- /dev/null +++ b/tensorflow/cc/framework/while_gradients_test.cc @@ -0,0 +1,233 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/framework/testutil.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/cc/ops/while_loop.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace { + +class WhileGradientsTest : public ::testing::Test { + protected: + WhileGradientsTest() : scope_(Scope::NewRootScope()) {} + + void Init(int num_inputs, DataType dtype = DT_INT32) { + for (int i = 0; i < num_inputs; ++i) { + inputs_.push_back(ops::Placeholder(scope_, dtype)); + } + } + + void CreateLoop(const ops::CondGraphBuilderFn& cond, + const ops::BodyGraphBuilderFn& body, + const std::vector<Output>* inputs = nullptr) { + if (inputs == nullptr) inputs = &inputs_; + TF_ASSERT_OK(ops::BuildWhileLoop(scope_, *inputs, cond, body, "test_loop", + &outputs_)); + } + + void CreateBackprop() { + TF_ASSERT_OK( + AddSymbolicGradients(scope_, outputs_, inputs_, &grad_outputs_)); + ASSERT_EQ(grad_outputs_.size(), inputs_.size()); + } + + template <typename T> + void Run(const std::vector<Input::Initializer>& input_values, + const std::vector<T>& expected_grad_values) { + Run<T>(ClientSession(scope_), input_values, expected_grad_values); + } + + template <typename T> + void Run(const ClientSession& session, + const std::vector<Input::Initializer>& input_values, + const std::vector<T>& expected_grad_values, + const RunOptions& run_options = RunOptions(), + RunMetadata* run_metadata = nullptr) { + DCHECK_EQ(input_values.size(), inputs_.size()); + ClientSession::FeedType feeds; + for (int i = 0; i < inputs_.size(); ++i) { + feeds.emplace(inputs_[i], input_values[i]); + } + + std::vector<Operation> run_outputs; + std::vector<Tensor> out_tensors; + TF_ASSERT_OK(session.Run(run_options, feeds, grad_outputs_, run_outputs, + &out_tensors, run_metadata)); + ASSERT_EQ(out_tensors.size(), grad_outputs_.size()); + + DCHECK_EQ(expected_grad_values.size(), out_tensors.size()); + for (int i = 0; i < out_tensors.size(); ++i) { + test::ExpectTensorEqual<T>( + out_tensors[i], test::AsTensor<T>({expected_grad_values[i]}, {})); + } + } + + Scope scope_; + std::vector<Output> inputs_; + std::vector<Output> outputs_; + std::vector<Output> grad_outputs_; +}; + +TEST_F(WhileGradientsTest, Basic) { + // Create loop: while (i < 10) i += 1 + Init(1); + CreateLoop( + [](const Scope& s, const std::vector<Output>& inputs, Output* output) { + *output = ops::Less(s, inputs[0], 10); + return s.status(); + }, + [](const Scope& s, const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + // Use AddN, rather than Add, because the gradient function doesn't + // depend on the input shapes, and thus we do not need to store + // intermediate values in a stack. + outputs->push_back(ops::AddN(s, {inputs[0], 1})); + return s.status(); + }); + CreateBackprop(); + + Run<int>({1}, {1}); + Run<int>({11}, {1}); +} + +TEST_F(WhileGradientsTest, MultipleLoopVars) { + // Create loop: while (i < 10) i += j; j += 1; k = k + Init(3); + CreateLoop( + [](const Scope& s, const std::vector<Output>& inputs, Output* output) { + *output = ops::Less(s, inputs[0], 10); + return s.status(); + }, + [](const Scope& s, const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + outputs->push_back(ops::AddN(s, {inputs[0], inputs[1]})); + outputs->push_back(ops::AddN(s, {inputs[1], 1})); + outputs->push_back(inputs[2]); + return s.status(); + }); + CreateBackprop(); + + // The following execution traces illustrate why we expect dF/dj to be 5: + // + // i j k + // --------- + // 0 1 2 <-- initial values + // 1 2 2 + // 3 3 2 + // 6 4 2 + // 10 5 2 <-- while output values + // outputs sum = 17 + // + // i j k + // --------- + // 0 2 2 <-- initial values (add 1 to j) + // 2 3 2 + // 5 4 2 + // 9 5 2 + // 14 6 2 <-- while output values + // outputs sum = 22 + // + // Calculate the "slope" between j=1 and j=2: + // 22 - 17 = 5 => dF/dj = 5 + Run<int>({0, 1, 2}, {1, 5, 1}); + + Run<int>({1, 1, 0}, {1, 5, 1}); + Run<int>({0, 0, 0}, {1, 6, 1}); +} + +TEST_F(WhileGradientsTest, Chaining) { + Init(2, DT_DOUBLE); + + // Multiply each input by 2 before passing to while loop to make sure chaining + // works properly + std::vector<Output> loop_inputs = {ops::Multiply(scope_, inputs_[0], 2.0), + ops::Multiply(scope_, inputs_[1], 2.0)}; + + // Create loop: while (i > 0 && j > 0) i -= 1 + CreateLoop( + [](const Scope& s, const std::vector<Output>& inputs, Output* output) { + *output = ops::LogicalAnd(s, ops::Greater(s, inputs[0], 0.0), + ops::Greater(s, inputs[1], 0.0)); + return s.status(); + }, + [](const Scope& s, const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + outputs->push_back(ops::AddN(s, {inputs[0], -1.0})); + outputs->push_back(inputs[1]); + return s.status(); + }, + &loop_inputs); + + // Take negative of first output to make sure chaining works properly + outputs_[0] = ops::Neg(scope_, outputs_[0]); + + CreateBackprop(); + + Run<double>({1.0, 1.0}, {-2.0, 2.0}); + Run<double>({0.0, 0.0}, {-2.0, 2.0}); +} + +TEST_F(WhileGradientsTest, MultipleDevices) { + // Make sure loop is created on cpu0 + scope_ = scope_.WithDevice("/cpu:0"); + + // Create loop: while (i < 10) i += j + Init(2); + CreateLoop( + [](const Scope& s, const std::vector<Output>& inputs, Output* output) { + *output = ops::Less(s, inputs[0], 10); + return s.status(); + }, + [](const Scope& s, const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + // Place body on cpu1 + Scope cpu1_scope = s.WithDevice("/cpu:1"); + outputs->push_back(ops::AddN(cpu1_scope, {inputs[0], inputs[1]})); + outputs->push_back(inputs[1]); + return cpu1_scope.status(); + }); + + // Build gradient graph on cpu1 + Scope cpu1_scope = scope_.WithDevice("/cpu:1"); + TF_ASSERT_OK( + AddSymbolicGradients(cpu1_scope, outputs_, inputs_, &grad_outputs_)); + ASSERT_EQ(grad_outputs_.size(), inputs_.size()); + + // Run with two CPU devices and output partition graphs + SessionOptions session_options; + (*session_options.config.mutable_device_count())["CPU"] = 2; + RunOptions run_options; + run_options.set_output_partition_graphs(true); + RunMetadata run_metadata; + Run<int>(ClientSession(scope_, session_options), {0, 1}, {1, 11}, run_options, + &run_metadata); + + // Check that at least one node ran on each device + ASSERT_EQ(run_metadata.partition_graphs().size(), 2); + for (const GraphDef& partition_graph : run_metadata.partition_graphs()) { + EXPECT_GE(partition_graph.node().size(), 1); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/ops/while_loop.h b/tensorflow/cc/ops/while_loop.h index 82181516d6..a04476056a 100644 --- a/tensorflow/cc/ops/while_loop.h +++ b/tensorflow/cc/ops/while_loop.h @@ -49,7 +49,12 @@ typedef std::function<Status(const Scope&, const std::vector<Output>& inputs, // * outputs: output param that returns final loop variable outputs in non-error // case. Must be non-null and empty. // * create_while_ctx: if true, a WhileContext is created and populated for this -// loop. See core/graph/while_context.h for more details. +// loop. See core/graph/while_context.h for more details on +// WhileContexts. This is set to false for loops used as part of gradient +// computations, since they're part of the gradient for a loop in the +// forward-pass. +// TODO(skyewm): revisit this. Should we create WhileContexts for all loops, +// even if we don't need them? // * cond_output: if non-null, the output of the predicate is returned. This // will always be a LoopCond node. // diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake index 6632433087..a5f5ae5478 100644 --- a/tensorflow/contrib/cmake/tf_cc_ops.cmake +++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake @@ -135,6 +135,8 @@ set(tf_cc_srcs "${tensorflow_source_dir}/tensorflow/cc/framework/gradient_checker.cc" "${tensorflow_source_dir}/tensorflow/cc/framework/gradients.h" "${tensorflow_source_dir}/tensorflow/cc/framework/gradients.cc" + "${tensorflow_source_dir}/tensorflow/cc/framework/while_gradients.h" + "${tensorflow_source_dir}/tensorflow/cc/framework/while_gradients.cc" ) file(GLOB_RECURSE tf_cc_test_srcs diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 5502eebd7f..5ca5ef916b 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2613,6 +2613,7 @@ tf_cc_tests( "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:scope", "//tensorflow/cc:sendrecv_ops", + "//tensorflow/cc:while_loop", "//tensorflow/core/kernels:ops_util", "//third_party/eigen3", ], diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc index 8dde7320ed..858ef8ac01 100644 --- a/tensorflow/core/graph/graph_partition_test.cc +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/cc/ops/math_ops.h" #include "tensorflow/cc/ops/random_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/while_loop.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/op.h" @@ -72,10 +73,13 @@ void Partition(const GraphDef& graph_def, GraphConstructorOptions opts; TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &g)); - // Assigns devices to each node. Uses 1st letter of the node name as - // the device index. + // Assigns devices to each node. Uses 1st letter of the node name as the + // device index if no device is specified. for (Node* node : g.nodes()) { - node->set_assigned_device_name(DeviceName(node)); + string device_name = !node->requested_device().empty() + ? node->requested_device() + : DeviceName(node); + node->set_assigned_device_name(device_name); } PartitionOptions popts; @@ -368,7 +372,7 @@ TEST_F(GraphPartitionTest, CrossDevice_DataControl) { ExpectMatchB(); } -TEST_F(GraphPartitionTest, CrossDeviceLoop) { +TEST_F(GraphPartitionTest, CrossDeviceLoopSimple) { using namespace ::tensorflow::ops; // NOLINT(build/namespaces) auto a1 = BoolInput(in_.WithOpName("A1")); auto a2 = ::tensorflow::ops::internal::Enter(in_.WithOpName("A2"), a1, "foo"); @@ -382,7 +386,7 @@ TEST_F(GraphPartitionTest, CrossDeviceLoop) { CheckLoopConstruction(ToGraphDef()); } -TEST_F(GraphPartitionTest, CrossDeviceLoop1) { +TEST_F(GraphPartitionTest, CrossDeviceLoopSimple1) { using namespace ::tensorflow::ops; // NOLINT(build/namespaces) auto a1 = BoolInput(in_.WithOpName("A1")); auto a2 = ::tensorflow::ops::internal::Enter(in_.WithOpName("B2"), a1, "foo"); @@ -407,6 +411,29 @@ TEST_F(GraphPartitionTest, CrossDeviceLoop1) { } } +TEST_F(GraphPartitionTest, CrossDeviceLoopFull) { + Scope cpu0 = in_.WithDevice("/job:a/replica:0/task:0/cpu:0"); + auto p1 = ops::Placeholder(cpu0, DT_INT32); + auto p2 = ops::Placeholder(cpu0, DT_INT32); + OutputList outputs; + // while i1 < 10: i1 += i2 + TF_ASSERT_OK(ops::BuildWhileLoop( + cpu0, {p1, p2}, + [](const Scope& s, const std::vector<Output>& inputs, Output* output) { + *output = ops::Less(s, inputs[0], 10); + return s.status(); + }, + [](const Scope& s, const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + Scope cpu1 = s.WithDevice("/job:a/replica:0/task:0/cpu:1"); + outputs->push_back(ops::AddN(cpu1, {inputs[0], inputs[1]})); + outputs->push_back(inputs[1]); + return s.status(); + }, + "test_loop", &outputs)); + CheckLoopConstruction(ToGraphDef()); +} + TEST_F(GraphPartitionTest, PartitionIncompleteGraph) { NodeDef ndef; Graph g(OpRegistry::Global()); |