aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-05-21 16:37:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-21 16:39:58 -0700
commit753cc5b3f7461b0b3f59605cba10b965aca0e3ad (patch)
tree33b6ab1f8cd9c6e8b9d412f88d405c5ae2fa561c /tensorflow/c/eager
parent433bb8e1fa0b300961a430c2652ad0776dcef187 (diff)
Fixes issue with gradient tape when asking for the gradient of an intermediate tensor.
PiperOrigin-RevId: 197481473
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r--tensorflow/c/eager/tape.h14
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index dcc2357b71..1833b25fea 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -104,6 +104,10 @@ class VSpace {
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) const = 0;
+ // Marks the following gradient as a result so it's not consumed by backward
+ // functions.
+ virtual void MarkAsResult(Gradient* gradient) const = 0;
+
// Deletes the input tensor.
virtual void DeleteGradient(Gradient* gradient) const = 0;
@@ -356,8 +360,7 @@ BackpropInitialState<BackwardFunction> PrepareBackprop(
count_it->second++;
} else {
result.tensor_usage_counts[it] = 1;
- if (sources_set.find(it) == sources_set.end() &&
- tensor_tape.find(it) != tensor_tape.end()) {
+ if (tensor_tape.find(it) != tensor_tape.end()) {
tensor_stack.push_back(it);
}
}
@@ -522,10 +525,15 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
}
} else {
any_gradient_nonzero = true;
- out_gradients.push_back(vspace.AggregateGradients(grad_it->second));
+ auto new_gradients = vspace.AggregateGradients(grad_it->second);
if (sources_set.find(grad_it->first) == sources_set.end()) {
gradients.erase(grad_it);
+ } else {
+ grad_it->second.clear();
+ grad_it->second.push_back(new_gradients);
+ vspace.MarkAsResult(new_gradients);
}
+ out_gradients.push_back(new_gradients);
}
}
std::vector<Gradient*> in_gradients;