diff options
Diffstat (limited to 'tensorflow/c')
-rw-r--r-- | tensorflow/c/eager/tape.h | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 1adb0458c3..ce038a4b57 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -440,6 +440,15 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace, return Status::OK(); } +gtl::FlatMap<string, gtl::FlatSet<int>>* FunctionsAcceptingNoneForIndicesMap() { + static auto* const m = new gtl::FlatMap<string, gtl::FlatSet<int>>({ + {"SoftmaxCrossEntropyWithLogits", {1}}, + {"SparseSoftmaxCrossEntropyWithLogits", {1}}, + {"FusedBatchNorm", {1, 2, 3, 4}}, + }); + return m; +} + } // namespace // If over kMinAggregateCount gradients are accumulated and the total @@ -485,10 +494,6 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient( VLOG(1) << " " << t; } } - gtl::FlatMap<string, gtl::FlatSet<int>> functions_accept_none_for_indices({ - {"SoftmaxCrossEntropyWithLogits", {1}}, - {"FusedBatchNorm", {1, 2, 3, 4}}, - }); while (!op_stack.empty()) { const int64 op = op_stack.back(); VLOG(1) << "Popped " << op; @@ -509,8 +514,8 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient( auto grad_it = gradients.find(id); if (grad_it == gradients.end()) { auto func_name_it = - functions_accept_none_for_indices.find(trace.op_type); - if (func_name_it != functions_accept_none_for_indices.end() && + FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type); + if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() && func_name_it->second.find(i) != func_name_it->second.end()) { out_gradients.push_back(nullptr); } else { |