aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/framework
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2017-10-04 00:07:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-04 00:11:41 -0700
commitd016cb020583b1ecbc260c1492e347c2731b1c29 (patch)
treecffe8e044b948ef3a5c0dde1c752b550d4e193de /tensorflow/cc/framework
parentf9f037c1c489d6a72ef682e3bce01e6f154222e4 (diff)
Fix c++ gradients issue where multiple dependent outputs result in incorrect answer.
The issue is that we incorrectly calculate the pending num_expected_backprops for outputs nodes when one output transitively depends on another. this is because we use output nodes as an indicator of when we need to end our traversal. Instead we should only use output nodes that don't transitively get consumed by other output nodes as end indicators for our traversal. This change implements that fix. Fixes #13190 PiperOrigin-RevId: 170971937
Diffstat (limited to 'tensorflow/cc/framework')
-rw-r--r--tensorflow/cc/framework/gradients.cc90
-rw-r--r--tensorflow/cc/framework/gradients_test.cc40
2 files changed, 118 insertions, 12 deletions
diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc
index 0ec5b9a1bd..affd90b1bc 100644
--- a/tensorflow/cc/framework/gradients.cc
+++ b/tensorflow/cc/framework/gradients.cc
@@ -91,6 +91,13 @@ class SymbolicGradientBuilder {
// `summed_grads` is the sum of `exit_node`s gradients.
Status ProcessWhileLoop(Node* exit_node, const Output& summed_grads);
+ // Gets the set of node ids at which to stop backprop. These are all elements
+ // of `outputs_` that do not get transitively consumed by other `outputs_`.
+ // Used to identify nodes at which to stop backprop.
+ std::unordered_set<int> GetStopBackpropNodes(
+ const std::vector<bool>& reachable_nodes,
+ std::unordered_set<int> output_nodes);
+
const Scope& scope_;
const ops::GradOpRegistry* registry_;
const std::vector<Output>& outputs_;
@@ -117,10 +124,6 @@ class SymbolicGradientBuilder {
// gradients from `grad_inputs_`.
std::deque<Node*> ready_;
- // The set of node ids in `outputs_`. Used to identify nodes at which to stop
- // backprop.
- std::unordered_set<int> output_nodes_;
-
// The set of node ids in `inputs_`. Used to identify nodes at backprop
// frontier. Maps from Output -> index into `grad_outputs_`.
std::unordered_map<Output, int, OutputHash, OutputEq> input_nodes_;
@@ -186,6 +189,63 @@ std::vector<bool> SymbolicGradientBuilder::GetReachableNodes() {
return reachable_nodes;
}
+std::unordered_set<int> SymbolicGradientBuilder::GetStopBackpropNodes(
+ const std::vector<bool>& reachable_nodes,
+ std::unordered_set<int> output_nodes) {
+ // Output nodes that get transitively consumed by other `outputs_` are stored
+ // in `internal_outputs`.
+ std::unordered_set<int> internal_outputs;
+ std::unordered_set<Node*> visited;
+ // Initialize `queue` for BFS traversal. Nodes in `queue` hold upcoming nodes
+ // along with the last Node in `output_` encountered along that path. If no
+ // `output_` node was encountered, pair.second will be nullptr.
+ std::deque<std::pair<Node*, Node*>> queue;
+ for (const Output& nout : inputs_) {
+ if (visited.find(nout.node()) == visited.end()) {
+ queue.push_back(std::make_pair(nout.node(), static_cast<Node*>(nullptr)));
+ visited.insert(nout.node());
+ }
+ }
+ // BFS from nodes in 'inputs_' along out edges for the entire graph. Internal
+ // output nodes are recorded during the traversal. All nodes that are output
+ // nodes but not internal output nodes are considered the frontier of the
+ // output nodes, and thus our stop backprop nodes.
+ while (!queue.empty()) {
+ std::pair<Node*, Node*> p = queue.front();
+ Node* n = p.first;
+ queue.pop_front();
+ for (const Edge* e : n->out_edges()) {
+ // If a node is not reachable from outputs_, we can stop.
+ if (e->IsControlEdge() || !reachable_nodes[e->dst()->id()]) continue;
+ if (visited.find(e->dst()) != visited.end()) continue;
+
+ int node_id = e->dst()->id();
+ Node* last_output_node = p.second;
+ if (output_nodes.find(node_id) != output_nodes.end()) {
+ // We reached an output node.
+ if (last_output_node != nullptr) {
+ // If we had already found an output node on this path so we mark
+ // it as an internal output.
+ internal_outputs.insert(last_output_node->id());
+ }
+ // Mark this newly found output node to insert in the queue.
+ last_output_node = e->dst();
+ }
+ queue.push_back(std::make_pair(e->dst(), last_output_node));
+ visited.insert(e->dst());
+ }
+ }
+ // Finally, we set stop_backprop_nodes to all output_nodes that aren't also
+ // internal_outputs.
+ std::unordered_set<int> stop_backprop_nodes;
+ for (int output_node : output_nodes) {
+ if (internal_outputs.find(output_node) == internal_outputs.end()) {
+ stop_backprop_nodes.insert(output_node);
+ }
+ }
+ return stop_backprop_nodes;
+}
+
Status SymbolicGradientBuilder::Initialize() {
if (outputs_.size() != grad_inputs_.size()) {
return errors::InvalidArgument(
@@ -202,11 +262,16 @@ Status SymbolicGradientBuilder::Initialize() {
}
grad_outputs_->clear();
grad_outputs_->resize(inputs_.size());
- // Populate `output_nodes_` from node ids in `outputs_`.
- output_nodes_.reserve(outputs_.size());
+
+ std::unordered_set<int> output_nodes;
+ output_nodes.reserve(outputs_.size());
for (size_t i = 0; i < outputs_.size(); ++i) {
- output_nodes_.insert(outputs_[i].node()->id());
+ output_nodes.insert(outputs_[i].node()->id());
}
+
+ std::unordered_set<int> stop_backprop_nodes =
+ GetStopBackpropNodes(reachable_nodes, output_nodes);
+
// Populate `input_nodes_` from Outputs in `inputs_`.
input_nodes_.reserve(inputs_.size());
for (size_t i = 0; i < inputs_.size(); ++i) {
@@ -237,7 +302,7 @@ Status SymbolicGradientBuilder::Initialize() {
backprops_[{n, i}].clear();
}
int num_expected_backprops = 0;
- if (output_nodes_.find(n->id()) == output_nodes_.end()) {
+ if (stop_backprop_nodes.find(n->id()) == stop_backprop_nodes.end()) {
// Internal node: continue BFS along connected outputs.
for (const Edge* e : n->out_edges()) {
// If a node is not reachable from outputs_,
@@ -250,9 +315,10 @@ Status SymbolicGradientBuilder::Initialize() {
}
++num_expected_backprops;
}
- } else {
- // Output node: stop BFS and update `num_expected_backprops` for
- // each Output in `outputs_` that references `n`.
+ }
+ if (output_nodes.find(n->id()) != output_nodes.end()) {
+ // Output node: update `num_expected_backprops` for each Output in
+ // `outputs_` that references `n`.
for (const Output& output : outputs_) {
if (output.node() == n) {
++num_expected_backprops;
@@ -323,7 +389,7 @@ Status SymbolicGradientBuilder::CallGradFunction(
Status SymbolicGradientBuilder::ProcessWhileLoop(Node* exit_node,
const Output& summed_grads) {
- // TOOD(skyewm): detect second-order gradient and return bad status
+ // TODO(skyewm): detect second-order gradient and return bad status
// TODO(skyewm): handle (or at least detect) nested while loops
// TODO(skyewm): handle NoGradient in while loop
diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc
index dcaf10c340..07a062e704 100644
--- a/tensorflow/cc/framework/gradients_test.cc
+++ b/tensorflow/cc/framework/gradients_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/gradients.h"
+#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/testutil.h"
#include "tensorflow/cc/ops/standard_ops.h"
@@ -453,6 +454,45 @@ TEST_F(GradientsTest, UnreachableInput) {
" for node 'z' as it's unreachable from the output node(s).");
}
+TEST_F(GradientsTest, DependentOutputs) {
+ auto x = Placeholder(scope_test_, DT_FLOAT);
+ auto y0 = Square(scope_test_, x);
+ auto y1 = Square(scope_test_, y0);
+ auto y2 = Square(scope_test_, y1);
+ // Requesting the gradients for y0 and y2 should return the sum of their
+ // individual gradients.
+ std::vector<Output> grad_outputs;
+ TF_EXPECT_OK(AddSymbolicGradients(scope_test_, {y0, y2}, {x}, &grad_outputs));
+ ClientSession session(scope_test_);
+ std::vector<Tensor> grad_result;
+ TF_EXPECT_OK(session.Run({{x, {3.0f}}}, grad_outputs, &grad_result));
+ EXPECT_EQ(grad_result.size(), 1);
+ EXPECT_EQ(grad_result[0].NumElements(), 1);
+ EXPECT_EQ(grad_result[0].flat<float>()(0), 17502.0f);
+}
+
+TEST_F(GradientsTest, MultiOutputNodeDependentOutputs) {
+ auto x = Placeholder(scope_test_, DT_FLOAT);
+ auto y0 = Square(scope_test_, x);
+ // y1, y2, and y3 all use y0. This means the backwards pass will need to wait
+ // for the gradient for all three.
+ auto y1 = Square(scope_test_, y0);
+ auto y2 = Square(scope_test_, y0);
+ auto y3 = Square(scope_test_, y2);
+ std::vector<Output> grad_outputs;
+ // By requesting y0, y1, and y3 we test that the computation correctly waits
+ // for all the points in backprop where gradients need to be summed from
+ // multiple branches.
+ TF_EXPECT_OK(
+ AddSymbolicGradients(scope_test_, {y0, y1, y3}, {x}, &grad_outputs));
+ ClientSession session(scope_test_);
+ std::vector<Tensor> grad_result;
+ TF_EXPECT_OK(session.Run({{x, {3.0f}}}, grad_outputs, &grad_result));
+ EXPECT_EQ(grad_result.size(), 1);
+ EXPECT_EQ(grad_result[0].NumElements(), 1);
+ EXPECT_EQ(grad_result[0].flat<float>()(0), 17610.0f);
+}
+
// StopGradientSingleOutputMultiEdgeTest tests combinations of valid and
// 'NoGradient' (induced by StopGradient op) returned along multiple edges from
// a single nodes output.