diff options
author | Alexandre Passos <apassos@google.com> | 2018-04-30 09:29:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-30 09:32:36 -0700 |
commit | aa2405ee79dbcfabb8862ef3e1f8ca60e52159a0 (patch) | |
tree | 842265cc624b9c2f201fb5bc75264c39ea92740a /tensorflow/c/eager | |
parent | a5a51ad3a1200e2e5ef46c140bab717422e41ca2 (diff) |
Fixes to tape gradient for providing outputs and having multiple targets.
PiperOrigin-RevId: 194796304
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r-- | tensorflow/c/eager/tape.h | 65 |
1 files changed, 27 insertions, 38 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 97c323b872..8026076b9e 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -380,49 +380,39 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace, gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape, const OpTape<BackwardFunction>& op_tape, - const gtl::FlatMap<int64, int64>& tensor_usage_counts, gtl::FlatMap<int64, std::vector<Gradient*>>* result) { for (int i = 0; i < target_tensor_ids.size(); ++i) { const int64 id = target_tensor_ids[i]; - if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) { - if (!output_gradients.empty() && output_gradients[i] != nullptr) { - // TODO(apassos) figure out how to print debugging information here. - return errors::InvalidArgument( - "A gradient was provided for a tensor which is used as part of the " - "computation."); - } - } else { - if (output_gradients.empty() || output_gradients[i] == nullptr) { - auto tensor_it = tensor_tape.find(id); - if (tensor_it != tensor_tape.end() && tensor_it->second != -1) { - auto op_it = op_tape.find(tensor_it->second); - if (op_it == op_tape.end()) { - return errors::Internal( - "Internal state of the gradient tape is invalid: " - "failed to find operation producing a tensor"); - } - bool found = false; - for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { - if (op_it->second.output_tensor_info[j].id == id) { - found = true; - (*result)[id].push_back( - vspace.Ones(op_it->second.output_tensor_info[j].shape, - op_it->second.output_tensor_info[j].dtype)); - break; - } - } - if (!found) { - return errors::Internal( - "Internal state of the gradient tape is invalid: " - "none of operations outputs match expected tensor"); + if (output_gradients.empty() || output_gradients[i] == nullptr) { + auto tensor_it = tensor_tape.find(id); + if (tensor_it != tensor_tape.end() && tensor_it->second != -1) { + auto op_it = op_tape.find(tensor_it->second); + if (op_it == op_tape.end()) { + return errors::Internal( + "Internal state of the gradient tape is invalid: " + "failed to find operation producing a tensor"); + } + bool found = false; + for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { + if (op_it->second.output_tensor_info[j].id == id) { + found = true; + (*result)[id].push_back( + vspace.Ones(op_it->second.output_tensor_info[j].shape, + op_it->second.output_tensor_info[j].dtype)); + break; } - } else { - // No record of the target tensor found on the tape, so no gradient - // needs to be computed from it. Do nothing. + } + if (!found) { + return errors::Internal( + "Internal state of the gradient tape is invalid: " + "none of operations outputs match expected tensor"); } } else { - (*result)[id].push_back(output_gradients[i]); + // No record of the target tensor found on the tape, so no gradient + // needs to be computed from it. Do nothing. } + } else { + (*result)[id].push_back(output_gradients[i]); } } return Status::OK(); @@ -451,8 +441,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient( InitialStack(state.op_tape, state.op_missing_tensor); gtl::FlatMap<int64, std::vector<Gradient*>> gradients; Status s = InitialGradients(vspace, target_tensor_ids, output_gradients, - tensor_tape_, state.op_tape, - state.tensor_usage_counts, &gradients); + tensor_tape_, state.op_tape, &gradients); auto cleanup = [this, &state]() { if (!persistent_) { // Release all backprop functions |