aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-24 04:35:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-24 04:38:23 -0700
commit9f38ab74161a0e8dd0b35b47f23ddeda7b286af3 (patch)
treedf8e6c2829241dc44ca0c5811a007aa9703ed4d0 /tensorflow/stream_executor/stream.cc
parente74b98ba6348d869fee50b95b7795885fdedecee (diff)
Add variants of DoBlasGemmWithAlgorithm with alpha being on device.
This is in preparation of allowing XLA to fuse (A dot b) * alpha where alpha can be on device instead of just a constant. PiperOrigin-RevId: 194068597
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r--tensorflow/stream_executor/stream.cc114
1 files changed, 68 insertions, 46 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index f59d9a13ac..093f0c9306 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/host_buffer.h"
+#include "tensorflow/stream_executor/host_or_device_scalar.h"
#include "tensorflow/stream_executor/lib/stacktrace.h"
#include "tensorflow/stream_executor/lib/strcat.h"
#include "tensorflow/stream_executor/platform.h"
@@ -133,6 +134,14 @@ string ToVlogString(float f) { return port::StrCat(f); }
string ToVlogString(double d) { return port::StrCat(d); }
+template <typename T>
+string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
+ if (memory_or_constant.is_pointer()) {
+ return ToVlogString(memory_or_constant.pointer());
+ }
+ return ToVlogString(memory_or_constant.value());
+}
+
template <class T>
string ToVlogString(port::ArraySlice<T> elements) {
string str = port::StrCat(
@@ -3882,22 +3891,23 @@ Stream &Stream::ThenBlasGemmWithProfiling(
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a,
- int lda, const DeviceMemory<Eigen::half> &b, int ldb,
- const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc,
- blas::ComputationType computation_type, blas::AlgorithmType algorithm,
- blas::ProfileResult *output_profile_result) {
+ uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
+ const DeviceMemory<Eigen::half> &a, int lda,
+ const DeviceMemory<Eigen::half> &b, int ldb,
+ const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
+ int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
PARAM(algorithm));
- ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
- uint64, const Eigen::half &,
- const DeviceMemory<Eigen::half> &, int,
- const DeviceMemory<Eigen::half> &, int,
- const Eigen::half &, DeviceMemory<Eigen::half> *, int,
- blas::ComputationType, blas::AlgorithmType>
+ ThenBlasWithProfileImpl<
+ blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ const HostOrDeviceScalar<Eigen::half> &,
+ const DeviceMemory<Eigen::half> &, int, const DeviceMemory<Eigen::half> &,
+ int, const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *,
+ int, blas::ComputationType, blas::AlgorithmType>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
@@ -3906,18 +3916,20 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
- const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c,
- int ldc, blas::ComputationType computation_type,
- blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+ uint64 k, const HostOrDeviceScalar<int> &alpha, const DeviceMemory<int8> &a,
+ int lda, const DeviceMemory<int8> &b, int ldb,
+ const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
PARAM(algorithm));
ThenBlasWithProfileImpl<
- blas::Transpose, blas::Transpose, uint64, uint64, uint64, int,
- const DeviceMemory<int8> &, int, const DeviceMemory<int8> &, int, int,
+ blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ const HostOrDeviceScalar<int> &, const DeviceMemory<int8> &, int,
+ const DeviceMemory<int8> &, int, const HostOrDeviceScalar<int> &,
DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
@@ -3927,8 +3939,9 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
- const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
+ uint64 k, const HostOrDeviceScalar<float> &alpha,
+ const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
+ int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
@@ -3937,8 +3950,9 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
PARAM(algorithm));
ThenBlasWithProfileImpl<
- blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
- const DeviceMemory<float> &, int, const DeviceMemory<float> &, int, float,
+ blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ const HostOrDeviceScalar<float> &, const DeviceMemory<float> &, int,
+ const DeviceMemory<float> &, int, const HostOrDeviceScalar<float> &,
DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
@@ -3948,32 +3962,35 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
- const DeviceMemory<double> &b, int ldb, double beta,
- DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type,
+ uint64 k, const HostOrDeviceScalar<double> &alpha,
+ const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
+ int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
+ int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
PARAM(algorithm));
- ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
- uint64, double, const DeviceMemory<double> &, int,
- const DeviceMemory<double> &, int, double,
- DeviceMemory<double> *, int, blas::ComputationType,
- blas::AlgorithmType>
+ ThenBlasWithProfileImpl<
+ blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ const HostOrDeviceScalar<double> &, const DeviceMemory<double> &, int,
+ const DeviceMemory<double> &, int, const HostOrDeviceScalar<double> &,
+ DeviceMemory<double> *, int, blas::ComputationType, blas::AlgorithmType>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
- m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
+ m, n, k, HostOrDeviceScalar<double>(alpha), a, lda, b, ldb,
+ HostOrDeviceScalar<double>(beta), c, ldc, computation_type,
algorithm, output_profile_result);
}
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, std::complex<float> alpha,
+ uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &b, int ldb,
- std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ const HostOrDeviceScalar<std::complex<float>> &beta,
+ DeviceMemory<std::complex<float>> *c, int ldc,
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
@@ -3981,12 +3998,14 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
PARAM(algorithm));
- ThenBlasWithProfileImpl<
- blas::Transpose, blas::Transpose, uint64, uint64, uint64,
- std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
- const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
- DeviceMemory<std::complex<float>> *, int, blas::ComputationType,
- blas::AlgorithmType>
+ ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
+ uint64,
+ const HostOrDeviceScalar<std::complex<float>> &,
+ const DeviceMemory<std::complex<float>> &, int,
+ const DeviceMemory<std::complex<float>> &, int,
+ const HostOrDeviceScalar<std::complex<float>> &,
+ DeviceMemory<std::complex<float>> *, int,
+ blas::ComputationType, blas::AlgorithmType>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
@@ -3995,10 +4014,11 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, std::complex<double> alpha,
+ uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &b, int ldb,
- std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ const HostOrDeviceScalar<std::complex<double>> &beta,
+ DeviceMemory<std::complex<double>> *c, int ldc,
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
@@ -4006,12 +4026,14 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
PARAM(algorithm));
- ThenBlasWithProfileImpl<
- blas::Transpose, blas::Transpose, uint64, uint64, uint64,
- std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
- const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
- DeviceMemory<std::complex<double>> *, int, blas::ComputationType,
- blas::AlgorithmType>
+ ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
+ uint64,
+ const HostOrDeviceScalar<std::complex<double>> &,
+ const DeviceMemory<std::complex<double>> &, int,
+ const DeviceMemory<std::complex<double>> &, int,
+ const HostOrDeviceScalar<std::complex<double>> &,
+ DeviceMemory<std::complex<double>> *, int,
+ blas::ComputationType, blas::AlgorithmType>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,