aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/xent_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/xent_op.cc')
-rw-r--r--tensorflow/core/kernels/xent_op.cc8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc
index 639bad5f04..2a0ef63eab 100644
--- a/tensorflow/core/kernels/xent_op.cc
+++ b/tensorflow/core/kernels/xent_op.cc
@@ -61,9 +61,11 @@ class SoftmaxXentWithLogitsOp : public OpKernel {
context->allocate_output(
0, TensorShape({logits_in.dim_size(0)}), &loss_out));
Tensor* back_out = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(1, logits_in.shape(), &back_out));
-
+ // Try to reuse the logits_in buffer for the backprop output.
+ if (!context->forward_input_to_output(0, 1, &back_out)) {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, logits_in.shape(), &back_out));
+ }
functor::XentFunctor<Device, T> functor;
functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(),