aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc141
1 files changed, 141 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
new file mode 100644
index 0000000000..637360d149
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
@@ -0,0 +1,141 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific BatchMatMul Op.
+// The current implementation simply unrolls the computation along the batch
+// dimension.
+// TODO(dominikg,phawkins): Use a real batched matmul instead of unrolling.
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class BatchMatMulOp : public XlaOpKernel {
+ public:
+ explicit BatchMatMulOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("adj_x", &adj_x_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("adj_y", &adj_y_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape x_shape = ctx->InputShape(0);
+ const TensorShape y_shape = ctx->InputShape(1);
+
+ // Check that both tensors have the same number of dimensions. There must be
+ // at least two (the batch dimensions can be empty).
+ OP_REQUIRES(ctx, x_shape.dims() == y_shape.dims(),
+ errors::InvalidArgument("In[0] and In[1] has different ndims: ",
+ x_shape.DebugString(), " vs. ",
+ y_shape.DebugString()));
+ const int ndims = x_shape.dims();
+ OP_REQUIRES(
+ ctx, ndims >= 2,
+ errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
+
+ // The batch dimensions must be equal and the matrix dimensions must be
+ // valid.
+ std::vector<int64> dimensions;
+ int batch_count = 1;
+ for (int i = 0; i < ndims - 2; ++i) {
+ OP_REQUIRES(
+ ctx, x_shape.dim_size(i) == y_shape.dim_size(i),
+ errors::InvalidArgument("In[0].dim(", i, ") and In[1].dim(", i,
+ ") must be the same: ", x_shape.DebugString(),
+ " vs ", y_shape.DebugString()));
+ dimensions.push_back(x_shape.dim_size(i));
+ batch_count *= x_shape.dim_size(i);
+ }
+
+ int x_inner_dim = adj_x_ ? (ndims - 2) : (ndims - 1);
+ int y_inner_dim = adj_y_ ? (ndims - 1) : (ndims - 2);
+ OP_REQUIRES(
+ ctx, x_shape.dim_size(x_inner_dim) == y_shape.dim_size(y_inner_dim),
+ errors::InvalidArgument(
+ "In[0] mismatch In[1] shape: ", x_shape.dim_size(x_inner_dim),
+ " vs. ", y_shape.dim_size(y_inner_dim), ": ", x_shape.DebugString(),
+ " ", y_shape.DebugString(), " ", adj_x_, " ", adj_y_));
+
+ int x_outer_dim = adj_x_ ? (ndims - 1) : (ndims - 2);
+ int y_outer_dim = adj_y_ ? (ndims - 2) : (ndims - 1);
+ dimensions.push_back(x_shape.dim_size(x_outer_dim));
+ dimensions.push_back(y_shape.dim_size(y_outer_dim));
+
+ xla::ComputationBuilder* builder = ctx->builder();
+
+ xla::ComputationDataHandle x_handle = ctx->Input(0);
+ xla::ComputationDataHandle y_handle = ctx->Input(1);
+
+ // Reshape input tensors into 3D tensors by flattening the batch
+ // dimensions. This makes it easier to unroll the batch dimension.
+ auto x_flat =
+ builder->Reshape(x_handle, {batch_count, x_shape.dim_size(ndims - 2),
+ x_shape.dim_size(ndims - 1)});
+ auto y_flat =
+ builder->Reshape(y_handle, {batch_count, y_shape.dim_size(ndims - 2),
+ y_shape.dim_size(ndims - 1)});
+
+ // Slice batches into individual matrices and multiply them.
+ std::vector<xla::ComputationDataHandle> out_slices;
+ for (int i = 0; i < batch_count; ++i) {
+ // Slice off individual matrices and reshape to 2D tensors.
+ auto x_slice = builder->Slice(
+ x_flat, {i, 0, 0},
+ {i + 1, x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)});
+ x_slice = builder->Reshape(
+ x_slice, {x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)});
+ auto y_slice = builder->Slice(
+ y_flat, {i, 0, 0},
+ {i + 1, y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)});
+ y_slice = builder->Reshape(
+ y_slice, {y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)});
+
+ // Transpose if needed.
+ auto lhs = adj_x_ ? builder->Transpose(x_slice, {1, 0}) : x_slice;
+ auto rhs = adj_y_ ? builder->Transpose(y_slice, {1, 0}) : y_slice;
+
+ // Multiply matrices and add an outer singleton dimension to the output
+ // so we can concatenate along the flattened batch dimension later.
+ auto out = builder->Dot(lhs, rhs);
+ out = builder->Reshape(out,
+ {1, dimensions[ndims - 2], dimensions[ndims - 1]});
+ out_slices.push_back(out);
+ }
+
+ // Concatenate output slices and reshape to original number of dimensions.
+ xla::ComputationDataHandle data;
+ if (out_slices.empty()) {
+ // It is illegal to pass an empty list to ConcatInDim.
+ // The batch count is empty, so both inputs must have zero elements.
+ // Arbitrarily use the left input as the argument to Reshape().
+ data = x_handle;
+ } else {
+ data = builder->ConcatInDim(out_slices, 0);
+ }
+ data = builder->Reshape(data, dimensions);
+
+ ctx->SetOutput(0, data);
+ }
+
+ private:
+ bool adj_x_;
+ bool adj_y_;
+};
+
+REGISTER_XLA_OP("BatchMatMul", BatchMatMulOp);
+
+} // namespace
+} // namespace tensorflow