aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-07-09 10:23:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-09 10:26:44 -0700
commit6732ec3dfffebd8e17250803e27f1a4c5dedefae (patch)
treeca09bd47de33cf095645e04727a406a8a8ddd2e7 /tensorflow/c
parent3b7edb3cedb6552ab77cc6c4e3e68387e12253ac (diff)
Skip calling back into python if only 1 gradient to aggregate
PiperOrigin-RevId: 203786896
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/eager/tape.h7
1 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 734e712daa..1adb0458c3 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -520,7 +520,12 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
}
} else {
any_gradient_nonzero = true;
- auto new_gradients = vspace.AggregateGradients(grad_it->second);
+ Gradient* new_gradients = nullptr;
+ if (grad_it->second.size() == 1) {
+ new_gradients = grad_it->second.at(0);
+ } else {
+ new_gradients = vspace.AggregateGradients(grad_it->second);
+ }
if (sources_set.find(grad_it->first) == sources_set.end()) {
gradients.erase(grad_it);
} else {