aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/xent_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-21 17:31:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-21 17:52:15 -0800
commit4891c01b1cadf085a915a3eac5dd1b8d8cdee203 (patch)
tree87ec00e1927877ba26a2ffb69bc4f74f25c36f6a /tensorflow/core/kernels/xent_op.cc
parent123c2bb0af532d5fdaa05358158da33497d4bfe6 (diff)
Allow (safe) in-place computation in TensorFlow C++ ops. When at least one input tensor has the same size and type as the output, and the underlying buffer is owned by the op, i.e. when its refcount is 1 at the time the op's Compute method executes, the computation can be performed in place and allocation of the output buffer avoided.
I updated the following ops to perform in-place computation automatically when possible: * All standard coefficient-wise unary and binary operators (including with broadcasting) inheriting from base classes in kernels/cwise_ops_common.h. * unary and binary operators inheriting from base classes in framework/numeric_op.h. This is mostly old code for the Relu family and associated gradients. * All linear algebra ops inheriting from linalg_common. * Misc individual files/ops: softmax, select, bias, aggregate ops, batch_norm & fused_batch_norm, adjust_hue, constant, depthwise_conv_grad, fractional_avg_pool, misc. pooling ops, matrix_set_diag, xent & sparse_xent, unique_op. Change: 148166936
Diffstat (limited to 'tensorflow/core/kernels/xent_op.cc')
-rw-r--r--tensorflow/core/kernels/xent_op.cc8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc
index 639bad5f04..2a0ef63eab 100644
--- a/tensorflow/core/kernels/xent_op.cc
+++ b/tensorflow/core/kernels/xent_op.cc
@@ -61,9 +61,11 @@ class SoftmaxXentWithLogitsOp : public OpKernel {
context->allocate_output(
0, TensorShape({logits_in.dim_size(0)}), &loss_out));
Tensor* back_out = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(1, logits_in.shape(), &back_out));
-
+ // Try to reuse the logits_in buffer for the backprop output.
+ if (!context->forward_input_to_output(0, 1, &back_out)) {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, logits_in.shape(), &back_out));
+ }
functor::XentFunctor<Device, T> functor;
functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(),