diff options
Diffstat (limited to 'tensorflow/core/kernels/xent_op.cc')
-rw-r--r-- | tensorflow/core/kernels/xent_op.cc | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc index 639bad5f04..2a0ef63eab 100644 --- a/tensorflow/core/kernels/xent_op.cc +++ b/tensorflow/core/kernels/xent_op.cc @@ -61,9 +61,11 @@ class SoftmaxXentWithLogitsOp : public OpKernel { context->allocate_output( 0, TensorShape({logits_in.dim_size(0)}), &loss_out)); Tensor* back_out = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(1, logits_in.shape(), &back_out)); - + // Try to reuse the logits_in buffer for the backprop output. + if (!context->forward_input_to_output(0, 1, &back_out)) { + OP_REQUIRES_OK(context, + context->allocate_output(1, logits_in.shape(), &back_out)); + } functor::XentFunctor<Device, T> functor; functor(context->eigen_device<Device>(), logits_in.matrix<T>(), labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(), |