aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2017-03-02 17:49:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-02 18:08:01 -0800
commit01194694948eb883e99af597d9dbbf3fc9f5c9e2 (patch)
treeab3517cf656259681283a90c6682c5b320ac36e3 /tensorflow/stream_executor/stream.cc
parente065b3093f4fec5a5f79ad9de81f6baab361962e (diff)
[XLA] [StreamExecutor] Tune GEMMs when possible.
cublas 8 adds the cublasGemmEx function, which lets you specify an explicit "algorithm" for the computation. This functions as an opaque tuning hint to cublas. This patch adds support for cublasGemmEx to StreamExecutor, and wires up XLA's GemmThunk to use the new function. This patch does not add GEMM autotuning support in TensorFlow proper, only XLA. Change: 149068961
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r--tensorflow/stream_executor/stream.cc163
1 files changed, 158 insertions, 5 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 9b67531689..f28f965f2c 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -75,6 +75,10 @@ string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
string ToVlogString(blas::Side s) { return blas::SideString(s); }
+string ToVlogString(blas::ComputationType ty) {
+ return blas::ComputationTypeString(ty);
+}
+
string ToVlogString(const void *ptr) {
if (ptr == nullptr) {
return "null";
@@ -109,6 +113,8 @@ string ToVlogString(const DeviceMemoryBase *memory) {
return ToVlogString(*memory);
}
+string ToVlogString(const Eigen::half &h) { return port::StrCat(h); }
+
string ToVlogString(int i) { return port::StrCat(i); }
string ToVlogString(uint32 i) { return port::StrCat(i); }
@@ -1520,21 +1526,33 @@ struct ThenBlasImpl {
// arguments except the first one of Stream* type.
Stream &operator()(Stream *stream,
bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
- Args... args);
+ Args... args) {
+ return Run(stream, blas_func, /*record_error=*/true, args...);
+ }
+
+ // Like operator(), but only calls stream->CheckError() if record_error is
+ // true.
+ Stream &Run(Stream *stream,
+ bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
+ bool record_error, Args... args);
};
template <typename... Args>
-Stream &ThenBlasImpl<Args...>::operator()(
+Stream &ThenBlasImpl<Args...>::Run(
Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
- Args... args) {
+ bool record_error, Args... args) {
if (stream->ok()) {
+ bool ok;
if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
- stream->CheckError((blas->*blas_func)(stream, args...));
+ ok = (blas->*blas_func)(stream, args...);
} else {
- stream->CheckError(false);
LOG(WARNING)
<< "attempting to perform BLAS operation using StreamExecutor "
"without BLAS support";
+ ok = false;
+ }
+ if (record_error) {
+ stream->CheckError(ok);
}
}
return *stream;
@@ -3215,6 +3233,141 @@ Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
alpha, a, lda, b, ldb, beta, c, ldc);
}
+namespace {
+// Like ThenBlasImpl, except this expects the last argument of blas_func to be a
+// blas::ProfileResult*. This functor doesn't put the stream into an error
+// state if the op fails and the profile result is non-null. Instead, the
+// error-ness is returned in the profile result itself.
+template <typename... Args>
+struct ThenBlasWithProfileImpl {
+ Stream &operator()(Stream *stream,
+ bool (blas::BlasSupport::*blas_func)(
+ Stream *, Args..., blas::ProfileResult *),
+ Args... args, blas::ProfileResult *profile_result) {
+ ThenBlasImpl<Args..., blas::ProfileResult *> Runner;
+ bool record_error = profile_result == nullptr;
+ return Runner.Run(stream, blas_func, record_error, args..., profile_result);
+ }
+};
+} // anonymous namespace
+
+Stream &Stream::ThenBlasGemmWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a,
+ int lda, const DeviceMemory<Eigen::half> &b, int ldb,
+ const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
+ PARAM(algorithm));
+
+ ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
+ uint64, const Eigen::half &,
+ const DeviceMemory<Eigen::half> &, int,
+ const DeviceMemory<Eigen::half> &, int,
+ const Eigen::half &, DeviceMemory<Eigen::half> *, int,
+ blas::ComputationType, blas::AlgorithmType>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
+ m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
+ algorithm, output_profile_result);
+}
+
+Stream &Stream::ThenBlasGemmWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
+ int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
+ PARAM(algorithm));
+
+ ThenBlasWithProfileImpl<
+ blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
+ const DeviceMemory<float> &, int, const DeviceMemory<float> &, int, float,
+ DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
+ m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
+ algorithm, output_profile_result);
+}
+
+Stream &Stream::ThenBlasGemmWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
+ PARAM(algorithm));
+
+ ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
+ uint64, double, const DeviceMemory<double> &, int,
+ const DeviceMemory<double> &, int, double,
+ DeviceMemory<double> *, int, blas::ComputationType,
+ blas::AlgorithmType>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
+ m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
+ algorithm, output_profile_result);
+}
+
+Stream &Stream::ThenBlasGemmWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
+ PARAM(algorithm));
+
+ ThenBlasWithProfileImpl<
+ blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
+ const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
+ DeviceMemory<std::complex<float>> *, int, blas::ComputationType,
+ blas::AlgorithmType>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
+ m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
+ algorithm, output_profile_result);
+}
+
+Stream &Stream::ThenBlasGemmWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
+ PARAM(algorithm));
+
+ ThenBlasWithProfileImpl<
+ blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
+ const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
+ DeviceMemory<std::complex<double>> *, int, blas::ComputationType,
+ blas::AlgorithmType>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
+ m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
+ algorithm, output_profile_result);
+}
+
Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
uint64 n, std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a,