diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-24 04:35:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-24 04:38:23 -0700 |
commit | 9f38ab74161a0e8dd0b35b47f23ddeda7b286af3 (patch) | |
tree | df8e6c2829241dc44ca0c5811a007aa9703ed4d0 /tensorflow/stream_executor/stream.cc | |
parent | e74b98ba6348d869fee50b95b7795885fdedecee (diff) |
Add variants of DoBlasGemmWithAlgorithm with alpha being on device.
This is in preparation of allowing XLA to fuse (A dot b) * alpha where alpha
can be on device instead of just a constant.
PiperOrigin-RevId: 194068597
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 114 |
1 files changed, 68 insertions, 46 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index f59d9a13ac..093f0c9306 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -20,6 +20,7 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/host_buffer.h" +#include "tensorflow/stream_executor/host_or_device_scalar.h" #include "tensorflow/stream_executor/lib/stacktrace.h" #include "tensorflow/stream_executor/lib/strcat.h" #include "tensorflow/stream_executor/platform.h" @@ -133,6 +134,14 @@ string ToVlogString(float f) { return port::StrCat(f); } string ToVlogString(double d) { return port::StrCat(d); } +template <typename T> +string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) { + if (memory_or_constant.is_pointer()) { + return ToVlogString(memory_or_constant.pointer()); + } + return ToVlogString(memory_or_constant.value()); +} + template <class T> string ToVlogString(port::ArraySlice<T> elements) { string str = port::StrCat( @@ -3882,22 +3891,23 @@ Stream &Stream::ThenBlasGemmWithProfiling( Stream &Stream::ThenBlasGemmWithAlgorithm( blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, - uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a, - int lda, const DeviceMemory<Eigen::half> &b, int ldb, - const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc, - blas::ComputationType computation_type, blas::AlgorithmType algorithm, - blas::ProfileResult *output_profile_result) { + uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha, + const DeviceMemory<Eigen::half> &a, int lda, + const DeviceMemory<Eigen::half> &b, int ldb, + const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c, + int ldc, blas::ComputationType computation_type, + blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), PARAM(algorithm)); - ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, - uint64, const Eigen::half &, - const DeviceMemory<Eigen::half> &, int, - const DeviceMemory<Eigen::half> &, int, - const Eigen::half &, DeviceMemory<Eigen::half> *, int, - blas::ComputationType, blas::AlgorithmType> + ThenBlasWithProfileImpl< + blas::Transpose, blas::Transpose, uint64, uint64, uint64, + const HostOrDeviceScalar<Eigen::half> &, + const DeviceMemory<Eigen::half> &, int, const DeviceMemory<Eigen::half> &, + int, const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *, + int, blas::ComputationType, blas::AlgorithmType> impl; return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, @@ -3906,18 +3916,20 @@ Stream &Stream::ThenBlasGemmWithAlgorithm( Stream &Stream::ThenBlasGemmWithAlgorithm( blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, - uint64 k, int alpha, const DeviceMemory<int8> &a, int lda, - const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c, - int ldc, blas::ComputationType computation_type, - blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { + uint64 k, const HostOrDeviceScalar<int> &alpha, const DeviceMemory<int8> &a, + int lda, const DeviceMemory<int8> &b, int ldb, + const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, int ldc, + blas::ComputationType computation_type, blas::AlgorithmType algorithm, + blas::ProfileResult *output_profile_result) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), PARAM(algorithm)); ThenBlasWithProfileImpl< - blas::Transpose, blas::Transpose, uint64, uint64, uint64, int, - const DeviceMemory<int8> &, int, const DeviceMemory<int8> &, int, int, + blas::Transpose, blas::Transpose, uint64, uint64, uint64, + const HostOrDeviceScalar<int> &, const DeviceMemory<int8> &, int, + const DeviceMemory<int8> &, int, const HostOrDeviceScalar<int> &, DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType> impl; return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, @@ -3927,8 +3939,9 @@ Stream &Stream::ThenBlasGemmWithAlgorithm( Stream &Stream::ThenBlasGemmWithAlgorithm( blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, - uint64 k, float alpha, const DeviceMemory<float> &a, int lda, - const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, + uint64 k, const HostOrDeviceScalar<float> &alpha, + const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b, + int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), @@ -3937,8 +3950,9 @@ Stream &Stream::ThenBlasGemmWithAlgorithm( PARAM(algorithm)); ThenBlasWithProfileImpl< - blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, - const DeviceMemory<float> &, int, const DeviceMemory<float> &, int, float, + blas::Transpose, blas::Transpose, uint64, uint64, uint64, + const HostOrDeviceScalar<float> &, const DeviceMemory<float> &, int, + const DeviceMemory<float> &, int, const HostOrDeviceScalar<float> &, DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType> impl; return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, @@ -3948,32 +3962,35 @@ Stream &Stream::ThenBlasGemmWithAlgorithm( Stream &Stream::ThenBlasGemmWithAlgorithm( blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, - uint64 k, double alpha, const DeviceMemory<double> &a, int lda, - const DeviceMemory<double> &b, int ldb, double beta, - DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type, + uint64 k, const HostOrDeviceScalar<double> &alpha, + const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, + int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c, + int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), PARAM(algorithm)); - ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, - uint64, double, const DeviceMemory<double> &, int, - const DeviceMemory<double> &, int, double, - DeviceMemory<double> *, int, blas::ComputationType, - blas::AlgorithmType> + ThenBlasWithProfileImpl< + blas::Transpose, blas::Transpose, uint64, uint64, uint64, + const HostOrDeviceScalar<double> &, const DeviceMemory<double> &, int, + const DeviceMemory<double> &, int, const HostOrDeviceScalar<double> &, + DeviceMemory<double> *, int, blas::ComputationType, blas::AlgorithmType> impl; return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, - m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, + m, n, k, HostOrDeviceScalar<double>(alpha), a, lda, b, ldb, + HostOrDeviceScalar<double>(beta), c, ldc, computation_type, algorithm, output_profile_result); } Stream &Stream::ThenBlasGemmWithAlgorithm( blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, - uint64 k, std::complex<float> alpha, + uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha, const DeviceMemory<std::complex<float>> &a, int lda, const DeviceMemory<std::complex<float>> &b, int ldb, - std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, + const HostOrDeviceScalar<std::complex<float>> &beta, + DeviceMemory<std::complex<float>> *c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), @@ -3981,12 +3998,14 @@ Stream &Stream::ThenBlasGemmWithAlgorithm( PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), PARAM(algorithm)); - ThenBlasWithProfileImpl< - blas::Transpose, blas::Transpose, uint64, uint64, uint64, - std::complex<float>, const DeviceMemory<std::complex<float>> &, int, - const DeviceMemory<std::complex<float>> &, int, std::complex<float>, - DeviceMemory<std::complex<float>> *, int, blas::ComputationType, - blas::AlgorithmType> + ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, + uint64, + const HostOrDeviceScalar<std::complex<float>> &, + const DeviceMemory<std::complex<float>> &, int, + const DeviceMemory<std::complex<float>> &, int, + const HostOrDeviceScalar<std::complex<float>> &, + DeviceMemory<std::complex<float>> *, int, + blas::ComputationType, blas::AlgorithmType> impl; return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, @@ -3995,10 +4014,11 @@ Stream &Stream::ThenBlasGemmWithAlgorithm( Stream &Stream::ThenBlasGemmWithAlgorithm( blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, - uint64 k, std::complex<double> alpha, + uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha, const DeviceMemory<std::complex<double>> &a, int lda, const DeviceMemory<std::complex<double>> &b, int ldb, - std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, + const HostOrDeviceScalar<std::complex<double>> &beta, + DeviceMemory<std::complex<double>> *c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), @@ -4006,12 +4026,14 @@ Stream &Stream::ThenBlasGemmWithAlgorithm( PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), PARAM(algorithm)); - ThenBlasWithProfileImpl< - blas::Transpose, blas::Transpose, uint64, uint64, uint64, - std::complex<double>, const DeviceMemory<std::complex<double>> &, int, - const DeviceMemory<std::complex<double>> &, int, std::complex<double>, - DeviceMemory<std::complex<double>> *, int, blas::ComputationType, - blas::AlgorithmType> + ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, + uint64, + const HostOrDeviceScalar<std::complex<double>> &, + const DeviceMemory<std::complex<double>> &, int, + const DeviceMemory<std::complex<double>> &, int, + const HostOrDeviceScalar<std::complex<double>> &, + DeviceMemory<std::complex<double>> *, int, + blas::ComputationType, blas::AlgorithmType> impl; return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, |