aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-03-28 20:51:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-28 20:53:48 -0700
commit3f7adc710495e1160acd956c482779247ef1f101 (patch)
treeb050b53db22e8f3f382a22c13ebe558f832fd168 /tensorflow/c/eager
parentbb582f1b6fad474bc446c78a6683247a8eb6048e (diff)
Support structured source in GradientTape.gradient
Before this change, it was easy to forget [] around the source tensor. This mistake lead to GradientTape.gradient(), returning a list of Nones. Nones normally tell to the user that the source and the target are not connected via differentiable operations, which is not the source of the error in this case. Instead of adding a check that `sources` is a list of tensors, this CL adds ability to handle structured source (which includes a lone tensor), similarly to many existing TensorFlow APIs. Also, with Alex's help, it fixes a bug where repeated tensors in `sources` were not handled correctly. PiperOrigin-RevId: 190878583
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r--tensorflow/c/eager/tape.h21
1 files changed, 13 insertions, 8 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index c7bd3bdafd..97c323b872 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -601,23 +601,28 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
}
CHECK(state.op_tape.empty());
result->reserve(source_tensor_ids.size());
+ gtl::FlatSet<int64> used_gradient_ids(source_tensor_ids.size());
for (auto is : source_tensor_ids) {
auto grad_it = gradients.find(is);
if (grad_it == gradients.end()) {
result->push_back(nullptr);
} else {
- if (grad_it->second.size() == 1) {
- result->push_back(grad_it->second[0]);
- } else {
- result->push_back(vspace.AggregateGradients(grad_it->second));
+ if (grad_it->second.size() > 1) {
+ Gradient* grad = vspace.AggregateGradients(grad_it->second);
+ grad_it->second.clear();
+ grad_it->second.push_back(grad);
}
- gradients.erase(grad_it);
+ result->push_back(grad_it->second[0]);
+ used_gradient_ids.insert(is);
}
}
- VLOG(1) << "Final gradients size: " << gradients.size();
+ VLOG(1) << "Final gradients size: "
+ << gradients.size() - used_gradient_ids.size();
for (auto grad_pair : gradients) {
- for (const auto& g : grad_pair.second) {
- vspace.DeleteGradient(g);
+ if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) {
+ for (const auto& g : grad_pair.second) {
+ vspace.DeleteGradient(g);
+ }
}
}
return Status::OK();