diff options
-rw-r--r-- | tensorflow/c/eager/tape.h | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 734e712daa..1adb0458c3 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -520,7 +520,12 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient( } } else { any_gradient_nonzero = true; - auto new_gradients = vspace.AggregateGradients(grad_it->second); + Gradient* new_gradients = nullptr; + if (grad_it->second.size() == 1) { + new_gradients = grad_it->second.at(0); + } else { + new_gradients = vspace.AggregateGradients(grad_it->second); + } if (sources_set.find(grad_it->first) == sources_set.end()) { gradients.erase(grad_it); } else { |