diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-08-30 10:19:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-30 10:27:38 -0700 |
commit | 862d753d4edac42b2af440ae999b61b80e94e000 (patch) | |
tree | 1aa2bb662eb19a1f509793939f17492df7236da3 /tensorflow/c | |
parent | ee89fccfd1db25563dadd0e3b4336612d7c52e0a (diff) |
Skip zeros call if unrequired in backprop for SparseSoftmaxCrossEntropyWithLogits
See https://github.com/tensorflow/tensorflow/blob/065f9b833ffbb3b2f03d63febb186275674ba133/tensorflow/python/ops/nn_grad.py#L482
Should help with #20218
PiperOrigin-RevId: 210933185
Diffstat (limited to 'tensorflow/c')
-rw-r--r-- | tensorflow/c/eager/tape.h | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 1adb0458c3..ce038a4b57 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -440,6 +440,15 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace, return Status::OK(); } +gtl::FlatMap<string, gtl::FlatSet<int>>* FunctionsAcceptingNoneForIndicesMap() { + static auto* const m = new gtl::FlatMap<string, gtl::FlatSet<int>>({ + {"SoftmaxCrossEntropyWithLogits", {1}}, + {"SparseSoftmaxCrossEntropyWithLogits", {1}}, + {"FusedBatchNorm", {1, 2, 3, 4}}, + }); + return m; +} + } // namespace // If over kMinAggregateCount gradients are accumulated and the total @@ -485,10 +494,6 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient( VLOG(1) << " " << t; } } - gtl::FlatMap<string, gtl::FlatSet<int>> functions_accept_none_for_indices({ - {"SoftmaxCrossEntropyWithLogits", {1}}, - {"FusedBatchNorm", {1, 2, 3, 4}}, - }); while (!op_stack.empty()) { const int64 op = op_stack.back(); VLOG(1) << "Popped " << op; @@ -509,8 +514,8 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient( auto grad_it = gradients.find(id); if (grad_it == gradients.end()) { auto func_name_it = - functions_accept_none_for_indices.find(trace.op_type); - if (func_name_it != functions_accept_none_for_indices.end() && + FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type); + if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() && func_name_it->second.find(i) != func_name_it->second.end()) { out_gradients.push_back(nullptr); } else { |