diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-07-09 10:23:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-09 10:26:44 -0700 |
commit | 6732ec3dfffebd8e17250803e27f1a4c5dedefae (patch) | |
tree | ca09bd47de33cf095645e04727a406a8a8ddd2e7 /tensorflow/c | |
parent | 3b7edb3cedb6552ab77cc6c4e3e68387e12253ac (diff) |
Skip calling back into python if only 1 gradient to aggregate
PiperOrigin-RevId: 203786896
Diffstat (limited to 'tensorflow/c')
-rw-r--r-- | tensorflow/c/eager/tape.h | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 734e712daa..1adb0458c3 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -520,7 +520,12 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient( } } else { any_gradient_nonzero = true; - auto new_gradients = vspace.AggregateGradients(grad_it->second); + Gradient* new_gradients = nullptr; + if (grad_it->second.size() == 1) { + new_gradients = grad_it->second.at(0); + } else { + new_gradients = vspace.AggregateGradients(grad_it->second); + } if (sources_set.find(grad_it->first) == sources_set.end()) { gradients.erase(grad_it); } else { |