aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.cc
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-05-24 19:12:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-24 19:15:01 -0700
commitb59833c3fd91511b33255369016868e4ae6cda2e (patch)
treeecbd70cfd3abb5d934f6eb4b7280a35e8589f5cf /tensorflow/stream_executor/stream.cc
parent2b99d9cbc7166efedaff9eee11744348da30fc8a (diff)
Merge changes from github.
Revert #18413. Too many internal test failures due to the name scope change caused by this change. Revert #18192. Cannot use re2::StringPiece internally. Need alternative for set call. Will pull and clean this up in a separate change. PiperOrigin-RevId: 197991247
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r--tensorflow/stream_executor/stream.cc34
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 2bc9b6b798..4a98cfe164 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -4482,6 +4482,40 @@ Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
Stream &Stream::ThenBlasGemmBatched(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
+ int batch_count) {
+ return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
+ b, ldb, beta, c, ldc, batch_count,
+ /*scratch_allocator=*/nullptr);
+}
+
+Stream &Stream::ThenBlasGemmBatchedWithScratch(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
+ const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
+ float, const port::ArraySlice<DeviceMemory<Eigen::half> *> &,
+ int, int, ScratchAllocator *>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
+ k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
+ scratch_allocator);
+}
+
+Stream &Stream::ThenBlasGemmBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,