From 862d753d4edac42b2af440ae999b61b80e94e000 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Thu, 30 Aug 2018 10:19:52 -0700 Subject: Skip zeros call if unrequired in backprop for SparseSoftmaxCrossEntropyWithLogits See https://github.com/tensorflow/tensorflow/blob/065f9b833ffbb3b2f03d63febb186275674ba133/tensorflow/python/ops/nn_grad.py#L482 Should help with #20218 PiperOrigin-RevId: 210933185 --- tensorflow/c/eager/tape.h | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) (limited to 'tensorflow/c') diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 1adb0458c3..ce038a4b57 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -440,6 +440,15 @@ Status InitialGradients(const VSpace& vspace, return Status::OK(); } +gtl::FlatMap>* FunctionsAcceptingNoneForIndicesMap() { + static auto* const m = new gtl::FlatMap>({ + {"SoftmaxCrossEntropyWithLogits", {1}}, + {"SparseSoftmaxCrossEntropyWithLogits", {1}}, + {"FusedBatchNorm", {1, 2, 3, 4}}, + }); + return m; +} + } // namespace // If over kMinAggregateCount gradients are accumulated and the total @@ -485,10 +494,6 @@ Status GradientTape::ComputeGradient( VLOG(1) << " " << t; } } - gtl::FlatMap> functions_accept_none_for_indices({ - {"SoftmaxCrossEntropyWithLogits", {1}}, - {"FusedBatchNorm", {1, 2, 3, 4}}, - }); while (!op_stack.empty()) { const int64 op = op_stack.back(); VLOG(1) << "Popped " << op; @@ -509,8 +514,8 @@ Status GradientTape::ComputeGradient( auto grad_it = gradients.find(id); if (grad_it == gradients.end()) { auto func_name_it = - functions_accept_none_for_indices.find(trace.op_type); - if (func_name_it != functions_accept_none_for_indices.end() && + FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type); + if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() && func_name_it->second.find(i) != func_name_it->second.end()) { out_gradients.push_back(nullptr); } else { -- cgit v1.2.3