diff options
author | Alexandre Passos <apassos@google.com> | 2018-05-21 16:37:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-21 16:39:58 -0700 |
commit | 753cc5b3f7461b0b3f59605cba10b965aca0e3ad (patch) | |
tree | 33b6ab1f8cd9c6e8b9d412f88d405c5ae2fa561c /tensorflow/c/eager | |
parent | 433bb8e1fa0b300961a430c2652ad0776dcef187 (diff) |
Fixes issue with gradient tape when asking for the gradient of an intermediate tensor.
PiperOrigin-RevId: 197481473
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r-- | tensorflow/c/eager/tape.h | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index dcc2357b71..1833b25fea 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -104,6 +104,10 @@ class VSpace { gtl::ArraySlice<Gradient*> output_gradients, std::vector<Gradient*>* result) const = 0; + // Marks the following gradient as a result so it's not consumed by backward + // functions. + virtual void MarkAsResult(Gradient* gradient) const = 0; + // Deletes the input tensor. virtual void DeleteGradient(Gradient* gradient) const = 0; @@ -356,8 +360,7 @@ BackpropInitialState<BackwardFunction> PrepareBackprop( count_it->second++; } else { result.tensor_usage_counts[it] = 1; - if (sources_set.find(it) == sources_set.end() && - tensor_tape.find(it) != tensor_tape.end()) { + if (tensor_tape.find(it) != tensor_tape.end()) { tensor_stack.push_back(it); } } @@ -522,10 +525,15 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient( } } else { any_gradient_nonzero = true; - out_gradients.push_back(vspace.AggregateGradients(grad_it->second)); + auto new_gradients = vspace.AggregateGradients(grad_it->second); if (sources_set.find(grad_it->first) == sources_set.end()) { gradients.erase(grad_it); + } else { + grad_it->second.clear(); + grad_it->second.push_back(new_gradients); + vspace.MarkAsResult(new_gradients); } + out_gradients.push_back(new_gradients); } } std::vector<Gradient*> in_gradients; |