aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/matmul_op.cc
diff options
context:
space:
mode:
authorGravatar Yangzihao Wang <yangzihao@google.com>2017-07-21 09:22:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-21 09:38:31 -0700
commit3e3306ef0009b5b21050139f9b8e5f4868c4c0c7 (patch)
treec7e25f278d93e9ce1ab9e2984df7b97c0f27c6d0 /tensorflow/core/kernels/matmul_op.cc
parent4729180d24af3126d736a7045c43fcbf031b5bef (diff)
Let GetBlasGemmAlgorithms() always return true.
PiperOrigin-RevId: 162748507
Diffstat (limited to 'tensorflow/core/kernels/matmul_op.cc')
-rw-r--r--tensorflow/core/kernels/matmul_op.cc312
1 files changed, 248 insertions, 64 deletions
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index 8003f7ff67..62c5ecfe81 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -23,27 +23,15 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/fill_functor.h"
-
+#include "tensorflow/core/util/matmul_autotune.h"
#if GOOGLE_CUDA
#include "cuda/include/cuda.h"
+#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
-#if GOOGLE_CUDA
-
-namespace {
-template <typename T>
-perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
- perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
- perftools::gputools::DeviceMemory<T> typed(wrapped);
- return typed;
-}
-} // namespace
-
-#endif // GOOGLE_CUDA
-
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
@@ -123,10 +111,16 @@ bool ExplicitVectorMatrixOptimization<Eigen::half>(
template <typename Device, typename T>
struct LaunchMatMulBase {
+#if GOOGLE_CUDA
+ typedef perftools::gputools::blas::AlgorithmType AlgorithmType;
+#else
+ typedef int64 AlgorithmType;
+#endif // GOOGLE_CUDA
+
static void launch(
- OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b,
+ OpKernelContext* ctx, const Tensor& a, const Tensor& b,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
- Tensor* out) {
+ std::vector<AlgorithmType>* algorithms, bool use_aututone, Tensor* out) {
#ifndef TENSORFLOW_USE_SYCL
// An explicit vector-matrix multiply is much better optimized than an
// implicit one and this is a bottleneck during non-batched inference.
@@ -140,6 +134,10 @@ struct LaunchMatMulBase {
}
#endif // TENSORFLOW_USE_SYCL
}
+
+ static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
+ std::vector<int64>* algorithms,
+ bool* algorithm_set_flag) {}
};
// On CPUs, we ignore USE_CUBLAS
template <typename T>
@@ -159,24 +157,39 @@ struct LaunchMatMul<SYCLDevice, T, USE_CUBLAS> : public LaunchMatMulSYCL<T> {};
#if GOOGLE_CUDA
namespace {
+
template <typename T>
struct LaunchBlasGemv {
- static void Compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
- bool trans, uint64 m, uint64 n,
- const perftools::gputools::DeviceMemory<T>& a,
- const perftools::gputools::DeviceMemory<T>& b,
- perftools::gputools::DeviceMemory<T>* c) {
+ static void Compute(
+ OpKernelContext* ctx, perftools::gputools::Stream* stream, bool trans,
+ uint64 m, uint64 n, const perftools::gputools::DeviceMemory<T>& a,
+ const perftools::gputools::DeviceMemory<T>& b,
+ perftools::gputools::DeviceMemory<T>* c,
+ perftools::gputools::blas::ProfileResult* output_profile) {
const auto blas_trans =
trans ? perftools::gputools::blas::Transpose::kTranspose
: perftools::gputools::blas::Transpose::kNoTranspose;
- bool blas_launch_status =
- stream
- ->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1,
- static_cast<T>(0.0), c, 1)
- .ok();
- if (!blas_launch_status) {
- ctx->SetStatus(
- errors::Internal("Blas GEMV launch failed: m=", m, ", n=", n));
+ if (output_profile == nullptr) {
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1,
+ static_cast<T>(0.0), c, 1)
+ .ok();
+ if (!blas_launch_status) {
+ ctx->SetStatus(
+ errors::Internal("Blas GEMV launch failed: m=", m, ", n=", n));
+ }
+ } else {
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemvWithProfiling(blas_trans, m, n, static_cast<T>(1.0),
+ a, m, b, 1, static_cast<T>(0.0), c, 1,
+ output_profile)
+ .ok();
+ if (!blas_launch_status) {
+ ctx->SetStatus(errors::Internal(
+ "Blas GEMV with profiling launch failed: m=", m, ", n=", n));
+ }
}
}
@@ -188,7 +201,8 @@ void LaunchBlasGemv<Eigen::half>::Compute(
OpKernelContext* ctx, perftools::gputools::Stream* stream, bool trans,
uint64 m, uint64 n, const perftools::gputools::DeviceMemory<Eigen::half>& a,
const perftools::gputools::DeviceMemory<Eigen::half>& b,
- perftools::gputools::DeviceMemory<Eigen::half>* c) {
+ perftools::gputools::DeviceMemory<Eigen::half>* c,
+ perftools::gputools::blas::ProfileResult* output_profile) {
ctx->SetStatus(errors::Internal(
"Blas GEMV launch failed: GEMV is not implemented for float16."));
}
@@ -200,15 +214,55 @@ bool LaunchBlasGemv<Eigen::half>::IsSupported() {
} // namespace
+bool GetCublasAutotuneComputationType(
+ const DataType& dtype,
+ perftools::gputools::blas::ComputationType* compute_type) {
+ using perftools::gputools::blas::ComputationType;
+ bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input();
+ switch (dtype) {
+ case DT_HALF:
+ case DT_BFLOAT16:
+ if (use_f32_for_f16_computation) {
+ *compute_type = ComputationType::kF32;
+ } else {
+ *compute_type = ComputationType::kF16;
+ }
+ return false;
+ case DT_FLOAT:
+ *compute_type = ComputationType::kF32;
+ return true;
+ case DT_DOUBLE:
+ *compute_type = ComputationType::kF64;
+ return true;
+ default:
+ // Unsupported compute_type, return false.
+ return false;
+ }
+}
+
+// A dummy type to group matmul autotune results together.
+struct MatmulAutoTuneGroup {
+ static string name() { return "Matmul"; }
+};
+typedef AutoTuneSingleton<MatmulAutoTuneGroup, MatmulParameters,
+ perftools::gputools::blas::AlgorithmConfig>
+ AutoTuneMatmul;
+
template <typename T>
struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
static void launch(
- OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b,
+ OpKernelContext* ctx, const Tensor& a, const Tensor& b,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
- Tensor* out) {
- perftools::gputools::blas::Transpose trans[] = {
- perftools::gputools::blas::Transpose::kNoTranspose,
- perftools::gputools::blas::Transpose::kTranspose};
+ std::vector<int64>* algorithms, bool use_autotune, Tensor* out) {
+ using perftools::gputools::blas::AlgorithmConfig;
+ using perftools::gputools::blas::ComputationType;
+ using perftools::gputools::blas::ProfileResult;
+ using perftools::gputools::blas::Transpose;
+ using perftools::gputools::blas::kDefaultAlgorithm;
+ using perftools::gputools::blas::kDefaultBlasGemm;
+ using perftools::gputools::blas::kDefaultBlasGemv;
+ using perftools::gputools::blas::kNoAlgorithm;
+ Transpose trans[] = {Transpose::kNoTranspose, Transpose::kTranspose};
const uint64 m = a.dim_size(1 - dim_pair[0].first);
const uint64 k = a.dim_size(dim_pair[0].first);
const uint64 n = b.dim_size(1 - dim_pair[0].second);
@@ -220,35 +274,156 @@ struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
auto* stream = ctx->op_device_context()->stream();
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
- auto a_ptr = AsDeviceMemory(a.template flat<T>().data());
- auto b_ptr = AsDeviceMemory(b.template flat<T>().data());
- auto c_ptr = AsDeviceMemory(out->template flat<T>().data());
- // Cublas does
- // C = A x B
- // where A, B and C are assumed to be in column major.
- // We want the output to be in row-major, so we can compute
- // C' = B' x A' (' stands for transpose)
- if (LaunchBlasGemv<T>::IsSupported() && n == 1) {
- // This is a matrix*vector multiply so use GEMV to compute A * b.
- // Here we are multiplying in the natural order, so we have to flip
- // the transposition flag to compensate for the tensor being stored
- // row-major.
- LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a, transpose_a ? m : k,
- transpose_a ? k : m, a_ptr, b_ptr, &c_ptr);
- } else {
- bool blas_launch_status =
- stream
- ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, 1.0f,
- b_ptr, transpose_b ? k : n, a_ptr,
- transpose_a ? m : k, 0.0f, &c_ptr, n)
- .ok();
- if (!blas_launch_status) {
- ctx->SetStatus(errors::Internal(
- "Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ",
- a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
- "), m=", m, ", n=", n, ", k=", k));
+ auto a_ptr = AsDeviceMemory(a.template flat<T>().data(),
+ a.template flat<T>().size());
+ auto b_ptr = AsDeviceMemory(b.template flat<T>().data(),
+ b.template flat<T>().size());
+ auto c_ptr = AsDeviceMemory(out->template flat<T>().data(),
+ out->template flat<T>().size());
+ auto alpha = static_cast<T>(1.0);
+ auto beta = static_cast<T>(0.0);
+
+ int device_id = stream->parent()->device_ordinal();
+ DataType dtype = a.dtype();
+ MatmulParameters matmul_parameters = {
+ transpose_a, transpose_b, m, n, k, dtype, device_id,
+ };
+ AlgorithmConfig algorithm_config(kNoAlgorithm);
+
+ ComputationType computation_type;
+ bool compute_type_supported =
+ GetCublasAutotuneComputationType(dtype, &computation_type);
+ if (use_autotune && compute_type_supported && !algorithms->empty()) {
+ ProfileResult best_result;
+ // TODO(yangzihao): Unify this code with conv autotuning.
+ if (!AutoTuneMatmul::GetInstance()->Find(matmul_parameters,
+ &algorithm_config)) {
+ ProfileResult profile_result;
+ for (auto profile_algorithm : (*algorithms)) {
+ // Cublas does
+ // C = A x B
+ // where A, B and C are assumed to be in column major.
+ // We want the output to be in row-major, so we can compute
+ // C' = B' x A' (' stands for transpose)
+ bool cublas_launch_status =
+ stream
+ ->ThenBlasGemmWithAlgorithm(
+ blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
+ transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
+ &c_ptr, n, computation_type, profile_algorithm,
+ &profile_result)
+ .ok();
+ if (cublas_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ }
+ }
+ }
+ // Try BlasGemmWithProfiling
+ bool cublas_launch_status =
+ stream
+ ->ThenBlasGemmWithProfiling(
+ blas_transpose_b, blas_transpose_a, n, m, k, 1.0, b_ptr,
+ transpose_b ? k : n, a_ptr, transpose_a ? m : k, 0.0,
+ &c_ptr, n, &profile_result)
+ .ok();
+ if (cublas_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ }
+ }
+ // Try BlasGemvWithProfiling
+ if (LaunchBlasGemv<T>::IsSupported() && n == 1) {
+ LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
+ transpose_a ? m : k, transpose_a ? k : m,
+ a_ptr, b_ptr, &c_ptr, &profile_result);
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ }
+ }
+ }
+ // We make sure that each matmul parameter set only gets one pass of
+ // autotune. If the best result is found, assign it to algorithm_type
+ // and insert it to autotune map. If all internal kernels of
+ // cublasGemmEx() returns invalid results, we add kNoAlgorithm to the
+ // autotune map.
+ if (best_result.is_valid()) {
+ algorithm_config.set_algorithm(best_result.algorithm());
+ }
+ AutoTuneMatmul::GetInstance()->Insert(matmul_parameters,
+ algorithm_config);
+ if (algorithm_config.algorithm() != kNoAlgorithm &&
+ algorithm_config.algorithm() != kDefaultBlasGemm &&
+ algorithm_config.algorithm() != kDefaultBlasGemv) {
+ bool cublas_launch_status =
+ stream
+ ->ThenBlasGemmWithAlgorithm(
+ blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
+ transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
+ &c_ptr, n, computation_type, algorithm_config.algorithm(),
+ nullptr)
+ .ok();
+ if (!cublas_launch_status) {
+ ctx->SetStatus(errors::Internal(
+ "Blas GEMM with algorithm launch failed : a.shape=(",
+ a.dim_size(0), ", ", a.dim_size(1), "), b.shape=(", b.dim_size(0),
+ ", ", b.dim_size(1), "), m=", m, ", n=", n, ", k=", k));
+ }
}
}
+ // For the following case, we use normal BlasGemm():
+ // 1) We didn't set the use_autotune flag;
+ // 2) compute type does not support autotune;
+ // 3) no algorithm is found;
+ // 4) all internal kernels in autotune return invalid results.
+ if (!use_autotune || !compute_type_supported || algorithms->empty() ||
+ algorithm_config.algorithm() == kNoAlgorithm ||
+ algorithm_config.algorithm() == kDefaultBlasGemm ||
+ algorithm_config.algorithm() == kDefaultBlasGemv) {
+ if (algorithm_config.algorithm() == kDefaultBlasGemv) {
+ // This is a matrix*vector multiply so use GEMV to compute A * b.
+ // Here we are multiplying in the natural order, so we have to flip
+ // the transposition flag to compensate for the tensor being stored
+ // row-major.
+ // TODO(yangzihao): Add Gemv as an autotuning option too.
+ LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
+ transpose_a ? m : k, transpose_a ? k : m,
+ a_ptr, b_ptr, &c_ptr, nullptr);
+ } else {
+ // Use C' = B' x A' (' stands for transpose)
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
+ 1.0f, b_ptr, transpose_b ? k : n, a_ptr,
+ transpose_a ? m : k, 0.0f, &c_ptr, n)
+ .ok();
+ if (!blas_launch_status) {
+ ctx->SetStatus(errors::Internal(
+ "Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ",
+ a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
+ "), m=", m, ", n=", n, ", k=", k));
+ }
+ }
+ }
+ }
+
+ static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
+ std::vector<int64>* algorithms,
+ bool* algorithm_set_flag) {
+ if (*algorithm_set_flag == false) {
+ auto* stream = ctx->device()->tensorflow_gpu_device_info()->stream;
+ stream->parent()->GetBlasGemmAlgorithms(algorithms);
+ *algorithm_set_flag = true;
+ }
}
};
@@ -257,9 +432,14 @@ struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
template <typename Device, typename T, bool USE_CUBLAS>
class MatMulOp : public OpKernel {
public:
- explicit MatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ explicit MatMulOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), algorithms_set_already_(false) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
+
+ LaunchMatMul<Device, T, USE_CUBLAS>::GetBlasGemmAlgorithm(
+ ctx, &algorithms_, &algorithms_set_already_);
+ use_autotune_ = MatmulAutotuneEnable();
}
void Compute(OpKernelContext* ctx) override {
@@ -302,10 +482,14 @@ class MatMulOp : public OpKernel {
return;
}
- LaunchMatMul<Device, T, USE_CUBLAS>::launch(ctx, this, a, b, dim_pair, out);
+ LaunchMatMul<Device, T, USE_CUBLAS>::launch(
+ ctx, a, b, dim_pair, &algorithms_, use_autotune_, out);
}
private:
+ std::vector<int64> algorithms_;
+ bool algorithms_set_already_;
+ bool use_autotune_;
bool transpose_a_;
bool transpose_b_;
};