aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-04-30 09:29:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-30 09:32:36 -0700
commitaa2405ee79dbcfabb8862ef3e1f8ca60e52159a0 (patch)
tree842265cc624b9c2f201fb5bc75264c39ea92740a /tensorflow/c/eager
parenta5a51ad3a1200e2e5ef46c140bab717422e41ca2 (diff)
Fixes to tape gradient for providing outputs and having multiple targets.
PiperOrigin-RevId: 194796304
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r--tensorflow/c/eager/tape.h65
1 files changed, 27 insertions, 38 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 97c323b872..8026076b9e 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -380,49 +380,39 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
gtl::ArraySlice<Gradient*> output_gradients,
const TensorTape& tensor_tape,
const OpTape<BackwardFunction>& op_tape,
- const gtl::FlatMap<int64, int64>& tensor_usage_counts,
gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
for (int i = 0; i < target_tensor_ids.size(); ++i) {
const int64 id = target_tensor_ids[i];
- if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
- if (!output_gradients.empty() && output_gradients[i] != nullptr) {
- // TODO(apassos) figure out how to print debugging information here.
- return errors::InvalidArgument(
- "A gradient was provided for a tensor which is used as part of the "
- "computation.");
- }
- } else {
- if (output_gradients.empty() || output_gradients[i] == nullptr) {
- auto tensor_it = tensor_tape.find(id);
- if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
- auto op_it = op_tape.find(tensor_it->second);
- if (op_it == op_tape.end()) {
- return errors::Internal(
- "Internal state of the gradient tape is invalid: "
- "failed to find operation producing a tensor");
- }
- bool found = false;
- for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
- if (op_it->second.output_tensor_info[j].id == id) {
- found = true;
- (*result)[id].push_back(
- vspace.Ones(op_it->second.output_tensor_info[j].shape,
- op_it->second.output_tensor_info[j].dtype));
- break;
- }
- }
- if (!found) {
- return errors::Internal(
- "Internal state of the gradient tape is invalid: "
- "none of operations outputs match expected tensor");
+ if (output_gradients.empty() || output_gradients[i] == nullptr) {
+ auto tensor_it = tensor_tape.find(id);
+ if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
+ auto op_it = op_tape.find(tensor_it->second);
+ if (op_it == op_tape.end()) {
+ return errors::Internal(
+ "Internal state of the gradient tape is invalid: "
+ "failed to find operation producing a tensor");
+ }
+ bool found = false;
+ for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
+ if (op_it->second.output_tensor_info[j].id == id) {
+ found = true;
+ (*result)[id].push_back(
+ vspace.Ones(op_it->second.output_tensor_info[j].shape,
+ op_it->second.output_tensor_info[j].dtype));
+ break;
}
- } else {
- // No record of the target tensor found on the tape, so no gradient
- // needs to be computed from it. Do nothing.
+ }
+ if (!found) {
+ return errors::Internal(
+ "Internal state of the gradient tape is invalid: "
+ "none of operations outputs match expected tensor");
}
} else {
- (*result)[id].push_back(output_gradients[i]);
+ // No record of the target tensor found on the tape, so no gradient
+ // needs to be computed from it. Do nothing.
}
+ } else {
+ (*result)[id].push_back(output_gradients[i]);
}
}
return Status::OK();
@@ -451,8 +441,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
InitialStack(state.op_tape, state.op_missing_tensor);
gtl::FlatMap<int64, std::vector<Gradient*>> gradients;
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
- tensor_tape_, state.op_tape,
- state.tensor_usage_counts, &gradients);
+ tensor_tape_, state.op_tape, &gradients);
auto cleanup = [this, &state]() {
if (!persistent_) {
// Release all backprop functions