aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yangzihao Wang <yangzihao@google.com>2017-07-21 09:22:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-21 09:38:31 -0700
commit3e3306ef0009b5b21050139f9b8e5f4868c4c0c7 (patch)
treec7e25f278d93e9ce1ab9e2984df7b97c0f27c6d0
parent4729180d24af3126d736a7045c43fcbf031b5bef (diff)
Let GetBlasGemmAlgorithms() always return true.
PiperOrigin-RevId: 162748507
-rw-r--r--tensorflow/contrib/fused_conv/BUILD2
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/kernels/BUILD14
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.h124
-rw-r--r--tensorflow/core/kernels/gpu_utils.h165
-rw-r--r--tensorflow/core/kernels/matmul_op.cc312
-rw-r--r--tensorflow/core/kernels/matmul_op.h64
-rw-r--r--tensorflow/core/util/matmul_autotune.cc51
-rw-r--r--tensorflow/core/util/matmul_autotune.h28
-rw-r--r--tensorflow/python/BUILD42
-rw-r--r--tensorflow/python/kernel_tests/matmul_op_test.py3
-rw-r--r--tensorflow/python/ops/matmul_benchmark.py143
-rw-r--r--tensorflow/python/ops/matmul_benchmark_test.py122
-rw-r--r--tensorflow/stream_executor/blas.cc4
-rw-r--r--tensorflow/stream_executor/blas.h138
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc177
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.h17
-rw-r--r--tensorflow/stream_executor/stream.cc178
-rw-r--r--tensorflow/stream_executor/stream.h63
19 files changed, 1456 insertions, 192 deletions
diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD
index 026ee3df07..5a9eeea70e 100644
--- a/tensorflow/contrib/fused_conv/BUILD
+++ b/tensorflow/contrib/fused_conv/BUILD
@@ -68,6 +68,7 @@ tf_kernel_library(
"//tensorflow/core/kernels:bounds_check_lib",
"//tensorflow/core/kernels:conv_2d_hdrs",
"//tensorflow/core/kernels:conv_ops_gpu_hdrs",
+ "//tensorflow/core/kernels:gpu_util_hdrs",
"//tensorflow/core/kernels:ops_util_hdrs",
"//third_party/eigen3",
],
@@ -86,6 +87,7 @@ tf_custom_op_library(
"//tensorflow/core/kernels:bounds_check_lib",
"//tensorflow/core/kernels:conv_2d_hdrs",
"//tensorflow/core/kernels:conv_ops_gpu_hdrs",
+ "//tensorflow/core/kernels:gpu_util_hdrs",
"//tensorflow/core/kernels:ops_util_hdrs",
],
)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 5885d4ed52..eea59e8e0e 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -407,6 +407,7 @@ tf_cuda_library(
"util/tensor_slice_reader_cache.h",
"util/tensor_slice_writer.h",
"util/use_cudnn.h",
+ "util/matmul_autotune.h",
"util/util.h",
"util/work_sharder.h",
] + select({
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 49aa3a11ba..e676d5a367 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -156,6 +156,7 @@ cc_library(
hdrs = ["conv_2d.h"],
deps = [
":eigen_helpers",
+ ":gpu_util_hdrs",
"//tensorflow/core:framework",
"//third_party/eigen3",
],
@@ -265,6 +266,15 @@ cc_library(
],
)
+cc_library(
+ name = "gpu_util_hdrs",
+ hdrs = ["gpu_utils.h"],
+ deps = [
+ ":eigen_helpers",
+ "//third_party/eigen3",
+ ],
+)
+
tf_cc_test(
name = "ops_util_test",
size = "small",
@@ -2424,7 +2434,9 @@ tf_kernel_library(
],
"//conditions:default": [],
}),
- deps = MATH_DEPS + select({
+ deps = MATH_DEPS + [
+ ":gpu_util_hdrs",
+ ] + select({
":xsmm": [
"@libxsmm_archive//:xsmm_avx",
],
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index 34d50fdc27..83ef4f01ca 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -21,28 +21,12 @@ limitations under the License.
#include <tuple>
#include <unordered_map>
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/stream_executor.h"
namespace tensorflow {
-namespace dnn = ::perftools::gputools::dnn;
-
-// TODO(zhengxq): move this to gpu_util.h. The use of such wrappers is wide
-// spread.
-template <typename T>
-inline perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
- uint64 size) {
- perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
- size * sizeof(T));
- perftools::gputools::DeviceMemory<T> typed(wrapped);
- return typed;
-}
// Get the Cudnn workspace limit from the environment variable, which is in MB.
// Return the workspace memory limit in bytes. If no value is set, return the
@@ -185,112 +169,6 @@ class ConvParameters {
typedef Eigen::GpuDevice GPUDevice;
-// A helper class that looks up the best autotuned config from parameters.
-// Due to the noisy nature of autotune, especially with multiple devices, it
-// only accepts a config if its margin exceeds a threshold.
-// For the same shape configs, if a new best config matches the previous best,
-// they get promoted; otherwise, the winner gets demoted. This process stops
-// when the winner's score exceeds the threshold.
-// In a bad case when two configs are very close to each other and flips
-// back and forth randomly, the expected number of experiments before autotune
-// settles is O(threshold ^ 2). So we recommend that number of warmup runs
-// for any benchmarks.
-template <typename Parameters, typename Config>
-class AutoTuneMap {
- public:
- bool Find(const Parameters& params, Config* config) const {
- mutex_lock lock(mu_);
- auto iter = params_config_map_.find(params);
- if (iter == params_config_map_.end() ||
- iter->second.score < min_score_threshold_) {
- return false;
- }
- *config = iter->second.config;
- return true;
- }
- void Insert(const ConvParameters& params, const Config& config) {
- mutex_lock lock(mu_);
- auto iter = params_config_map_.find(params);
- int new_score = 0;
- if (iter == params_config_map_.end()) {
- // Create a new entry if params is new.
- VLOG(1) << GetActionSummary("creates", params, config);
- params_config_map_.insert(std::make_pair(params, ValueType{config, 1}));
- new_score = 1;
- } else if (iter->second.score < min_score_threshold_) {
- DCHECK_GT(iter->second.score, 0);
- if (iter->second.config != config) {
- // If it is different from the current winner, demotes the winner.
- VLOG(1) << GetActionSummary("demotes", params, config);
- new_score = --iter->second.score;
- if (new_score <= 0) {
- VLOG(1) << GetActionSummary("erases", params, config);
- params_config_map_.erase(iter);
- }
- } else {
- // If it is the same as the current winner, promotes the winner.
- VLOG(1) << GetActionSummary("promotes", params, config);
- new_score = ++iter->second.score;
- }
- }
- if (new_score >= min_score_threshold_) {
- VLOG(1) << GetActionSummary("accepts", params, config);
- }
- }
-
- private:
- AutoTuneMap(const string& name) : name_(name) {
- min_score_threshold_ = 1;
- const char* threshold_str = getenv("TF_AUTOTUNE_THRESHOLD");
- if (threshold_str != nullptr) {
- strings::safe_strto32(threshold_str, &min_score_threshold_);
- }
- min_score_threshold_ = std::max(min_score_threshold_, 1);
- }
-
- template <class Group, class Params, class Cfg>
- friend class AutoTuneSingleton;
-
- struct Hasher {
- std::size_t operator()(const Parameters& parameter) const {
- return parameter.hash();
- }
- };
-
- string GetActionSummary(StringPiece action, const Parameters& params,
- const Config& config) {
- return strings::Printf("autotune_map %s %s: %s -> (%s)", name_.c_str(),
- action.ToString().c_str(), params.ToString().c_str(),
- config.ToString().c_str());
- }
-
- mutable mutex mu_;
- struct ValueType {
- Config config;
- int32 score;
- };
- std::unordered_map<Parameters, ValueType, Hasher> params_config_map_
- GUARDED_BY(mu_);
- string name_;
- int32 min_score_threshold_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(AutoTuneMap);
-};
-
-// A Singleton helper that manages the global autotune results by groups.
-// The caller specified arbitrary Group type that can distinguish between
-// different autotune results, even if their Parameters and Configs are the
-// same.
-template <class Group, typename Parameters, typename Config>
-class AutoTuneSingleton {
- public:
- typedef AutoTuneMap<Parameters, Config> AutoTuneType;
- static AutoTuneType* GetInstance() {
- static AutoTuneType* instance = new AutoTuneType(Group::name());
- return instance;
- }
-};
-
} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h
new file mode 100644
index 0000000000..366877bcf5
--- /dev/null
+++ b/tensorflow/core/kernels/gpu_utils.h
@@ -0,0 +1,165 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
+
+#if GOOGLE_CUDA
+
+#include <unordered_map>
+
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor.h"
+
+namespace tensorflow {
+
+template <typename T>
+inline perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
+ uint64 size) {
+ perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
+ size * sizeof(T));
+ perftools::gputools::DeviceMemory<T> typed(wrapped);
+ return typed;
+}
+
+// A helper class that looks up the best autotuned config from parameters.
+// Due to the noisy nature of autotune, especially with multiple devices, it
+// only accepts a config if its margin exceeds a threshold.
+// For the same shape configs, if a new best config matches the previous best,
+// they get promoted; otherwise, the winner gets demoted. This process stops
+// when the winner's score exceeds the threshold.
+// In a bad case when two configs are very close to each other and flips
+// back and forth randomly, the expected number of experiments before autotune
+// settles is O(threshold ^ 2). So we recommend that number of warmup runs
+// for any benchmarks.
+template <typename Parameters, typename Config>
+class AutoTuneMap {
+ public:
+ bool Find(const Parameters& params, Config* config) const {
+ mutex_lock lock(mu_);
+ auto iter = params_config_map_.find(params);
+ if (iter == params_config_map_.end() ||
+ (iter->second.score < min_score_threshold_ &&
+ iter->second.count <= max_autotune_count_)) {
+ return false;
+ }
+ *config = iter->second.config;
+ return true;
+ }
+ void Insert(const Parameters& params, const Config& config) {
+ mutex_lock lock(mu_);
+ auto iter = params_config_map_.find(params);
+ int new_score = 0;
+ if (iter == params_config_map_.end()) {
+ // Create a new entry if params is new.
+ VLOG(1) << GetActionSummary("creates", params, config);
+ params_config_map_.insert(
+ std::make_pair(params, ValueType{config, 1, 1}));
+ new_score = 1;
+ } else if (iter->second.score < min_score_threshold_ &&
+ iter->second.count <= max_autotune_count_) {
+ DCHECK_GT(iter->second.score, 0);
+ if (iter->second.config != config) {
+ // If it is different from the current winner, demotes the winner.
+ VLOG(1) << GetActionSummary("demotes", params, config);
+ new_score = --iter->second.score;
+ ++iter->second.count;
+ if (new_score <= 0) {
+ VLOG(1) << GetActionSummary("erases", params, config);
+ params_config_map_.erase(iter);
+ }
+ } else {
+ // If it is the same as the current winner, promotes the winner.
+ VLOG(1) << GetActionSummary("promotes", params, config);
+ new_score = ++iter->second.score;
+ ++iter->second.count;
+ }
+ }
+ if (new_score >= min_score_threshold_) {
+ VLOG(1) << GetActionSummary("accepts", params, config);
+ }
+ }
+
+ private:
+ AutoTuneMap(const string& name) : name_(name) {
+ min_score_threshold_ = 1;
+ int min_warmup_iterations = 10;
+ const char* threshold_str = getenv("TF_AUTOTUNE_THRESHOLD");
+ if (threshold_str != nullptr) {
+ strings::safe_strto32(threshold_str, &min_score_threshold_);
+ }
+ const char* min_warmup_iteration_str =
+ getenv("TF_AUTOTUNE_MIN_WARMUP_ITERATIONS");
+ if (min_warmup_iteration_str != nullptr) {
+ strings::safe_strto32(min_warmup_iteration_str, &min_warmup_iterations);
+ }
+ min_score_threshold_ = std::max(min_score_threshold_, 1);
+ max_autotune_count_ = std::max(
+ 5 * min_score_threshold_ * min_score_threshold_, min_warmup_iterations);
+ }
+
+ template <class Group, class Params, class Cfg>
+ friend class AutoTuneSingleton;
+
+ struct Hasher {
+ std::size_t operator()(const Parameters& parameter) const {
+ return parameter.hash();
+ }
+ };
+
+ string GetActionSummary(StringPiece action, const Parameters& params,
+ const Config& config) {
+ return strings::Printf("autotune_map %s %s: %s -> (%s)", name_.c_str(),
+ action.ToString().c_str(), params.ToString().c_str(),
+ config.ToString().c_str());
+ }
+
+ mutable mutex mu_;
+ struct ValueType {
+ Config config;
+ int32 score;
+ int32 count;
+ };
+ std::unordered_map<Parameters, ValueType, Hasher> params_config_map_
+ GUARDED_BY(mu_);
+ string name_;
+ int32 min_score_threshold_;
+ int32 max_autotune_count_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(AutoTuneMap);
+};
+
+// A Singleton helper that manages the global autotune results by groups.
+// The caller specified arbitrary Group type that can distinguish between
+// different autotune results, even if their Parameters and Configs are the
+// same.
+template <class Group, typename Parameters, typename Config>
+class AutoTuneSingleton {
+ public:
+ typedef AutoTuneMap<Parameters, Config> AutoTuneType;
+ static AutoTuneType* GetInstance() {
+ static AutoTuneType* instance = new AutoTuneType(Group::name());
+ return instance;
+ }
+};
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index 8003f7ff67..62c5ecfe81 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -23,27 +23,15 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/fill_functor.h"
-
+#include "tensorflow/core/util/matmul_autotune.h"
#if GOOGLE_CUDA
#include "cuda/include/cuda.h"
+#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
-#if GOOGLE_CUDA
-
-namespace {
-template <typename T>
-perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
- perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
- perftools::gputools::DeviceMemory<T> typed(wrapped);
- return typed;
-}
-} // namespace
-
-#endif // GOOGLE_CUDA
-
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
@@ -123,10 +111,16 @@ bool ExplicitVectorMatrixOptimization<Eigen::half>(
template <typename Device, typename T>
struct LaunchMatMulBase {
+#if GOOGLE_CUDA
+ typedef perftools::gputools::blas::AlgorithmType AlgorithmType;
+#else
+ typedef int64 AlgorithmType;
+#endif // GOOGLE_CUDA
+
static void launch(
- OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b,
+ OpKernelContext* ctx, const Tensor& a, const Tensor& b,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
- Tensor* out) {
+ std::vector<AlgorithmType>* algorithms, bool use_aututone, Tensor* out) {
#ifndef TENSORFLOW_USE_SYCL
// An explicit vector-matrix multiply is much better optimized than an
// implicit one and this is a bottleneck during non-batched inference.
@@ -140,6 +134,10 @@ struct LaunchMatMulBase {
}
#endif // TENSORFLOW_USE_SYCL
}
+
+ static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
+ std::vector<int64>* algorithms,
+ bool* algorithm_set_flag) {}
};
// On CPUs, we ignore USE_CUBLAS
template <typename T>
@@ -159,24 +157,39 @@ struct LaunchMatMul<SYCLDevice, T, USE_CUBLAS> : public LaunchMatMulSYCL<T> {};
#if GOOGLE_CUDA
namespace {
+
template <typename T>
struct LaunchBlasGemv {
- static void Compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
- bool trans, uint64 m, uint64 n,
- const perftools::gputools::DeviceMemory<T>& a,
- const perftools::gputools::DeviceMemory<T>& b,
- perftools::gputools::DeviceMemory<T>* c) {
+ static void Compute(
+ OpKernelContext* ctx, perftools::gputools::Stream* stream, bool trans,
+ uint64 m, uint64 n, const perftools::gputools::DeviceMemory<T>& a,
+ const perftools::gputools::DeviceMemory<T>& b,
+ perftools::gputools::DeviceMemory<T>* c,
+ perftools::gputools::blas::ProfileResult* output_profile) {
const auto blas_trans =
trans ? perftools::gputools::blas::Transpose::kTranspose
: perftools::gputools::blas::Transpose::kNoTranspose;
- bool blas_launch_status =
- stream
- ->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1,
- static_cast<T>(0.0), c, 1)
- .ok();
- if (!blas_launch_status) {
- ctx->SetStatus(
- errors::Internal("Blas GEMV launch failed: m=", m, ", n=", n));
+ if (output_profile == nullptr) {
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1,
+ static_cast<T>(0.0), c, 1)
+ .ok();
+ if (!blas_launch_status) {
+ ctx->SetStatus(
+ errors::Internal("Blas GEMV launch failed: m=", m, ", n=", n));
+ }
+ } else {
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemvWithProfiling(blas_trans, m, n, static_cast<T>(1.0),
+ a, m, b, 1, static_cast<T>(0.0), c, 1,
+ output_profile)
+ .ok();
+ if (!blas_launch_status) {
+ ctx->SetStatus(errors::Internal(
+ "Blas GEMV with profiling launch failed: m=", m, ", n=", n));
+ }
}
}
@@ -188,7 +201,8 @@ void LaunchBlasGemv<Eigen::half>::Compute(
OpKernelContext* ctx, perftools::gputools::Stream* stream, bool trans,
uint64 m, uint64 n, const perftools::gputools::DeviceMemory<Eigen::half>& a,
const perftools::gputools::DeviceMemory<Eigen::half>& b,
- perftools::gputools::DeviceMemory<Eigen::half>* c) {
+ perftools::gputools::DeviceMemory<Eigen::half>* c,
+ perftools::gputools::blas::ProfileResult* output_profile) {
ctx->SetStatus(errors::Internal(
"Blas GEMV launch failed: GEMV is not implemented for float16."));
}
@@ -200,15 +214,55 @@ bool LaunchBlasGemv<Eigen::half>::IsSupported() {
} // namespace
+bool GetCublasAutotuneComputationType(
+ const DataType& dtype,
+ perftools::gputools::blas::ComputationType* compute_type) {
+ using perftools::gputools::blas::ComputationType;
+ bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input();
+ switch (dtype) {
+ case DT_HALF:
+ case DT_BFLOAT16:
+ if (use_f32_for_f16_computation) {
+ *compute_type = ComputationType::kF32;
+ } else {
+ *compute_type = ComputationType::kF16;
+ }
+ return false;
+ case DT_FLOAT:
+ *compute_type = ComputationType::kF32;
+ return true;
+ case DT_DOUBLE:
+ *compute_type = ComputationType::kF64;
+ return true;
+ default:
+ // Unsupported compute_type, return false.
+ return false;
+ }
+}
+
+// A dummy type to group matmul autotune results together.
+struct MatmulAutoTuneGroup {
+ static string name() { return "Matmul"; }
+};
+typedef AutoTuneSingleton<MatmulAutoTuneGroup, MatmulParameters,
+ perftools::gputools::blas::AlgorithmConfig>
+ AutoTuneMatmul;
+
template <typename T>
struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
static void launch(
- OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b,
+ OpKernelContext* ctx, const Tensor& a, const Tensor& b,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
- Tensor* out) {
- perftools::gputools::blas::Transpose trans[] = {
- perftools::gputools::blas::Transpose::kNoTranspose,
- perftools::gputools::blas::Transpose::kTranspose};
+ std::vector<int64>* algorithms, bool use_autotune, Tensor* out) {
+ using perftools::gputools::blas::AlgorithmConfig;
+ using perftools::gputools::blas::ComputationType;
+ using perftools::gputools::blas::ProfileResult;
+ using perftools::gputools::blas::Transpose;
+ using perftools::gputools::blas::kDefaultAlgorithm;
+ using perftools::gputools::blas::kDefaultBlasGemm;
+ using perftools::gputools::blas::kDefaultBlasGemv;
+ using perftools::gputools::blas::kNoAlgorithm;
+ Transpose trans[] = {Transpose::kNoTranspose, Transpose::kTranspose};
const uint64 m = a.dim_size(1 - dim_pair[0].first);
const uint64 k = a.dim_size(dim_pair[0].first);
const uint64 n = b.dim_size(1 - dim_pair[0].second);
@@ -220,35 +274,156 @@ struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
auto* stream = ctx->op_device_context()->stream();
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
- auto a_ptr = AsDeviceMemory(a.template flat<T>().data());
- auto b_ptr = AsDeviceMemory(b.template flat<T>().data());
- auto c_ptr = AsDeviceMemory(out->template flat<T>().data());
- // Cublas does
- // C = A x B
- // where A, B and C are assumed to be in column major.
- // We want the output to be in row-major, so we can compute
- // C' = B' x A' (' stands for transpose)
- if (LaunchBlasGemv<T>::IsSupported() && n == 1) {
- // This is a matrix*vector multiply so use GEMV to compute A * b.
- // Here we are multiplying in the natural order, so we have to flip
- // the transposition flag to compensate for the tensor being stored
- // row-major.
- LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a, transpose_a ? m : k,
- transpose_a ? k : m, a_ptr, b_ptr, &c_ptr);
- } else {
- bool blas_launch_status =
- stream
- ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, 1.0f,
- b_ptr, transpose_b ? k : n, a_ptr,
- transpose_a ? m : k, 0.0f, &c_ptr, n)
- .ok();
- if (!blas_launch_status) {
- ctx->SetStatus(errors::Internal(
- "Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ",
- a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
- "), m=", m, ", n=", n, ", k=", k));
+ auto a_ptr = AsDeviceMemory(a.template flat<T>().data(),
+ a.template flat<T>().size());
+ auto b_ptr = AsDeviceMemory(b.template flat<T>().data(),
+ b.template flat<T>().size());
+ auto c_ptr = AsDeviceMemory(out->template flat<T>().data(),
+ out->template flat<T>().size());
+ auto alpha = static_cast<T>(1.0);
+ auto beta = static_cast<T>(0.0);
+
+ int device_id = stream->parent()->device_ordinal();
+ DataType dtype = a.dtype();
+ MatmulParameters matmul_parameters = {
+ transpose_a, transpose_b, m, n, k, dtype, device_id,
+ };
+ AlgorithmConfig algorithm_config(kNoAlgorithm);
+
+ ComputationType computation_type;
+ bool compute_type_supported =
+ GetCublasAutotuneComputationType(dtype, &computation_type);
+ if (use_autotune && compute_type_supported && !algorithms->empty()) {
+ ProfileResult best_result;
+ // TODO(yangzihao): Unify this code with conv autotuning.
+ if (!AutoTuneMatmul::GetInstance()->Find(matmul_parameters,
+ &algorithm_config)) {
+ ProfileResult profile_result;
+ for (auto profile_algorithm : (*algorithms)) {
+ // Cublas does
+ // C = A x B
+ // where A, B and C are assumed to be in column major.
+ // We want the output to be in row-major, so we can compute
+ // C' = B' x A' (' stands for transpose)
+ bool cublas_launch_status =
+ stream
+ ->ThenBlasGemmWithAlgorithm(
+ blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
+ transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
+ &c_ptr, n, computation_type, profile_algorithm,
+ &profile_result)
+ .ok();
+ if (cublas_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ }
+ }
+ }
+ // Try BlasGemmWithProfiling
+ bool cublas_launch_status =
+ stream
+ ->ThenBlasGemmWithProfiling(
+ blas_transpose_b, blas_transpose_a, n, m, k, 1.0, b_ptr,
+ transpose_b ? k : n, a_ptr, transpose_a ? m : k, 0.0,
+ &c_ptr, n, &profile_result)
+ .ok();
+ if (cublas_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ }
+ }
+ // Try BlasGemvWithProfiling
+ if (LaunchBlasGemv<T>::IsSupported() && n == 1) {
+ LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
+ transpose_a ? m : k, transpose_a ? k : m,
+ a_ptr, b_ptr, &c_ptr, &profile_result);
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ }
+ }
+ }
+ // We make sure that each matmul parameter set only gets one pass of
+ // autotune. If the best result is found, assign it to algorithm_type
+ // and insert it to autotune map. If all internal kernels of
+ // cublasGemmEx() returns invalid results, we add kNoAlgorithm to the
+ // autotune map.
+ if (best_result.is_valid()) {
+ algorithm_config.set_algorithm(best_result.algorithm());
+ }
+ AutoTuneMatmul::GetInstance()->Insert(matmul_parameters,
+ algorithm_config);
+ if (algorithm_config.algorithm() != kNoAlgorithm &&
+ algorithm_config.algorithm() != kDefaultBlasGemm &&
+ algorithm_config.algorithm() != kDefaultBlasGemv) {
+ bool cublas_launch_status =
+ stream
+ ->ThenBlasGemmWithAlgorithm(
+ blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
+ transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
+ &c_ptr, n, computation_type, algorithm_config.algorithm(),
+ nullptr)
+ .ok();
+ if (!cublas_launch_status) {
+ ctx->SetStatus(errors::Internal(
+ "Blas GEMM with algorithm launch failed : a.shape=(",
+ a.dim_size(0), ", ", a.dim_size(1), "), b.shape=(", b.dim_size(0),
+ ", ", b.dim_size(1), "), m=", m, ", n=", n, ", k=", k));
+ }
}
}
+ // For the following case, we use normal BlasGemm():
+ // 1) We didn't set the use_autotune flag;
+ // 2) compute type does not support autotune;
+ // 3) no algorithm is found;
+ // 4) all internal kernels in autotune return invalid results.
+ if (!use_autotune || !compute_type_supported || algorithms->empty() ||
+ algorithm_config.algorithm() == kNoAlgorithm ||
+ algorithm_config.algorithm() == kDefaultBlasGemm ||
+ algorithm_config.algorithm() == kDefaultBlasGemv) {
+ if (algorithm_config.algorithm() == kDefaultBlasGemv) {
+ // This is a matrix*vector multiply so use GEMV to compute A * b.
+ // Here we are multiplying in the natural order, so we have to flip
+ // the transposition flag to compensate for the tensor being stored
+ // row-major.
+ // TODO(yangzihao): Add Gemv as an autotuning option too.
+ LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
+ transpose_a ? m : k, transpose_a ? k : m,
+ a_ptr, b_ptr, &c_ptr, nullptr);
+ } else {
+ // Use C' = B' x A' (' stands for transpose)
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
+ 1.0f, b_ptr, transpose_b ? k : n, a_ptr,
+ transpose_a ? m : k, 0.0f, &c_ptr, n)
+ .ok();
+ if (!blas_launch_status) {
+ ctx->SetStatus(errors::Internal(
+ "Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ",
+ a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
+ "), m=", m, ", n=", n, ", k=", k));
+ }
+ }
+ }
+ }
+
+ static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
+ std::vector<int64>* algorithms,
+ bool* algorithm_set_flag) {
+ if (*algorithm_set_flag == false) {
+ auto* stream = ctx->device()->tensorflow_gpu_device_info()->stream;
+ stream->parent()->GetBlasGemmAlgorithms(algorithms);
+ *algorithm_set_flag = true;
+ }
}
};
@@ -257,9 +432,14 @@ struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
template <typename Device, typename T, bool USE_CUBLAS>
class MatMulOp : public OpKernel {
public:
- explicit MatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ explicit MatMulOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), algorithms_set_already_(false) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
+
+ LaunchMatMul<Device, T, USE_CUBLAS>::GetBlasGemmAlgorithm(
+ ctx, &algorithms_, &algorithms_set_already_);
+ use_autotune_ = MatmulAutotuneEnable();
}
void Compute(OpKernelContext* ctx) override {
@@ -302,10 +482,14 @@ class MatMulOp : public OpKernel {
return;
}
- LaunchMatMul<Device, T, USE_CUBLAS>::launch(ctx, this, a, b, dim_pair, out);
+ LaunchMatMul<Device, T, USE_CUBLAS>::launch(
+ ctx, a, b, dim_pair, &algorithms_, use_autotune_, out);
}
private:
+ std::vector<int64> algorithms_;
+ bool algorithms_set_already_;
+ bool use_autotune_;
bool transpose_a_;
bool transpose_b_;
};
diff --git a/tensorflow/core/kernels/matmul_op.h b/tensorflow/core/kernels/matmul_op.h
index 5a8db6da19..6398da2fb9 100644
--- a/tensorflow/core/kernels/matmul_op.h
+++ b/tensorflow/core/kernels/matmul_op.h
@@ -17,7 +17,9 @@ limitations under the License.
#define TENSORFLOW_KERNELS_MATMUL_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
namespace functor {
@@ -50,6 +52,68 @@ struct MatMulFunctor {
};
} // end namespace functor
+
+#if GOOGLE_CUDA
+// Encapsulate all the shape information that is used in matmul operations.
+class MatmulParameters {
+ public:
+ MatmulParameters(bool transa, bool transb, uint64 m, uint64 n, uint64 k,
+ DataType dtype, int device_id)
+ : transa_(transa),
+ transb_(transb),
+ m_(m),
+ n_(n),
+ k_(k),
+ dtype_(dtype),
+ device_id_(device_id) {
+ hash_code_ = transa;
+ hash_code_ = Hash64Combine(hash_code_, transb);
+ hash_code_ = Hash64Combine(hash_code_, m);
+ hash_code_ = Hash64Combine(hash_code_, n);
+ hash_code_ = Hash64Combine(hash_code_, k);
+ hash_code_ = Hash64Combine(hash_code_, dtype);
+ hash_code_ = Hash64Combine(hash_code_, device_id);
+ }
+ bool operator==(const MatmulParameters& other) const {
+ return this->get_data_as_tuple() == other.get_data_as_tuple();
+ }
+
+ bool operator!=(const MatmulParameters& other) const {
+ return !(*this == other);
+ }
+ uint64 hash() const { return hash_code_; }
+
+ string ToString() const {
+ // clang-format off
+ return strings::StrCat(
+ transa_, ", ", transb_, ", ",
+ m_, ", ", n_, ", ", k_,
+ dtype_, ", ", device_id_);
+ // clang-format on
+ }
+
+ private:
+ typedef std::tuple<bool, bool, int64, int64, int64, DataType, int>
+ ParameterDataType;
+
+ ParameterDataType get_data_as_tuple() const {
+ return std::make_tuple(transa_, transb_, m_, n_, k_, dtype_, device_id_);
+ }
+
+ bool transa_;
+ bool transb_;
+ uint64 m_;
+ uint64 n_;
+ uint64 k_;
+ DataType dtype_;
+ int device_id_;
+ uint64 hash_code_;
+};
+
+typedef Eigen::GpuDevice GPUDevice;
+
+#endif // GOOGLE_CUDA
+
} // end namespace tensorflow
#endif // TENSORFLOW_KERNELS_MATMUL_OP_H_
diff --git a/tensorflow/core/util/matmul_autotune.cc b/tensorflow/core/util/matmul_autotune.cc
new file mode 100644
index 0000000000..741a78a193
--- /dev/null
+++ b/tensorflow/core/util/matmul_autotune.cc
@@ -0,0 +1,51 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/util/matmul_autotune.h"
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/util/env_var.h"
+
+namespace tensorflow {
+bool MatmulAutotuneEnable() {
+ bool value;
+ Status status =
+ ReadBoolFromEnvVar("TF_MATMUL_AUTOTUNE_ENABLE", false, &value);
+ if (!status.ok()) {
+ LOG(ERROR) << status.error_message();
+ }
+ return value;
+}
+
+bool MatmulDoFP32ComputationFP16Input() {
+ bool value;
+ // Feedback from NVIDIA: the "true floating point 16" compute capability is
+ // absent from compute capability SM 5.2. The native 16 bit floating point
+ // computation was introduced in SM 5.3 and higher compute capability. So
+ // for compatibility, set this to be true by default for now.
+ // TODO(yangzihao): In the future, we need to return three possibilities:
+ // user-set-true, user-set-false, user-no-setting. In the calling sites,
+ // check the compatibilities. Note that user-set-false with compute
+ // capability <= 5.2 will cause an error in the later cublasGemmEx() call.
+ Status status =
+ ReadBoolFromEnvVar("TF_FP16_MATMUL_USE_FP32_COMPUTE", true, &value);
+ if (!status.ok()) {
+ LOG(ERROR) << status.error_message();
+ }
+ return value;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/matmul_autotune.h b/tensorflow/core/util/matmul_autotune.h
new file mode 100644
index 0000000000..5366623883
--- /dev/null
+++ b/tensorflow/core/util/matmul_autotune.h
@@ -0,0 +1,28 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// The utility to check matmul autotune related flags.
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_
+
+namespace tensorflow {
+
+bool MatmulAutotuneEnable();
+bool MatmulDoFP32ComputationFP16Input();
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 7882e088d0..65ca4be547 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3859,6 +3859,48 @@ cuda_py_test(
)
cuda_py_test(
+ name = "matmul_benchmark",
+ size = "medium",
+ srcs = ["ops/matmul_benchmark.py"],
+ additional_deps = [
+ ":math_ops",
+ ":random_ops",
+ ":client",
+ ":client_testlib",
+ ":control_flow_ops",
+ ":framework_for_generated_wrappers",
+ ":framework_test_lib",
+ ":platform",
+ ":platform_benchmark",
+ ":variables",
+ "//third_party/py/numpy",
+ "//tensorflow/core:protos_all_py",
+ ],
+ main = "ops/matmul_benchmark.py",
+)
+
+cuda_py_test(
+ name = "matmul_benchmark_test",
+ size = "medium",
+ srcs = ["ops/matmul_benchmark_test.py"],
+ additional_deps = [
+ ":math_ops",
+ ":random_ops",
+ ":client",
+ ":client_testlib",
+ ":control_flow_ops",
+ ":framework_for_generated_wrappers",
+ ":platform",
+ ":platform_benchmark",
+ ":matmul_benchmark",
+ ":variables",
+ "//third_party/py/numpy",
+ "//tensorflow/core:protos_all_py",
+ ],
+ main = "ops/matmul_benchmark_test.py",
+)
+
+cuda_py_test(
name = "session_benchmark",
srcs = ["client/session_benchmark.py"],
additional_deps = [
diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py
index 042f462357..b167278984 100644
--- a/tensorflow/python/kernel_tests/matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/matmul_op_test.py
@@ -31,6 +31,9 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test as test_lib
+# TODO(yangzihao): Currently matmul autotuning is disabled by default. Use
+# os.environ["TF_MATMUL_AUTOTUNE_ENABLE"] = "1" to enable it.
+
def _AddTest(test, op_name, testcase_name, fn):
test_name = "_".join(["test", op_name, testcase_name])
diff --git a/tensorflow/python/ops/matmul_benchmark.py b/tensorflow/python/ops/matmul_benchmark.py
new file mode 100644
index 0000000000..55c575162a
--- /dev/null
+++ b/tensorflow/python/ops/matmul_benchmark.py
@@ -0,0 +1,143 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmark for Matmul operator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+import time
+
+import numpy as np
+
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def build_graph(device, n, m, k, transpose_a, transpose_b, dtype):
+ """Build a graph containing a sequence of matmul operations.
+
+ Args:
+ device: String, the device to run on.
+ n: tensor A's first dimension size.
+ m: tensor A's second dimension size.
+ k: tensor B's second dimension size.
+ transpose_a: boolean value to show if tensor A is transposed.
+ transpose_b: boolean value to show if tensor B is transposed.
+ dtype: numpy data type of the input tensor.
+
+ Returns:
+ A matmul operation to run()
+ """
+ with ops.device('/%s:0' % device):
+ if not transpose_a:
+ x = variables.Variable(random_ops.random_uniform([n, m], dtype=dtype))
+ else:
+ x = variables.Variable(random_ops.random_uniform([m, n], dtype=dtype))
+ if not transpose_b:
+ y = variables.Variable(random_ops.random_uniform([m, k], dtype=dtype))
+ else:
+ y = variables.Variable(random_ops.random_uniform([k, m], dtype=dtype))
+
+ z = math_ops.matmul(x, y, transpose_a=transpose_a, transpose_b=transpose_b)
+ return control_flow_ops.group(z)
+
+
+class MatmulBenchmark(test.Benchmark):
+ """Benchmark matmul!"""
+
+ def run_graph(self, device, n, m, k, transpose_a, transpose_b, num_iters,
+ dtype):
+ """Run the graph and print its execution time.
+
+ Args:
+ device: String, the device to run on.
+ n: tensor A's first dimension size.
+ m: tensor A's second dimension size.
+ k: tensor B's second dimension size.
+ transpose_a: boolean value to show if tensor A is transposed.
+ transpose_b: boolean value to show if tensor B is transposed.
+ num_iters: number of iterations to run the benchmark.
+ dtype: numpy data type of the input tensor.
+
+ Returns:
+ The duration of the run in seconds.
+ """
+ graph = ops.Graph()
+ with graph.as_default():
+ output = build_graph(device, n, m, k, transpose_a, transpose_b, dtype)
+ with session_lib.Session(graph=graph) as session:
+ variables.global_variables_initializer().run()
+ for _ in range(500):
+ session.run(output)
+ start_time = time.time()
+ for _ in range(num_iters):
+ session.run(output)
+ duration = (time.time() - start_time)
+ num_items = n * m * k * 2
+ throughput = num_items * num_iters / duration / 1e9
+ print('%s %s input_info:%s %d %.4fsec, %.4fGitems/s.' %
+ (device, str(dtype), str(n) + 'x' + str(m) + 'x' + str(k) + ',ta:'
+ + str(transpose_a) + '.tb:' + str(transpose_b), num_iters,
+ duration, throughput))
+
+ name_template = ('matmul_{device}_{dtype}_input_info_{inputinfo}')
+
+ self.report_benchmark(
+ name=name_template.format(
+ device=device,
+ dtype=str(dtype).replace(' ', ''),
+ inputinfo=str(n) + 'x' + str(m) + 'x' + str(k) + ',ta:' +
+ str(transpose_a) + '.tb:' + str(transpose_b)).replace(' ', ''),
+ iters=num_iters,
+ wall_time=duration)
+ return duration
+
+ def run_test_gpu(self, n, m, k, transpose_a, transpose_b, dtype, num_iters):
+ self.run_graph('gpu', n, m, k, transpose_a, transpose_b, num_iters, dtype)
+
+ def test_round(self, num_iters):
+ dtypes = [np.float32, np.float64]
+ for dtype in dtypes:
+ for n, m, (transpose_a, transpose_b) in itertools.product(
+ [512, 1024], [1, 8, 16, 128], [(False, False), (True, False),
+ (False, True)]):
+ k = n
+ self.run_test_gpu(n, m, k, transpose_a, transpose_b, dtype, num_iters)
+
+ for n, m, k, (transpose_a, transpose_b) in itertools.product(
+ [200], [1, 8, 20], [10000], [(False, False), (True, False), (False,
+ True)]):
+ self.run_test_gpu(n, m, k, transpose_a, transpose_b, dtype, num_iters)
+
+ for (n, m, k), (transpose_a, transpose_b) in itertools.product(
+ [(200, 20, 20000), (1, 10000, 200)], [(False, False), (True, False),
+ (False, True)]):
+ self.run_test_gpu(n, m, k, transpose_a, transpose_b, dtype, num_iters)
+
+ def benchmark_matmul(self):
+ num_iters = 200
+ for _ in range(10):
+ self.test_round(num_iters)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/ops/matmul_benchmark_test.py b/tensorflow/python/ops/matmul_benchmark_test.py
new file mode 100644
index 0000000000..a7914dba78
--- /dev/null
+++ b/tensorflow/python/ops/matmul_benchmark_test.py
@@ -0,0 +1,122 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for matmul_benchmark.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+import numpy as np
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.framework import node_def_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import matmul_benchmark
+from tensorflow.python.platform import test as googletest
+from tensorflow.python.platform import tf_logging
+
+
+def BuildGraphTest(n, m, k, transpose_a, transpose_b, dtype):
+
+ def Test(self):
+ if not googletest.is_gpu_available():
+ tf_logging.info("Skipping BuildGraphTest %s", (n, m, k, transpose_a,
+ transpose_b))
+ return
+ tf_logging.info("Testing BuildGraphTest %s", (n, m, k, transpose_a,
+ transpose_b))
+ self._VerifyBuildGraph(n, m, k, transpose_a, transpose_b, dtype)
+
+ return Test
+
+
+def RunGraphTest(n, m, k, transpose_a, transpose_b, dtype):
+
+ def Test(self):
+ if not googletest.is_gpu_available():
+ tf_logging.info("Skipping RunGraphTest %s", (n, m, k, transpose_a,
+ transpose_b))
+ return
+ tf_logging.info("Testing RunGraphTest %s", (n, m, k, transpose_a,
+ transpose_b))
+ self._VerifyRunGraph(n, m, k, transpose_a, transpose_b, dtype)
+
+ return Test
+
+
+class MatmulBenchmarkTest(googletest.TestCase):
+
+ def _StripNode(self, nd):
+ snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input)
+ if nd.device:
+ snode.device = nd.device
+ return snode
+
+ def _StripGraph(self, gd):
+ return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node])
+
+ def _VerifyBuildGraph(self, n, m, k, transpose_a, transpose_b, dtype):
+ graph = ops.Graph()
+ with graph.as_default():
+ matmul_benchmark.build_graph("gpu", n, m, k, transpose_a, transpose_b,
+ dtype)
+ gd = graph.as_graph_def()
+ self.assertProtoEquals("""
+ node { name: "random_uniform/shape" op: "Const" device: "/device:GPU:0" }
+ node { name: "random_uniform/min" op: "Const" device: "/device:GPU:0" }
+ node { name: "random_uniform/max" op: "Const" device: "/device:GPU:0" }
+ node { name: "random_uniform/RandomUniform" op: "RandomUniform" input: "random_uniform/shape" device: "/device:GPU:0" }
+ node { name: "random_uniform/sub" op: "Sub" input: "random_uniform/max" input: "random_uniform/min" device: "/device:GPU:0" }
+ node { name: "random_uniform/mul" op: "Mul" input: "random_uniform/RandomUniform" input: "random_uniform/sub" device: "/device:GPU:0" }
+ node { name: "random_uniform" op: "Add" input: "random_uniform/mul" input: "random_uniform/min" device: "/device:GPU:0" }
+ node { name: "Variable" op: "VariableV2" device: "/device:GPU:0" }
+ node { name: "Variable/Assign" op: "Assign" input: "Variable" input: "random_uniform" device: "/device:GPU:0" }
+ node { name: "Variable/read" op: "Identity" input: "Variable" device: "/device:GPU:0" }
+ node { name: "random_uniform_1/shape" op: "Const" device: "/device:GPU:0" }
+ node { name: "random_uniform_1/min" op: "Const" device: "/device:GPU:0" }
+ node { name: "random_uniform_1/max" op: "Const" device: "/device:GPU:0" }
+ node { name: "random_uniform_1/RandomUniform" op: "RandomUniform" input: "random_uniform_1/shape" device: "/device:GPU:0" }
+ node { name: "random_uniform_1/sub" op: "Sub" input: "random_uniform_1/max" input: "random_uniform_1/min" device: "/device:GPU:0" }
+ node { name: "random_uniform_1/mul" op: "Mul" input: "random_uniform_1/RandomUniform" input: "random_uniform_1/sub" device: "/device:GPU:0" }
+ node { name: "random_uniform_1" op: "Add" input: "random_uniform_1/mul" input: "random_uniform_1/min" device: "/device:GPU:0" }
+ node { name: "Variable_1" op: "VariableV2" device: "/device:GPU:0" }
+ node { name: "Variable_1/Assign" op: "Assign" input: "Variable_1" input: "random_uniform_1" device: "/device:GPU:0" }
+ node { name: "Variable_1/read" op: "Identity" input: "Variable_1" device: "/device:GPU:0" }
+ node { name: "MatMul" op: "MatMul" input: "Variable/read" input: "Variable_1/read" device: "/device:GPU:0" }
+ node { name: "group_deps" op: "NoOp" input: "^MatMul" device: "/device:GPU:0" }
+ """, self._StripGraph(gd))
+
+ def _VerifyRunGraph(self, n, m, k, transpose_a, transpose_b, dtype):
+ benchmark_instance = matmul_benchmark.MatmulBenchmark()
+ duration = benchmark_instance.run_graph("gpu", n, m, k, transpose_a,
+ transpose_b, 1, dtype)
+ self.assertTrue(duration > 1e-6)
+
+
+if __name__ == "__main__":
+ dtypes = [np.float32, np.float64]
+ index = 0
+ for _dtype in dtypes:
+ for _n, _m, (_transpose_a, _transpose_b) in itertools.product(
+ [512, 1024], [1, 8, 16, 128], [(False, False), (True, False), (False,
+ True)]):
+ _k = _n
+ setattr(MatmulBenchmarkTest, "testBuildGraph_" + str(index),
+ BuildGraphTest(_n, _m, _k, _transpose_a, _transpose_b, _dtype))
+ setattr(MatmulBenchmarkTest, "testRunGraph_" + str(index),
+ RunGraphTest(_n, _m, _k, _transpose_a, _transpose_b, _dtype))
+ index += 1
+ googletest.main()
diff --git a/tensorflow/stream_executor/blas.cc b/tensorflow/stream_executor/blas.cc
index a59a1dda71..da09d84921 100644
--- a/tensorflow/stream_executor/blas.cc
+++ b/tensorflow/stream_executor/blas.cc
@@ -67,6 +67,10 @@ string SideString(Side s) {
}
}
+// -- AlgorithmConfig
+
+string AlgorithmConfig::ToString() const { return port::StrCat(algorithm_); }
+
string ComputationTypeString(ComputationType ty) {
switch (ty) {
case ComputationType::kF16:
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index cfff3649c8..eb1b19c5d9 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -44,7 +44,6 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
-#include "tensorflow/stream_executor/platform/port.h"
namespace Eigen {
struct half;
@@ -108,6 +107,10 @@ string ComputationTypeString(ComputationType ty);
// Opaque identifier for an "algorithm" used by a blas routine. This functions
// as a hint to the blas library.
typedef int64 AlgorithmType;
+constexpr AlgorithmType kDefaultAlgorithm = -1;
+constexpr AlgorithmType kDefaultBlasGemm = -2;
+constexpr AlgorithmType kDefaultBlasGemv = -3;
+constexpr AlgorithmType kNoAlgorithm = -4;
// blas uses -1 to represent the default algorithm. This happens to match up
// with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast
@@ -134,10 +137,28 @@ class ProfileResult {
private:
bool is_valid_ = false;
- AlgorithmType algorithm_ = 0;
+ AlgorithmType algorithm_ = kDefaultAlgorithm;
float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
};
+class AlgorithmConfig {
+ public:
+ AlgorithmConfig() : algorithm_(kDefaultAlgorithm) {}
+ explicit AlgorithmConfig(AlgorithmType algorithm) : algorithm_(algorithm) {}
+ AlgorithmType algorithm() const { return algorithm_; }
+ void set_algorithm(AlgorithmType val) { algorithm_ = val; }
+ bool operator==(const AlgorithmConfig &other) const {
+ return this->algorithm_ == other.algorithm_;
+ }
+ bool operator!=(const AlgorithmConfig &other) const {
+ return !(*this == other);
+ }
+ string ToString() const;
+
+ private:
+ AlgorithmType algorithm_;
+};
+
// BLAS support interface -- this can be derived from a GPU executor when the
// underlying platform has an BLAS library implementation available. See
// StreamExecutor::AsBlas().
@@ -453,6 +474,29 @@ class BlasSupport {
std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy) = 0;
+ virtual bool DoBlasGemvWithProfiling(
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
+ const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
+ int incx, float beta, DeviceMemory<float> *y, int incy,
+ ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemvWithProfiling(
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
+ const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
+ int incx, double beta, DeviceMemory<double> *y, int incy,
+ ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemvWithProfiling(
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
+ std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
+ int lda, const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
+ ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemvWithProfiling(
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
+ std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
+ int lda, const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
+ int incy, ProfileResult *output_profile_result) = 0;
+
// Performs a rank-1 update of a general matrix.
//
// a <- alpha * x * y' + a,
@@ -935,8 +979,39 @@ class BlasSupport {
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc) = 0;
- // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. Note that
- // any or all of these algorithms may still be
+ virtual bool DoBlasGemmWithProfiling(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
+ int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
+ DeviceMemory<Eigen::half> *c, int ldc,
+ ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemmWithProfiling(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
+ int ldc, ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemmWithProfiling(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc,
+ ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemmWithProfiling(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemmWithProfiling(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ ProfileResult *output_profile_result) = 0;
+
+ // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm.
virtual bool GetBlasGemmAlgorithms(
std::vector<AlgorithmType> *out_algorithms) = 0;
@@ -1473,6 +1548,28 @@ class BlasSupport {
const DeviceMemory<std::complex<double>> &x, int incx, \
std::complex<double> beta, \
DeviceMemory<std::complex<double>> *y, int incy) override; \
+ bool DoBlasGemvWithProfiling( \
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, \
+ const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, \
+ int incx, float beta, DeviceMemory<float> *y, int incy, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemvWithProfiling( \
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, \
+ const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, \
+ int incx, double beta, DeviceMemory<double> *y, int incy, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemvWithProfiling( \
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
+ std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, \
+ int lda, const DeviceMemory<std::complex<float>> &x, int incx, \
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *y, \
+ int incy, blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemvWithProfiling( \
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
+ std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, \
+ int lda, const DeviceMemory<std::complex<double>> &x, int incx, \
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *y, \
+ int incy, blas::ProfileResult *output_profile_result) override; \
bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, \
const DeviceMemory<float> &x, int incx, \
const DeviceMemory<float> &y, int incy, \
@@ -1751,6 +1848,39 @@ class BlasSupport {
const DeviceMemory<std::complex<double>> &b, int ldb, \
std::complex<double> beta, \
DeviceMemory<std::complex<double>> *c, int ldc) override; \
+ bool DoBlasGemmWithProfiling( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, float alpha, \
+ const DeviceMemory<Eigen::half> &a, int lda, \
+ const DeviceMemory<Eigen::half> &b, int ldb, float beta, \
+ DeviceMemory<Eigen::half> *c, int ldc, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemmWithProfiling( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
+ int lda, const DeviceMemory<float> &b, int ldb, float beta, \
+ DeviceMemory<float> *c, int ldc, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemmWithProfiling( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, double alpha, \
+ const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \
+ int ldb, double beta, DeviceMemory<double> *c, int ldc, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemmWithProfiling( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ const DeviceMemory<std::complex<float>> &b, int ldb, \
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemmWithProfiling( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ const DeviceMemory<std::complex<double>> &b, int ldb, \
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \
+ int ldc, blas::ProfileResult *output_profile_result) override; \
bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \
override; \
bool DoBlasGemmWithAlgorithm( \
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 2817364e97..cb2b06d47c 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -1857,6 +1857,180 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
CUDAComplex(CUDAMemoryMutable(c)), ldc);
}
+bool CUDABlas::DoBlasGemvWithProfiling(
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
+ const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
+ int incx, float beta, DeviceMemory<float> *y, int incy,
+ blas::ProfileResult *output_profile_result) {
+ return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
+ incx, beta, y, incy,
+ output_profile_result);
+}
+
+bool CUDABlas::DoBlasGemvWithProfiling(
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
+ const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
+ int incx, double beta, DeviceMemory<double> *y, int incy,
+ blas::ProfileResult *output_profile_result) {
+ return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
+ incx, beta, y, incy,
+ output_profile_result);
+}
+
+bool CUDABlas::DoBlasGemvWithProfiling(
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
+ std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
+ int lda, const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
+ blas::ProfileResult *output_profile_result) {
+ return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
+ incx, beta, y, incy,
+ output_profile_result);
+}
+
+bool CUDABlas::DoBlasGemvWithProfiling(
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
+ std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
+ int lda, const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
+ blas::ProfileResult *output_profile_result) {
+ return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
+ incx, beta, y, incy,
+ output_profile_result);
+}
+
+bool CUDABlas::DoBlasGemmWithProfiling(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
+ int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
+ DeviceMemory<Eigen::half> *c, int ldc,
+ blas::ProfileResult *output_profile_result) {
+ return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
+ lda, b, ldb, beta, c, ldc,
+ output_profile_result);
+}
+
+bool CUDABlas::DoBlasGemmWithProfiling(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
+ int ldc, blas::ProfileResult *output_profile_result) {
+ return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
+ lda, b, ldb, beta, c, ldc,
+ output_profile_result);
+}
+
+bool CUDABlas::DoBlasGemmWithProfiling(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc,
+ blas::ProfileResult *output_profile_result) {
+ return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
+ lda, b, ldb, beta, c, ldc,
+ output_profile_result);
+}
+
+bool CUDABlas::DoBlasGemmWithProfiling(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ blas::ProfileResult *output_profile_result) {
+ return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
+ lda, b, ldb, beta, c, ldc,
+ output_profile_result);
+}
+
+bool CUDABlas::DoBlasGemmWithProfiling(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ blas::ProfileResult *output_profile_result) {
+ return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
+ lda, b, ldb, beta, c, ldc,
+ output_profile_result);
+}
+
+template <typename T>
+bool CUDABlas::DoBlasGemvWithProfilingImpl(
+ Stream *stream, blas::Transpose trans, uint64 m, uint64 n, const T &alpha,
+ const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
+ const T &beta, DeviceMemory<T> *y, int incy,
+ blas::ProfileResult *output_profile_result) {
+ struct TimerDeleter {
+ void operator()(CUDATimer *t) {
+ t->Destroy();
+ delete t;
+ }
+ };
+ std::unique_ptr<CUDATimer, TimerDeleter> timer;
+ if (output_profile_result != nullptr) {
+ timer.reset(new CUDATimer(parent_));
+ if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ return false;
+ }
+ }
+
+ // Call blasGemm
+ bool result =
+ DoBlasGemv(stream, trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
+
+ if (timer != nullptr && result) {
+ // CUDATimer will CHECK-fail if we Stop() it while the stream is in an error
+ // state.
+ if (!timer->Stop(AsCUDAStream(stream))) {
+ return false;
+ }
+ output_profile_result->set_is_valid(true);
+ output_profile_result->set_algorithm(blas::kDefaultBlasGemv);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
+ }
+ return result;
+}
+
+template <typename T, typename ParamType>
+bool CUDABlas::DoBlasGemmWithProfilingImpl(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
+ int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
+ DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
+ struct TimerDeleter {
+ void operator()(CUDATimer *t) {
+ t->Destroy();
+ delete t;
+ }
+ };
+ std::unique_ptr<CUDATimer, TimerDeleter> timer;
+ if (output_profile_result != nullptr) {
+ timer.reset(new CUDATimer(parent_));
+ if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ return false;
+ }
+ }
+
+ // Call blasGemm
+ bool result = DoBlasGemm(stream, transa, transb, m, n, k, alpha, a, lda, b,
+ ldb, beta, c, ldc);
+
+ if (timer != nullptr && result) {
+ // CUDATimer will CHECK-fail if we Stop() it while the stream is in an error
+ // state.
+ if (!timer->Stop(AsCUDAStream(stream))) {
+ return false;
+ }
+ output_profile_result->set_is_valid(true);
+ output_profile_result->set_algorithm(blas::kDefaultBlasGemm);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
+ }
+ return result;
+}
+
template <typename InT, typename OutT, typename CompT>
bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
@@ -1920,6 +2094,9 @@ bool CUDABlas::GetBlasGemmAlgorithms(
std::vector<blas::AlgorithmType> *out_algorithms) {
// cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
// were first introduced in CUDA 8.
+// Note that when CUDA version and compute capability is not sufficient, we
+// still return the out_algorithms. Caller needs to make sure that in this case,
+// the returned vector is empty.
#if CUDA_VERSION >= 8000
for (cublasGemmAlgo_t algo :
{CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1,
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h
index 4a8641b300..80cda97117 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.h
+++ b/tensorflow/stream_executor/cuda/cuda_blas.h
@@ -127,6 +127,23 @@ class CUDABlas : public blas::BlasSupport {
blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result);
+ // Helper function for implementing DoBlasGemmWithProfiling.
+ template <typename T, typename ParamType>
+ bool DoBlasGemmWithProfilingImpl(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
+ int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
+ DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result);
+
+ // Helper function for implementing DoBlasGemvWithProfiling.
+ template <typename T>
+ bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans,
+ uint64 m, uint64 n, const T &alpha,
+ const DeviceMemory<T> &a, int lda,
+ const DeviceMemory<T> &x, int incx,
+ const T &beta, DeviceMemory<T> *y, int incy,
+ blas::ProfileResult *output_profile_result);
+
// mutex that guards the cuBLAS handle for this device.
mutex mu_;
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 5996195173..c9b36ba7ab 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -3458,6 +3458,184 @@ struct ThenBlasWithProfileImpl {
};
} // anonymous namespace
+Stream &Stream::ThenBlasGemvWithProfiling(
+ blas::Transpose trans, uint64 m, uint64 n, float alpha,
+ const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
+ int incx, float beta, DeviceMemory<float> *y, int incy,
+ blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
+ PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
+ PARAM(incy));
+
+ ThenBlasWithProfileImpl<
+ blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int,
+ const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
+ alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
+}
+
+Stream &Stream::ThenBlasGemvWithProfiling(
+ blas::Transpose trans, uint64 m, uint64 n, double alpha,
+ const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
+ int incx, double beta, DeviceMemory<double> *y, int incy,
+ blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
+ PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
+ PARAM(incy));
+
+ ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double,
+ const DeviceMemory<double> &, int,
+ const DeviceMemory<double> &, int, double,
+ DeviceMemory<double> *, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
+ alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
+}
+
+Stream &Stream::ThenBlasGemvWithProfiling(
+ blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
+ blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
+ PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
+ PARAM(incy));
+
+ ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>,
+ const DeviceMemory<std::complex<float>> &, int,
+ const DeviceMemory<std::complex<float>> &, int,
+ std::complex<float>,
+ DeviceMemory<std::complex<float>> *, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
+ alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
+}
+
+Stream &Stream::ThenBlasGemvWithProfiling(
+ blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
+ blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
+ PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
+ PARAM(incy));
+
+ ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>,
+ const DeviceMemory<std::complex<double>> &, int,
+ const DeviceMemory<std::complex<double>> &, int,
+ std::complex<double>,
+ DeviceMemory<std::complex<double>> *, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
+ alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
+}
+
+Stream &Stream::ThenBlasGemmWithProfiling(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
+ const DeviceMemory<Eigen::half> &b, int ldb, float beta,
+ DeviceMemory<Eigen::half> *c, int ldc,
+ blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc));
+
+ ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
+ uint64, float, const DeviceMemory<Eigen::half> &, int,
+ const DeviceMemory<Eigen::half> &, int, float,
+ DeviceMemory<Eigen::half> *, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
+ m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ output_profile_result);
+}
+
+Stream &Stream::ThenBlasGemmWithProfiling(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
+ int ldc, blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc));
+
+ ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
+ uint64, float, const DeviceMemory<float> &, int,
+ const DeviceMemory<float> &, int, float,
+ DeviceMemory<float> *, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
+ m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ output_profile_result);
+}
+
+Stream &Stream::ThenBlasGemmWithProfiling(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc,
+ blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc));
+
+ ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
+ uint64, double, const DeviceMemory<double> &, int,
+ const DeviceMemory<double> &, int, double,
+ DeviceMemory<double> *, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
+ m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ output_profile_result);
+}
+
+Stream &Stream::ThenBlasGemmWithProfiling(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc));
+
+ ThenBlasWithProfileImpl<
+ blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
+ const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
+ DeviceMemory<std::complex<float>> *, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
+ m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ output_profile_result);
+}
+
+Stream &Stream::ThenBlasGemmWithProfiling(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc));
+
+ ThenBlasWithProfileImpl<
+ blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
+ const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
+ DeviceMemory<std::complex<double>> *, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
+ m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ output_profile_result);
+}
+
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a,
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index 3c8b7ee894..e218873839 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -934,6 +934,31 @@ class Stream {
std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy);
+ Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
+ float alpha, const DeviceMemory<float> &a,
+ int lda, const DeviceMemory<float> &x,
+ int incx, float beta,
+ DeviceMemory<float> *y, int incy,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
+ double alpha, const DeviceMemory<double> &a,
+ int lda, const DeviceMemory<double> &x,
+ int incx, double beta,
+ DeviceMemory<double> *y, int incy,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemvWithProfiling(
+ blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemvWithProfiling(
+ blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
+ int incy, blas::ProfileResult *output_profile_result);
+
// See BlasSupport::DoBlasGer.
Stream &ThenBlasGer(uint64 m, uint64 n, float alpha,
const DeviceMemory<float> &x, int incx,
@@ -1249,6 +1274,44 @@ class Stream {
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc);
+ Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha,
+ const DeviceMemory<Eigen::half> &a, int lda,
+ const DeviceMemory<Eigen::half> &b, int ldb,
+ float beta, DeviceMemory<Eigen::half> *c,
+ int ldc,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb,
+ float beta, DeviceMemory<float> *c, int ldc,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb,
+ double beta, DeviceMemory<double> *c,
+ int ldc,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemmWithProfiling(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemmWithProfiling(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ blas::ProfileResult *output_profile_result);
+
// See BlasSupport::DoBlasGemmWithAlgorithm.
Stream &ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,