From f41959ccb2d9d4c722fe8fc3351401d53bcf4900 Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Fri, 6 Nov 2015 16:27:58 -0800 Subject: TensorFlow: Initial commit of TensorFlow library. TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108 --- tensorflow/core/kernels/sparse_matmul_op.cc | 192 ++++++++++++++++++++++++++++ 1 file changed, 192 insertions(+) create mode 100644 tensorflow/core/kernels/sparse_matmul_op.cc (limited to 'tensorflow/core/kernels/sparse_matmul_op.cc') diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc new file mode 100644 index 0000000000..919e129ff8 --- /dev/null +++ b/tensorflow/core/kernels/sparse_matmul_op.cc @@ -0,0 +1,192 @@ +// See docs in ../ops/math_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/port.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +void PrefetchBlockNTA(const T& tensor, int si, int ei, int sj, int ej) { + for (int i = si; i < ei; ++i) { + for (int j = sj; j < ej; j = j + 16) { + port::prefetch(&tensor(i, j)); + } + } +} + +template +void PrefetchBlockT1(const T& tensor, int si, int ei, int sj, int ej) { + for (int i = si; i < ei; ++i) { + for (int j = sj; j < ej; j = j + 16) { + port::prefetch(&tensor(i, j)); + } + } +} + +struct Block { + Block(int sm, int em, int sk, int ek, int sn, int en) + : startm(sm), endm(em), startk(sk), endk(ek), startn(sn), endn(en) {} + + int startm; + int endm; + int startk; + int endk; + int startn; + int endn; +}; + +bool NextBlock(const int Bm, const int Bk, const int Bn, const int m_start, + const int m, const int k, const int n, const Block& b, + Block* next) { + *next = b; + if (b.endk < k) { + next->startk = b.endk; + next->endk = std::min(b.endk + Bk, k); + } else { + next->startk = 0; + next->endk = std::min(Bk, k); + if (b.endm < m) { + next->startm = b.endm; + next->endm = std::min(b.endm + Bm, m); + } else { + next->startm = m_start; + next->endm = std::min(m_start + Bm, m); + next->startn = b.endn; + next->endn = std::min(b.endn + Bn, n); + } + } + return next->startn == next->endn; +} + +class SparseMatMulOp : public OpKernel { + public: + explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("a_is_sparse", &a_is_sparse_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("b_is_sparse", &b_is_sparse_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& a = ctx->input(0); + const Tensor& b = ctx->input(1); + + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()), + errors::InvalidArgument("a is not a matrix")); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()), + errors::InvalidArgument("b is not a matrix")); + + auto left = a.matrix(); + auto right_mat = b.matrix(); + const int m = transpose_a_ ? left.dimension(1) : left.dimension(0); + const int k = transpose_a_ ? left.dimension(0) : left.dimension(1); + const int n = + transpose_b_ ? right_mat.dimension(0) : right_mat.dimension(1); + const int k2 = + transpose_b_ ? right_mat.dimension(1) : right_mat.dimension(0); + + OP_REQUIRES(ctx, k == k2, + errors::InvalidArgument("Matrix size incompatible: a: ", + a.shape().DebugString(), ", b: ", + b.shape().DebugString())); + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output)); + auto out = output->matrix(); + + if (!a_is_sparse_) { + // Fallback to Eigen contract. + // Note that we currently don't optimize the case where only right is + // sparse. That can generally be handled by tranposing the order of the + // matmul. + Eigen::array, 1> dim_pair; + dim_pair[0].first = transpose_a_ ? 0 : 1; + dim_pair[0].second = transpose_b_ ? 1 : 0; + out.device(ctx->template eigen_device()) = + left.contract(right_mat, dim_pair); + return; + } + typedef Eigen::Tensor Matrix; + std::unique_ptr right_tr_mat; + std::unique_ptr::ConstMatrix> right_tr_map; + if (transpose_b_) { + right_tr_mat.reset(new Matrix(k, n)); + Eigen::array perm({1, 0}); + right_tr_mat->device(ctx->template eigen_device()) = + right_mat.shuffle(perm); + right_tr_map.reset(new TTypes::ConstMatrix( + right_tr_mat->data(), right_tr_mat->dimensions())); + } + TTypes::ConstMatrix& right = + transpose_b_ ? *right_tr_map : right_mat; + + const bool transpose_a = transpose_a_; + + typedef Eigen::TensorMap, + Eigen::Unaligned> TensorMap; + typedef Eigen::TensorMap, + Eigen::Unaligned> ConstTensorMap; + typedef Eigen::DSizes DSizes; + const int Bm = 16; + const int Bk = 16; + const int Bn = 1024; + + auto work_shard = [m, n, k, transpose_a, Bm, Bk, Bn, &left, &right, &out]( + int64 start64, int64 end64) { + const int start = static_cast(start64); + const int end = static_cast(end64); + Block curr(start, std::min(start + Bm, end), 0, std::min(Bk, k), 0, + std::min(Bn, n)); + Block next(curr); + bool done = false; + for (int i = start; i < end; ++i) { + out.chip<0>(i).setZero(); + } + while (true) { + done = NextBlock(Bm, Bk, Bn, start, end, k, n, curr, &next); + + PrefetchBlockT1(right, curr.startk, curr.endk, curr.startn, curr.endn); + + // Process current block + for (int i = curr.startm; i < curr.endm; ++i) { + PrefetchBlockNTA(left, i, i + 1, curr.startk, curr.endk); + PrefetchBlockNTA(out, i, i + 1, curr.startn, curr.endn); + DSizes out_slice_shape(curr.endn - curr.startn); + TensorMap out_i(&out(i, curr.startn), out_slice_shape); + for (int j = curr.startk; j < curr.endk; ++j) { + const float l = transpose_a ? left(j, i) : left(i, j); + if (l == 0) continue; + ConstTensorMap right_j(&right(j, curr.startn), out_slice_shape); + out_i += right_j * l; + } + } + if (done) break; + curr = next; + } + }; + auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); + Shard(worker_threads.num_threads, worker_threads.workers, m, 2 * k * n, + work_shard); + } + + private: + bool transpose_a_; + bool transpose_b_; + bool a_is_sparse_; + bool b_is_sparse_; + TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp); +}; + +REGISTER_KERNEL_BUILDER(Name("SparseMatMul").Device(DEVICE_CPU), + SparseMatMulOp); + +} // end namespace tensorflow -- cgit v1.2.3