diff options
author | Igor Ganichev <iga@google.com> | 2018-03-28 20:51:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-28 20:53:48 -0700 |
commit | 3f7adc710495e1160acd956c482779247ef1f101 (patch) | |
tree | b050b53db22e8f3f382a22c13ebe558f832fd168 /tensorflow/c/eager | |
parent | bb582f1b6fad474bc446c78a6683247a8eb6048e (diff) |
Support structured source in GradientTape.gradient
Before this change, it was easy to forget [] around the source tensor.
This mistake lead to GradientTape.gradient(), returning a list of Nones.
Nones normally tell to the user that the source and the target are
not connected via differentiable operations, which is not the source
of the error in this case.
Instead of adding a check that `sources` is a list of tensors, this CL
adds ability to handle structured source (which includes a lone tensor),
similarly to many existing TensorFlow APIs.
Also, with Alex's help, it fixes a bug where repeated tensors in
`sources` were not handled correctly.
PiperOrigin-RevId: 190878583
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r-- | tensorflow/c/eager/tape.h | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index c7bd3bdafd..97c323b872 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -601,23 +601,28 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient( } CHECK(state.op_tape.empty()); result->reserve(source_tensor_ids.size()); + gtl::FlatSet<int64> used_gradient_ids(source_tensor_ids.size()); for (auto is : source_tensor_ids) { auto grad_it = gradients.find(is); if (grad_it == gradients.end()) { result->push_back(nullptr); } else { - if (grad_it->second.size() == 1) { - result->push_back(grad_it->second[0]); - } else { - result->push_back(vspace.AggregateGradients(grad_it->second)); + if (grad_it->second.size() > 1) { + Gradient* grad = vspace.AggregateGradients(grad_it->second); + grad_it->second.clear(); + grad_it->second.push_back(grad); } - gradients.erase(grad_it); + result->push_back(grad_it->second[0]); + used_gradient_ids.insert(is); } } - VLOG(1) << "Final gradients size: " << gradients.size(); + VLOG(1) << "Final gradients size: " + << gradients.size() - used_gradient_ids.size(); for (auto grad_pair : gradients) { - for (const auto& g : grad_pair.second) { - vspace.DeleteGradient(g); + if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) { + for (const auto& g : grad_pair.second) { + vspace.DeleteGradient(g); + } } } return Status::OK(); |