aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-08-30 10:19:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 10:27:38 -0700
commit862d753d4edac42b2af440ae999b61b80e94e000 (patch)
tree1aa2bb662eb19a1f509793939f17492df7236da3 /tensorflow/c
parentee89fccfd1db25563dadd0e3b4336612d7c52e0a (diff)
Skip zeros call if unrequired in backprop for SparseSoftmaxCrossEntropyWithLogits
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/eager/tape.h17
1 files changed, 11 insertions, 6 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 1adb0458c3..ce038a4b57 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -440,6 +440,15 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
return Status::OK();
}
+gtl::FlatMap<string, gtl::FlatSet<int>>* FunctionsAcceptingNoneForIndicesMap() {
+ static auto* const m = new gtl::FlatMap<string, gtl::FlatSet<int>>({
+ {"SoftmaxCrossEntropyWithLogits", {1}},
+ {"SparseSoftmaxCrossEntropyWithLogits", {1}},
+ {"FusedBatchNorm", {1, 2, 3, 4}},
+ });
+ return m;
+}
+
} // namespace
// If over kMinAggregateCount gradients are accumulated and the total
@@ -485,10 +494,6 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
VLOG(1) << " " << t;
}
}
- gtl::FlatMap<string, gtl::FlatSet<int>> functions_accept_none_for_indices({
- {"SoftmaxCrossEntropyWithLogits", {1}},
- {"FusedBatchNorm", {1, 2, 3, 4}},
- });
while (!op_stack.empty()) {
const int64 op = op_stack.back();
VLOG(1) << "Popped " << op;
@@ -509,8 +514,8 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
auto grad_it = gradients.find(id);
if (grad_it == gradients.end()) {
auto func_name_it =
- functions_accept_none_for_indices.find(trace.op_type);
- if (func_name_it != functions_accept_none_for_indices.end() &&
+ FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
+ if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() &&
func_name_it->second.find(i) != func_name_it->second.end()) {
out_gradients.push_back(nullptr);
} else {