aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-21 17:31:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-21 17:52:15 -0800
commit4891c01b1cadf085a915a3eac5dd1b8d8cdee203 (patch)
tree87ec00e1927877ba26a2ffb69bc4f74f25c36f6a /tensorflow
parent123c2bb0af532d5fdaa05358158da33497d4bfe6 (diff)
Allow (safe) in-place computation in TensorFlow C++ ops. When at least one input tensor has the same size and type as the output, and the underlying buffer is owned by the op, i.e. when its refcount is 1 at the time the op's Compute method executes, the computation can be performed in place and allocation of the output buffer avoided.
I updated the following ops to perform in-place computation automatically when possible: * All standard coefficient-wise unary and binary operators (including with broadcasting) inheriting from base classes in kernels/cwise_ops_common.h. * unary and binary operators inheriting from base classes in framework/numeric_op.h. This is mostly old code for the Relu family and associated gradients. * All linear algebra ops inheriting from linalg_common. * Misc individual files/ops: softmax, select, bias, aggregate ops, batch_norm & fused_batch_norm, adjust_hue, constant, depthwise_conv_grad, fractional_avg_pool, misc. pooling ops, matrix_set_diag, xent & sparse_xent, unique_op. Change: 148166936
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/framework/numeric_op.h15
-rw-r--r--tensorflow/core/framework/op_kernel.cc100
-rw-r--r--tensorflow/core/framework/op_kernel.h37
-rw-r--r--tensorflow/core/framework/tensor.cc8
-rw-r--r--tensorflow/core/framework/tensor.h4
-rw-r--r--tensorflow/core/kernels/adjust_hue_op.cc6
-rw-r--r--tensorflow/core/kernels/aggregate_ops.cc8
-rw-r--r--tensorflow/core/kernels/batch_norm_op.cc20
-rw-r--r--tensorflow/core/kernels/bias_op.cc12
-rw-r--r--tensorflow/core/kernels/constant_op.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_select.cc15
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.cc9
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.h21
-rw-r--r--tensorflow/core/kernels/depthwise_conv_grad_op.cc13
-rw-r--r--tensorflow/core/kernels/fractional_avg_pool_op.cc6
-rw-r--r--tensorflow/core/kernels/fractional_max_pool_op.cc6
-rw-r--r--tensorflow/core/kernels/fused_batch_norm_op.cc4
-rw-r--r--tensorflow/core/kernels/linalg_ops_common.cc31
-rw-r--r--tensorflow/core/kernels/matrix_set_diag_op.cc6
-rw-r--r--tensorflow/core/kernels/maxpooling_op.cc12
-rw-r--r--tensorflow/core/kernels/softmax_op.h6
-rw-r--r--tensorflow/core/kernels/sparse_xent_op.cc12
-rw-r--r--tensorflow/core/kernels/unique_op.cc4
-rw-r--r--tensorflow/core/kernels/xent_op.cc8
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py97
-rw-r--r--tensorflow/python/kernel_tests/slice_op_test.py9
26 files changed, 365 insertions, 108 deletions
diff --git a/tensorflow/core/framework/numeric_op.h b/tensorflow/core/framework/numeric_op.h
index f24bcfead3..891e077657 100644
--- a/tensorflow/core/framework/numeric_op.h
+++ b/tensorflow/core/framework/numeric_op.h
@@ -56,9 +56,11 @@ class UnaryElementWiseOp : public UnaryOp<T> {
void Compute(OpKernelContext* context) override {
// Output shape is the same as input shape.
const Tensor& input = context->input(0);
- Tensor* output;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, input.shape(), &output));
+ Tensor* output = nullptr;
+ if (!context->forward_input_to_output(0, 0, &output)) {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+ }
static_cast<CHILD*>(this)->Operate(context, input, output);
}
};
@@ -77,8 +79,11 @@ class BinaryElementWiseOp : public BinaryOp<T> {
return;
}
- Tensor* output;
- OP_REQUIRES_OK(context, context->allocate_output(0, a.shape(), &output));
+ Tensor* output = nullptr;
+ if (!context->forward_input_to_output(0, 0, &output) &&
+ !context->forward_input_to_output(1, 0, &output)) {
+ OP_REQUIRES_OK(context, context->allocate_output(0, a.shape(), &output));
+ }
// Dispatch to the descendant's Operate() function.
switch (a.dims()) {
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index b35e4ac243..a56b8cb4b3 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -347,6 +347,106 @@ void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
(*params_->inputs)[input_index].tensor);
}
+bool OpKernelContext::forward_input_to_output(int input_index, int output_index,
+ Tensor** output) {
+ DCHECK_GE(input_index, 0);
+ DCHECK_LT(input_index, params_->inputs->size());
+ const TensorValue& input = (*params_->inputs)[input_index];
+ if (input.tensor == nullptr) {
+ return false;
+ }
+ return forward_input_to_output_with_shape(input_index, output_index,
+ input.tensor->shape(), output);
+}
+
+Status OpKernelContext::forward_input_to_output(StringPiece input_name,
+ StringPiece output_name,
+ Tensor** output) {
+ int input_index, output_index, stop;
+ TF_RETURN_IF_ERROR(
+ params_->op_kernel->InputRange(input_name, &input_index, &stop));
+ if (stop != input_index + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued input name '",
+ input_name,
+ "' when single-valued input was "
+ "expected");
+ }
+ TF_RETURN_IF_ERROR(
+ params_->op_kernel->OutputRange(output_name, &output_index, &stop));
+ if (stop != output_index + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued output name '",
+ output_name,
+ "' when single-valued output was "
+ "expected");
+ }
+ if (!forward_input_to_output(input_index, output_index, output)) {
+ return errors::FailedPrecondition("OpKernel could not forward input '",
+ input_name, "' to output '", output_name);
+ }
+ return Status::OK();
+}
+
+bool OpKernelContext::forward_input_to_output_with_shape(
+ int input_index, int output_index, const TensorShape& output_shape,
+ Tensor** output) {
+ DCHECK_GE(input_index, 0);
+ DCHECK_LT(input_index, params_->inputs->size());
+ const TensorValue& input = (*params_->inputs)[input_index];
+ // Check that input tensor exists, is not a ref, and have no other consumers.
+ if (input.tensor == nullptr || input.is_ref() || !input->RefCountIsOne()) {
+ return false;
+ }
+ DCHECK_GE(output_index, 0);
+ DCHECK_LT(output_index, num_outputs());
+ // Check that input and output types match.
+ if (expected_output_dtype(output_index) != input_dtype(input_index)) {
+ return false;
+ }
+ // Check that the input and output sizes are compatible.
+ if (input.tensor->shape().num_elements() != output_shape.num_elements()) {
+ return false;
+ }
+ // Check that input and output memory types match, i.e.
+ // that they either both live in host or both live in device memmory.
+ if (op_kernel().output_memory_types()[output_index] !=
+ op_kernel().input_memory_types()[input_index]) {
+ return false;
+ }
+ Tensor* output_tensor = new Tensor();
+ CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
+ outputs_[output_index] = TensorValue(output_tensor);
+ *output = outputs_[output_index].tensor;
+ return true;
+}
+
+Status OpKernelContext::forward_input_to_output_with_shape(
+ StringPiece input_name, StringPiece output_name,
+ const TensorShape& output_shape, Tensor** output) {
+ int input_index, output_index, stop;
+ TF_RETURN_IF_ERROR(
+ params_->op_kernel->InputRange(input_name, &input_index, &stop));
+ if (stop != input_index + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued input name '",
+ input_name,
+ "' when single-valued input was "
+ "expected");
+ }
+ TF_RETURN_IF_ERROR(
+ params_->op_kernel->OutputRange(output_name, &output_index, &stop));
+ if (stop != output_index + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued output name '",
+ output_name,
+ "' when single-valued output was "
+ "expected");
+ }
+ if (!forward_input_to_output_with_shape(input_index, output_index,
+ output_shape, output)) {
+ return errors::FailedPrecondition("OpKernel could not forward input '",
+ input_name, "' to output '", output_name);
+ }
+ return Status::OK();
+}
+
void OpKernelContext::delete_ref_input(int index, bool lock_held) {
DCHECK_GE(index, 0);
DCHECK_LT(index, params_->inputs->size());
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 75ad4bb7fc..201e247615 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -643,12 +643,6 @@ class OpKernelContext {
Status replace_ref_input(StringPiece name, const Tensor& tensor,
bool lock_held);
- // Set the output Ref Tensor at output_index to be an alias of the
- // input Ref Tensor at input_index.
- // REQUIRES: IsRefType(input_dtype(input_index)).
- // REQUIRES: IsRefType(output_dtype(output_index)).
- void forward_ref_input_to_ref_output(int input_index, int output_index);
-
// Deletes the Tensor object used as the Ref Input at
// input_index. This is not usually necessary and should be used
// with caution. If !lock_held the input mutex will be acquired
@@ -667,6 +661,37 @@ class OpKernelContext {
// Usage: if (!context->ValidateInputsAreSameShape(this)) return;
bool ValidateInputsAreSameShape(OpKernel* op);
+ // Input to output forwarding.
+
+ // Set the output Ref Tensor at output_index to be an alias of the
+ // input Ref Tensor at input_index.
+ // REQUIRES: IsRefType(input_dtype(input_index)).
+ // REQUIRES: IsRefType(output_dtype(output_index)).
+ void forward_ref_input_to_ref_output(int input_index, int output_index);
+
+ // Returns true when an alias to input[input_index] that is safe to use for
+ // in-place computation was written to *output. Returns false if
+ // input[input_index] has a refcount greater than or if its type does not
+ // match the expected output type of output[output_index].
+ bool forward_input_to_output(int input_index, int output_index,
+ Tensor** output);
+ Status forward_input_to_output(StringPiece input_name,
+ StringPiece output_name, Tensor** output);
+
+ // Returns true when an alias to input[input_index], reshaped to output_shape,
+ // which is is safe to use for in-place computation was written to *output.
+ // Returns false if input[input_index] has a refcount greater than one, or if
+ // its type does not match the expected output type of output[output_index],
+ // or the number of elements in input[input_index] does not equal the number
+ // of elements in output_shape.
+ bool forward_input_to_output_with_shape(int input_index, int output_index,
+ const TensorShape& output_shape,
+ Tensor** output);
+ Status forward_input_to_output_with_shape(StringPiece input_name,
+ StringPiece output_name,
+ const TensorShape& output_shape,
+ Tensor** output);
+
// Output
// Returns the named list-valued output in "list", as defined in the OpDef.
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index f622d031f2..68c6817448 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -526,6 +526,14 @@ void Tensor::UnsafeCopyFromInternal(const Tensor& other, DataType dtype,
}
}
+// Notice that buf_ either points to a regular TensorBuffer or a SubBuffer.
+// For the latter case, we have to make sure that the refcount is
+// one both for the SubBuffer _and_ the underlying TensorBuffer.
+bool Tensor::RefCountIsOne() const {
+ return buf_ != nullptr && buf_->RefCountIsOne() &&
+ buf_->root_buffer()->RefCountIsOne();
+}
+
// The macro CASES() expands to a switch statement conditioned on
// TYPE_ENUM. Each case expands the STMTS after a typedef for T.
#define SINGLE_ARG(...) __VA_ARGS__
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index c9ddad3bdb..d9b22525c4 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -414,6 +414,9 @@ class Tensor {
const TensorShape&);
private:
+ // Returns true if the refcount on buf_ and any possible underlying root
+ // buffer is one.
+ bool RefCountIsOne() const;
void CheckType(DataType expected_dtype) const;
void CheckTypeAndIsAligned(DataType expected_dtype) const;
void CheckIsAlignedAndSingleElement() const;
@@ -439,6 +442,7 @@ class Tensor {
friend class TensorTestHelper; // For access to set_shape
template <typename Device, typename T>
friend class CreateVariableOp;
+ friend class OpKernelContext; // For access to RefCountIsOne().
// Creates a tensor with the input datatype, shape and buf.
//
diff --git a/tensorflow/core/kernels/adjust_hue_op.cc b/tensorflow/core/kernels/adjust_hue_op.cc
index 98934b4e5b..144bde2889 100644
--- a/tensorflow/core/kernels/adjust_hue_op.cc
+++ b/tensorflow/core/kernels/adjust_hue_op.cc
@@ -58,8 +58,10 @@ class AdjustHueOpBase : public OpKernel {
channels, " channels."));
Tensor* output = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, input.shape(), &output));
+ if (!context->forward_input_to_output(0, 0, &output)) {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+ }
if (input.NumElements() > 0) {
const int64 channel_count = input.NumElements() / channels;
diff --git a/tensorflow/core/kernels/aggregate_ops.cc b/tensorflow/core/kernels/aggregate_ops.cc
index 50d0cc1727..0f5186eb07 100644
--- a/tensorflow/core/kernels/aggregate_ops.cc
+++ b/tensorflow/core/kernels/aggregate_ops.cc
@@ -49,7 +49,13 @@ class AddNOp : public OpKernel {
}
Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input0.shape(), &output));
+ bool reused_input_buffer = false;
+ for (int i = 0; i < num && !reused_input_buffer; ++i) {
+ reused_input_buffer = ctx->forward_input_to_output(i, 0, &output);
+ }
+ if (!reused_input_buffer) {
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input0.shape(), &output));
+ }
auto To = output->flat<T>();
#define I(IDX) ctx->input(IDX).flat<T>()
diff --git a/tensorflow/core/kernels/batch_norm_op.cc b/tensorflow/core/kernels/batch_norm_op.cc
index f4aa759643..7c95d4dd20 100644
--- a/tensorflow/core/kernels/batch_norm_op.cc
+++ b/tensorflow/core/kernels/batch_norm_op.cc
@@ -115,15 +115,25 @@ class BatchNormGradOp : public OpKernel {
out_backprop.shape().DebugString()));
Tensor* dx = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &dx));
+ if (!context->forward_input_to_output(0, 0, &dx)) {
+ OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &dx));
+ }
Tensor* dm = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(1, mean.shape(), &dm));
+ if (!context->forward_input_to_output(1, 1, &dm)) {
+ OP_REQUIRES_OK(context, context->allocate_output(1, mean.shape(), &dm));
+ }
Tensor* dv = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(2, var.shape(), &dv));
+ if (!context->forward_input_to_output(2, 2, &dv)) {
+ OP_REQUIRES_OK(context, context->allocate_output(2, var.shape(), &dv));
+ }
Tensor* db = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(3, mean.shape(), &db));
+ if (!context->forward_input_to_output(3, 3, &db)) {
+ OP_REQUIRES_OK(context, context->allocate_output(3, mean.shape(), &db));
+ }
Tensor* dg = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg));
+ if (!context->forward_input_to_output(4, 4, &dg)) {
+ OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg));
+ }
// Scratch buffer of [depth] dimension, aka the 4th dimension of input,
// which is dim_size(3), for calculating various combinations of
diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc
index 46e12cff2a..92696f8c07 100644
--- a/tensorflow/core/kernels/bias_op.cc
+++ b/tensorflow/core/kernels/bias_op.cc
@@ -74,8 +74,10 @@ class BiasOp<CPUDevice, T> : public BinaryOp<T> {
bias.shape().DebugString(), " vs. ", input.shape().DebugString()));
Tensor* output = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, input.shape(), &output));
+ if (!context->forward_input_to_output(0, 0, &output)) {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+ }
if (input.NumElements() == 0) return;
switch (input.shape().dims()) {
@@ -271,8 +273,10 @@ class BiasOp<GPUDevice, T> : public BinaryOp<T> {
bias.shape().DebugString(), " vs. ", channel, " in ",
input.shape().DebugString()));
Tensor* output = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, input.shape(), &output));
+ if (!context->forward_input_to_output(0, 0, &output)) {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+ }
if (input.NumElements() > 0) {
BiasGPU<T>::compute(context->template eigen_device<Device>(),
input.flat<T>().data(), bias.flat<T>().data(),
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index 306736fe54..0de6f38451 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -228,7 +228,9 @@ class ZerosLikeOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& input = ctx->input(0);
Tensor* out = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &out));
+ if (!ctx->forward_input_to_output(0, 0, &out)) {
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &out));
+ }
functor::SetZeroFunctor<Device, T> f;
f(ctx->eigen_device<Device>(), out->flat<T>());
}
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc
index 8160fb74c2..0404d0e997 100644
--- a/tensorflow/core/kernels/cwise_op_select.cc
+++ b/tensorflow/core/kernels/cwise_op_select.cc
@@ -92,7 +92,10 @@ class SelectOp : public OpKernel {
else_->shape().DebugString()));
Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));
+ if (!ctx->forward_input_to_output("t", "output", &output).ok() &&
+ !ctx->forward_input_to_output("e", "output", &output).ok()) {
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));
+ }
if (output->NumElements() > 0) {
functor::BatchSelectFunctor<Device, T> func;
func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(),
@@ -105,7 +108,10 @@ class SelectOp : public OpKernel {
const Tensor* then, const Tensor* else_) {
if (!ctx->ValidateInputsAreSameShape(this)) return;
Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));
+ if (!ctx->forward_input_to_output("t", "output", &output).ok() &&
+ !ctx->forward_input_to_output("e", "output", &output).ok()) {
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));
+ }
if (output->NumElements() > 0) {
functor::SelectFunctor<Device, T> func;
func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(),
@@ -123,7 +129,10 @@ class SelectOp : public OpKernel {
else_->shape().DebugString()));
Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));
+ if (!ctx->forward_input_to_output("t", "output", &output).ok() &&
+ !ctx->forward_input_to_output("e", "output", &output).ok()) {
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));
+ }
if (output->NumElements() > 0) {
functor::SelectScalarFunctor<Device, T> func;
diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc
index c675faeea1..0a3b29b970 100644
--- a/tensorflow/core/kernels/cwise_ops_common.cc
+++ b/tensorflow/core/kernels/cwise_ops_common.cc
@@ -55,11 +55,14 @@ BinaryOpShared::BinaryOpState::BinaryOpState(OpKernelContext* ctx)
in1.shape().DebugString()));
return;
}
- OP_REQUIRES_OK(
- ctx, ctx->allocate_output(0, BCast::ToShape(bcast.output_shape()), &out));
- out_num_elements = out->NumElements();
+ const TensorShape output_shape = BCast::ToShape(bcast.output_shape());
+ out_num_elements = output_shape.num_elements();
in0_num_elements = in0.NumElements();
in1_num_elements = in1.NumElements();
+ if (!ctx->forward_input_to_output_with_shape(0, 0, output_shape, &out) &&
+ !ctx->forward_input_to_output_with_shape(1, 0, output_shape, &out)) {
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out));
+ }
ndims = static_cast<int>(bcast.x_reshape().size());
}
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h
index c825a91fb1..fdcbba680d 100644
--- a/tensorflow/core/kernels/cwise_ops_common.h
+++ b/tensorflow/core/kernels/cwise_ops_common.h
@@ -48,7 +48,9 @@ class BinaryOpShared : public OpKernel {
protected:
struct BinaryOpState {
// Sets up bcast with the shape of in0 and in1, ensures that the bcast
- // is valid, and if so, allocates out using ctx->output(...).
+ // is valid, and if so, set out, either by allocating a new buffer using
+ // ctx->output(...) or by creating an alias for an owned input buffer for
+ // in-place computation.
// Caller must check ctx->status() upon return for non-ok status.
// If ctx->status().ok() is true, then out is guaranteed to be allocated.
BinaryOpState(OpKernelContext* ctx);
@@ -168,14 +170,18 @@ class SimpleBinaryOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& in0 = ctx->input(0);
const Tensor& in1 = ctx->input(1);
-
- Tensor* out;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out));
- auto out_flat = out->flat<Tout>();
auto in0_flat = in0.flat<Tin>();
auto in1_flat = in1.flat<Tin>();
const Device& eigen_device = ctx->eigen_device<Device>();
+ Tensor* out = nullptr;
+ if (!std::is_same<Tin, Tout>::value) {
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out));
+ } else if (!ctx->forward_input_to_output(0, 0, &out) &&
+ !ctx->forward_input_to_output(1, 0, &out)) {
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out));
+ }
+ auto out_flat = out->flat<Tout>();
functor::SimpleBinaryFunctor<Device, Functor>()(eigen_device, out_flat,
in0_flat, in1_flat);
}
@@ -200,7 +206,10 @@ class UnaryOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& inp = ctx->input(0);
Tensor* out = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
+ if (!std::is_same<Tin, Tout>::value ||
+ !ctx->forward_input_to_output(0, 0, &out)) {
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
+ }
functor::UnaryFunctor<Device, Functor>()(
ctx->eigen_device<Device>(), out->flat<Tout>(), inp.flat<Tin>());
}
diff --git a/tensorflow/core/kernels/depthwise_conv_grad_op.cc b/tensorflow/core/kernels/depthwise_conv_grad_op.cc
index f9076cb903..a55365cb49 100644
--- a/tensorflow/core/kernels/depthwise_conv_grad_op.cc
+++ b/tensorflow/core/kernels/depthwise_conv_grad_op.cc
@@ -542,9 +542,10 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
EXTRACT_AND_VERIFY_DIMENSIONS("DepthwiseConv2DBackpropInput");
Tensor* in_backprop = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, input_shape, &in_backprop));
-
+ if (!context->forward_input_to_output(0, 0, &in_backprop)) {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input_shape, &in_backprop));
+ }
auto out_backprop_ptr = out_backprop.template flat<T>().data();
auto filter_ptr = filter.template flat<T>().data();
auto in_backprop_ptr = in_backprop->template flat<T>().data();
@@ -925,8 +926,10 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
EXTRACT_AND_VERIFY_DIMENSIONS("DepthwiseConv2DBackpropFilter");
Tensor* filter_backprop = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, filter_shape, &filter_backprop));
+ if (!context->forward_input_to_output(1, 0, &filter_backprop)) {
+ OP_REQUIRES_OK(
+ context, context->allocate_output(0, filter_shape, &filter_backprop));
+ }
auto out_backprop_ptr = out_backprop.template flat<T>().data();
auto input_ptr = input.template flat<T>().data();
diff --git a/tensorflow/core/kernels/fractional_avg_pool_op.cc b/tensorflow/core/kernels/fractional_avg_pool_op.cc
index 9bba6712a2..4a3ef59211 100644
--- a/tensorflow/core/kernels/fractional_avg_pool_op.cc
+++ b/tensorflow/core/kernels/fractional_avg_pool_op.cc
@@ -323,8 +323,10 @@ class FractionalAvgPoolGradOp : public OpKernel {
// Depending on the type, cast double to type T.
Tensor* in_backprop_tensor = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, in_shape, &in_backprop_tensor));
+ if (!context->forward_input_to_output(0, 0, &in_backprop_tensor)) {
+ OP_REQUIRES_OK(
+ context, context->allocate_output(0, in_shape, &in_backprop_tensor));
+ }
auto in_backprop_tensor_flat = in_backprop_tensor->flat<T>();
auto in_backprop_tensor_temp_flat = in_backprop_tensor_temp.flat<double>();
for (int64 i = 0; i < in_backprop_tensor_flat.size(); ++i) {
diff --git a/tensorflow/core/kernels/fractional_max_pool_op.cc b/tensorflow/core/kernels/fractional_max_pool_op.cc
index a422433ecf..45567461e2 100644
--- a/tensorflow/core/kernels/fractional_max_pool_op.cc
+++ b/tensorflow/core/kernels/fractional_max_pool_op.cc
@@ -343,8 +343,10 @@ class FractionalMaxPoolGradOp : public OpKernel {
}
Tensor* output = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, tensor_in.shape(), &output));
+ if (!context->forward_input_to_output(0, 0, &output)) {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, tensor_in.shape(), &output));
+ }
output->flat<T>().setZero();
auto out_backprop_flat = out_backprop.flat<T>();
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index 31570e2fc8..43d31fc221 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -520,7 +520,9 @@ class FusedBatchNormOp : public OpKernel {
}
Tensor* y = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, x.shape(), &y));
+ if (!context->forward_input_to_output(0, 0, &y)) {
+ OP_REQUIRES_OK(context, context->allocate_output(0, x.shape(), &y));
+ }
Tensor* batch_mean = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(1, scale.shape(), &batch_mean));
diff --git a/tensorflow/core/kernels/linalg_ops_common.cc b/tensorflow/core/kernels/linalg_ops_common.cc
index 5fde696963..3ecd3182ff 100644
--- a/tensorflow/core/kernels/linalg_ops_common.cc
+++ b/tensorflow/core/kernels/linalg_ops_common.cc
@@ -171,15 +171,20 @@ void LinearAlgebraOp<Scalar>::PrepareOutputs(
num_outputs, context->num_outputs()));
// Allocate outputs.
- for (int i = 0; i < context->num_outputs(); ++i) {
- TensorShape output_tensor_shape({0});
- if (i < num_outputs) {
+ std::set<int> unused_inputs;
+ for (int input_idx = 0; input_idx < context->num_inputs(); ++input_idx) {
+ unused_inputs.insert(input_idx);
+ }
+ for (int output_idx = 0; output_idx < context->num_outputs(); ++output_idx) {
+ TensorShape output_tensor_shape({});
+ if (output_idx < num_outputs) {
// This output is used, set up output shape and allocate it.
- const TensorShape& output_matrix_shape = output_matrix_shapes->at(i);
+ const TensorShape& output_matrix_shape =
+ output_matrix_shapes->at(output_idx);
OP_REQUIRES(context, output_matrix_shape.dims() <= 2,
errors::InvalidArgument(
"Rank of matrix output no. %d must be 0, 1 or 2, got %d.",
- i, output_matrix_shape.dims()));
+ output_idx, output_matrix_shape.dims()));
// The final output has the shape of the outer batch dimensions
// concatenated with the output_matrix_shape (if the output is not
@@ -190,8 +195,20 @@ void LinearAlgebraOp<Scalar>::PrepareOutputs(
}
}
Tensor* out = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(i, output_tensor_shape, &out));
+ // See if there is an input buffer we can reuse for this output.
+ bool reused_input = false;
+ for (int input_idx : unused_inputs) {
+ if (context->forward_input_to_output_with_shape(
+ input_idx, output_idx, output_tensor_shape, &out)) {
+ reused_input = true;
+ unused_inputs.erase(input_idx);
+ break;
+ }
+ }
+ if (!reused_input) {
+ OP_REQUIRES_OK(context, context->allocate_output(
+ output_idx, output_tensor_shape, &out));
+ }
outputs->push_back(out);
}
}
diff --git a/tensorflow/core/kernels/matrix_set_diag_op.cc b/tensorflow/core/kernels/matrix_set_diag_op.cc
index 952da7d8df..1754e4ad69 100644
--- a/tensorflow/core/kernels/matrix_set_diag_op.cc
+++ b/tensorflow/core/kernels/matrix_set_diag_op.cc
@@ -78,8 +78,10 @@ class MatrixSetDiagOp : public OpKernel {
auto diag_reshaped = diag.flat_inner_dims<T, 2>();
Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &output));
-
+ if (!context->forward_input_to_output(0, 0, &output)) {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input_shape, &output));
+ }
auto output_reshaped = output->flat_inner_dims<T, 3>();
Tensor scratch_tensor;
OP_REQUIRES_OK(context,
diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc
index 98b4558a3a..669597e382 100644
--- a/tensorflow/core/kernels/maxpooling_op.cc
+++ b/tensorflow/core/kernels/maxpooling_op.cc
@@ -290,7 +290,10 @@ class MaxPoolingGradOp : public OpKernel {
}
Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+ if (!context->forward_input_to_output(0, 0, &output)) {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, output_shape, &output));
+ }
SpatialMaxPoolWithArgMaxHelper<CPUDevice, T>(
context, &tensor_out_dup, &tensor_out_arg_max, output, tensor_in,
@@ -319,9 +322,10 @@ static void MaxPoolingBackwardCustomKernel(
const std::vector<int32>& stride, Padding padding, const Tensor* tensor_in,
const Tensor& out_backprop, const TensorShape& tensor_in_shape) {
Tensor* output = nullptr;
-
- OP_REQUIRES_OK(context,
- context->allocate_output(0, tensor_in_shape, &output));
+ if (!context->forward_input_to_output(0, 0, &output)) {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, tensor_in_shape, &output));
+ }
PoolParameters params{context, size, stride,
padding, FORMAT_NHWC, tensor_in_shape};
diff --git a/tensorflow/core/kernels/softmax_op.h b/tensorflow/core/kernels/softmax_op.h
index dc61e26809..e9dbafd589 100644
--- a/tensorflow/core/kernels/softmax_op.h
+++ b/tensorflow/core/kernels/softmax_op.h
@@ -40,8 +40,10 @@ class SoftmaxOp : public OpKernel {
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()),
errors::InvalidArgument("logits must be 2-dimensional"));
Tensor* softmax_out = nullptr;
- OP_REQUIRES_OK(
- context, context->allocate_output(0, logits_in.shape(), &softmax_out));
+ if (!context->forward_input_to_output(0, 0, &softmax_out)) {
+ OP_REQUIRES_OK(context, context->allocate_output(0, logits_in.shape(),
+ &softmax_out));
+ }
if (logits_in.NumElements()) {
functor::SoftmaxFunctor<Device, T> functor;
functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
diff --git a/tensorflow/core/kernels/sparse_xent_op.cc b/tensorflow/core/kernels/sparse_xent_op.cc
index 9c39841fee..4a61e31e8d 100644
--- a/tensorflow/core/kernels/sparse_xent_op.cc
+++ b/tensorflow/core/kernels/sparse_xent_op.cc
@@ -78,11 +78,15 @@ class SparseSoftmaxXentWithLogitsOp : public OpKernel {
labels.shape(), &scratch));
Tensor* loss_out = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, labels.shape(), &loss_out));
+ if (!context->forward_input_to_output(1, 0, &loss_out)) {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, labels.shape(), &loss_out));
+ }
Tensor* back_out = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(1, logits.shape(), &back_out));
+ if (!context->forward_input_to_output(0, 1, &back_out)) {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, logits.shape(), &back_out));
+ }
if (logits.dim_size(0) > 0) {
if (std::is_same<Device, CPUDevice>::value) {
diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc
index 6aa9d6accb..e06fe20e79 100644
--- a/tensorflow/core/kernels/unique_op.cc
+++ b/tensorflow/core/kernels/unique_op.cc
@@ -46,7 +46,9 @@ class UniqueOp : public OpKernel {
const int64 N = static_cast<int64>(Tin.size());
Tensor* idx = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(1, input.shape(), &idx));
+ if (!context->forward_input_to_output(0, 1, &idx)) {
+ OP_REQUIRES_OK(context, context->allocate_output(1, input.shape(), &idx));
+ }
auto idx_vec = idx->template vec<int32>();
std::unordered_map<T, int32> uniq;
diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc
index 639bad5f04..2a0ef63eab 100644
--- a/tensorflow/core/kernels/xent_op.cc
+++ b/tensorflow/core/kernels/xent_op.cc
@@ -61,9 +61,11 @@ class SoftmaxXentWithLogitsOp : public OpKernel {
context->allocate_output(
0, TensorShape({logits_in.dim_size(0)}), &loss_out));
Tensor* back_out = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(1, logits_in.shape(), &back_out));
-
+ // Try to reuse the logits_in buffer for the backprop output.
+ if (!context->forward_input_to_output(0, 1, &back_out)) {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, logits_in.shape(), &back_out));
+ }
functor::XentFunctor<Device, T> functor;
functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(),
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index e9db47716d..6c7cbbff9c 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -326,9 +326,8 @@ class ControlFlowTest(test.TestCase):
def testFetchables(self):
with self.test_session() as sess:
x = array_ops.placeholder(dtypes.float32)
- control_flow_ops.cond(constant_op.constant(True),
- lambda: x + 2,
- lambda: x + 0)
+ control_flow_ops.cond(
+ constant_op.constant(True), lambda: x + 2, lambda: x + 0)
tensor_names = all_fetchables()
for name in tensor_names:
sess.run(name, feed_dict={x: 3})
@@ -388,11 +387,12 @@ class ControlFlowTest(test.TestCase):
rv = resource_variable_ops.ResourceVariable(True)
variables.global_variables_initializer().run()
t = ops.convert_to_tensor(1.0)
+
def case():
- assign = resource_variable_ops.assign_variable_op(
- rv.handle, False)
+ assign = resource_variable_ops.assign_variable_op(rv.handle, False)
with ops.control_dependencies([assign]):
return array_ops.identity(t)
+
self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval())
def testCondIndexedSlicesDifferentTypes(self):
@@ -544,13 +544,15 @@ class ControlFlowTest(test.TestCase):
with self.test_session() as sess:
control_holder = array_ops.placeholder(dtypes.float32, shape=())
a = constant_op.constant(3)
+
def true_branch():
with ops.control_dependencies([control_holder]):
_ = a + 1
return a + 2
- r = control_flow_ops.cond(constant_op.constant(True),
- true_branch,
- lambda: constant_op.constant(1))
+
+ r = control_flow_ops.cond(
+ constant_op.constant(True), true_branch,
+ lambda: constant_op.constant(1))
self.assertEqual(5, r.eval())
def testUninitializedRefIdentity(self):
@@ -770,16 +772,37 @@ class ControlFlowTest(test.TestCase):
o = ops.convert_to_tensor([0])
x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
s = array_ops.size(x)
- r = control_flow_ops.while_loop(lambda i, c, o: math_ops.less(i, s),
- compute, [i, c, o], [
- i.get_shape(),
- tensor_shape.unknown_shape(),
- tensor_shape.unknown_shape()
- ])
+ r = control_flow_ops.while_loop(
+ lambda i, c, o: math_ops.less(i, s), compute, [i, c, o], [
+ i.get_shape(), tensor_shape.unknown_shape(),
+ tensor_shape.unknown_shape()
+ ])
result = r[2].eval()
self.assertTrue(check_op_order(i.graph))
self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result)
+ def testBufferForwarding(self):
+ run_options = config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE)
+ run_metadata = config_pb2.RunMetadata()
+
+ with self.test_session() as sess:
+ with ops.device("/cpu:0"):
+ c = constant_op.constant(2)
+ i0 = constant_op.constant(0)
+ r = control_flow_ops.while_loop(lambda i: i < 1000,
+ lambda i: math_ops.square(c) + i, [i0])
+ r_val = sess.run(r, options=run_options, run_metadata=run_metadata)
+ self.assertEqual(1000, r_val)
+ self.assertTrue(run_metadata.HasField("step_stats"))
+ unique_allocs = set()
+ for node_stat in run_metadata.step_stats.dev_stats[0].node_stats:
+ for output in node_stat.output:
+ unique_allocs.add(
+ output.tensor_description.allocation_description.ptr)
+ # Prior to cl/147536680, the number of unique allocations was about 1005.
+ self.assertLess(len(unique_allocs), 756)
+
def _testWhile_Gpu_1(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
n = constant_op.constant(1.0)
@@ -1368,8 +1391,9 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(45, rx.eval())
def _testWhileGrad_ColocateGradients(self, colocate):
- gpu_dev_name = test.gpu_device_name() if test.is_gpu_available() else "/gpu:0"
- gpu_short_name = gpu_dev_name.split('/')[-1]
+ gpu_dev_name = test.gpu_device_name() if test.is_gpu_available(
+ ) else "/gpu:0"
+ gpu_short_name = gpu_dev_name.split("/")[-1]
with self.test_session(graph=ops.Graph()) as sess:
v = constant_op.constant(2.0, name="v")
@@ -1485,16 +1509,21 @@ class ControlFlowTest(test.TestCase):
def _testNestedWhileCondWhileGrad(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
v = constant_op.constant(1.0)
+
def inner_loop(s):
z = constant_op.constant(0)
c = lambda i, x: math_ops.less(i, 4)
b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
return control_flow_ops.while_loop(c, b, [z, s])
+
c = lambda x: math_ops.less(x, 128.0)
+
def b(x):
- return control_flow_ops.cond(constant_op.constant(True),
- lambda: math_ops.square(inner_loop(x)[1]),
- lambda: math_ops.multiply(x, 2.0))
+ return control_flow_ops.cond(
+ constant_op.constant(True),
+ lambda: math_ops.square(inner_loop(x)[1]),
+ lambda: math_ops.multiply(x, 2.0))
+
r = control_flow_ops.while_loop(c, b, [v])
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(512.0, r.eval())
@@ -1550,10 +1579,9 @@ class ControlFlowTest(test.TestCase):
with self.test_session() as sess:
named = collections.namedtuple("named", ("a", "b"))
loop_vars = [
- named(
- a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
- (constant_op.constant(2.0), constant_op.constant(3.0)),
- constant_op.constant(4.0)
+ named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
+ (constant_op.constant(2.0),
+ constant_op.constant(3.0)), constant_op.constant(4.0)
]
c = lambda lv0, _1, _2: lv0.a < 100.0
@@ -1578,10 +1606,9 @@ class ControlFlowTest(test.TestCase):
with self.test_session():
named = collections.namedtuple("named", ("a", "b"))
loop_vars = [
- named(
- a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
- (constant_op.constant(2.0), constant_op.constant(3.0)),
- constant_op.constant(4.0)
+ named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
+ (constant_op.constant(2.0),
+ constant_op.constant(3.0)), constant_op.constant(4.0)
]
c = lambda lv0, _1, _2: lv0.a < 100.0
@@ -2522,15 +2549,11 @@ class TupleTest(test.TestCase):
with self.test_session():
v1 = variables.Variable([1.0])
add1 = math_ops.add(
- control_flow_ops.with_dependencies(
- [v1.initializer],
- v1._ref()), # pylint: disable=protected-access
+ control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access
2.0)
v2 = variables.Variable([10.0])
add2 = math_ops.add(
- control_flow_ops.with_dependencies(
- [v2.initializer],
- v2._ref()), # pylint: disable=protected-access
+ control_flow_ops.with_dependencies([v2.initializer], v2._ref()), # pylint: disable=protected-access
20.0)
t1, _, t2 = control_flow_ops.tuple([add1, None, add2])
@@ -2558,18 +2581,14 @@ class TupleTest(test.TestCase):
np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(
np.float32))
v1_at_1 = ops.IndexedSlices(
- control_flow_ops.with_dependencies(
- [v1.initializer],
- v1._ref()), # pylint: disable=protected-access
+ control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access
constant_op.constant([1]))
v2 = variables.Variable(
np.array([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]]).astype(
np.float32))
v2_at_1 = ops.IndexedSlices(
- control_flow_ops.with_dependencies(
- [v2.initializer],
- v2._ref()), # pylint: disable=protected-access
+ control_flow_ops.with_dependencies([v2.initializer], v2._ref()), # pylint: disable=protected-access
constant_op.constant([1]))
st1, st2 = control_flow_ops.tuple([v1_at_1, v2_at_1])
diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py
index 29f76a2182..c11f78b77e 100644
--- a/tensorflow/python/kernel_tests/slice_op_test.py
+++ b/tensorflow/python/kernel_tests/slice_op_test.py
@@ -269,6 +269,15 @@ class SliceTest(test.TestCase):
c = array_ops.slice(a, [begin, 0], [-1, 2])
self.assertEqual([None, 2], c.get_shape().as_list())
+ def testSliceOfSlice(self):
+ with self.test_session(use_gpu=True):
+ a = constant_op.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
+ b = a[1:, :]
+ c = b[:-1, :]
+ d = c[1, :]
+ res = 2 * d - c[1, :] + a[2, :] - 2 * b[-2, :]
+ self.assertAllEqual([0, 0, 0], res.eval())
+
if __name__ == "__main__":
test.main()