aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-17 16:19:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-17 16:25:58 -0700
commit47e4d4b6b5742350233a8fd83cd81269792ed286 (patch)
treee13ba390de56684359e9771a98dd80690dfd1121
parent95c7f5344f8da74a839c459c6415855bffe4f004 (diff)
Use optimized functor for conjugate transpose in MatrixInverseOp.
Introduce convenience functions DoMatrixTranspose and DoConjugateMatrixTranspose. Misc. minor cleanup of templates in transpose_functor*. PiperOrigin-RevId: 172532252
-rw-r--r--tensorflow/core/kernels/cuda_solvers.h8
-rw-r--r--tensorflow/core/kernels/cuda_solvers_gpu.cu.cc18
-rw-r--r--tensorflow/core/kernels/matrix_inverse_op.cc12
-rw-r--r--tensorflow/core/kernels/matrix_solve_op.cc9
-rw-r--r--tensorflow/core/kernels/qr_op_impl.h9
-rw-r--r--tensorflow/core/kernels/self_adjoint_eig_v2_op_gpu.cc5
-rw-r--r--tensorflow/core/kernels/svd_op_gpu.cu.cc25
-rw-r--r--tensorflow/core/kernels/transpose_functor.h150
-rw-r--r--tensorflow/core/kernels/transpose_functor_cpu.cc72
-rw-r--r--tensorflow/core/kernels/transpose_functor_gpu.cu.cc52
10 files changed, 174 insertions, 186 deletions
diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h
index 60c4a0bfb4..eb720b191f 100644
--- a/tensorflow/core/kernels/cuda_solvers.h
+++ b/tensorflow/core/kernels/cuda_solvers.h
@@ -409,14 +409,6 @@ class DeviceLapackInfo : public ScratchSpace<int> {
};
namespace functor {
-// Helper functor to transpose and conjugate all matrices in a flattened batch.
-template <typename Device, typename Scalar>
-struct AdjointBatchFunctor {
- // We assume that the tensor sizes are correct.
- void operator()(const Device& device,
- typename TTypes<Scalar, 3>::ConstTensor input,
- typename TTypes<Scalar, 3>::Tensor output);
-};
// Helper functor to compute the product of diagonal elements in all matrices
// in a flattened batch.
diff --git a/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc b/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc
index 79961c01ca..4171f9d68e 100644
--- a/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc
+++ b/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc
@@ -29,24 +29,6 @@ namespace functor {
typedef Eigen::GpuDevice GPUDevice;
-// TODO(rmlarsen): Add a faster custom kernel similar to
-// SwapDimension1And2InTensor3 in tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
-template <typename Scalar>
-struct AdjointBatchFunctor<GPUDevice, Scalar> {
- void operator()(const GPUDevice& device,
- typename TTypes<Scalar, 3>::ConstTensor input,
- typename TTypes<Scalar, 3>::Tensor output) {
- const Eigen::array<int, 3> perm({0, 2, 1});
- To32Bit(output).device(device) = To32Bit(input).shuffle(perm).conjugate();
- }
-};
-
-// Instantiate implementations for the 4 numeric types.
-template struct AdjointBatchFunctor<GPUDevice, float>;
-template struct AdjointBatchFunctor<GPUDevice, double>;
-template struct AdjointBatchFunctor<GPUDevice, complex64>;
-template struct AdjointBatchFunctor<GPUDevice, complex128>;
-
namespace {
// Hacks around missing support for complex arithmetic in nvcc.
diff --git a/tensorflow/core/kernels/matrix_inverse_op.cc b/tensorflow/core/kernels/matrix_inverse_op.cc
index 832e508bb7..64edfe470d 100644
--- a/tensorflow/core/kernels/matrix_inverse_op.cc
+++ b/tensorflow/core/kernels/matrix_inverse_op.cc
@@ -33,6 +33,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/cuda_solvers.h"
+#include "tensorflow/core/kernels/transpose_functor.h"
#endif
namespace tensorflow {
@@ -135,15 +136,15 @@ class MatrixInverseOpGpu : public AsyncOpKernel {
input.shape(), &input_copy),
done);
auto input_copy_reshaped = input_copy.template flat_inner_dims<Scalar, 3>();
- auto input_reshaped = input.template flat_inner_dims<Scalar, 3>();
const GPUDevice& device = context->eigen_device<GPUDevice>();
if (!adjoint_) {
device.memcpy(input_copy.flat<Scalar>().data(),
input.flat<Scalar>().data(),
input.NumElements() * sizeof(Scalar));
} else {
- functor::AdjointBatchFunctor<GPUDevice, Scalar> functor;
- functor(device, input_reshaped, input_copy_reshaped);
+ OP_REQUIRES_OK_ASYNC(
+ context, DoConjugateMatrixTranspose(device, input, &input_copy),
+ done);
}
const int64 batch_size = input_copy_reshaped.dimension(0);
@@ -238,10 +239,7 @@ class MatrixInverseOpGpu : public AsyncOpKernel {
done);
}
}
- // Callback for checking info after kernels finish. Also capture the
- // temporary Tensors/ScratchSpace so they don't get deallocated before the
- // kernels run. TODO(rmlarsen): Use move capture once C++14 becomes
- // available.
+ // Callback for checking info after kernels finish.
auto info_checker = [context, done](
const Status& status,
const std::vector<HostLapackInfo>& host_infos) {
diff --git a/tensorflow/core/kernels/matrix_solve_op.cc b/tensorflow/core/kernels/matrix_solve_op.cc
index 862033e9fa..2e4098dfab 100644
--- a/tensorflow/core/kernels/matrix_solve_op.cc
+++ b/tensorflow/core/kernels/matrix_solve_op.cc
@@ -181,9 +181,6 @@ class MatrixSolveOpGpu : public AsyncOpKernel {
// false, try to reuse the input buffer if this op owns it exclusively.
Tensor input_copy;
const GPUDevice& device = context->eigen_device<GPUDevice>();
- std::vector<int> perm(ndims);
- std::iota(perm.begin(), perm.end(), 0);
- std::swap(perm[ndims - 2], perm[ndims - 1]);
if (adjoint_) {
// For the adjoint case, it is simpler to always make a transposed copy up
// front.
@@ -193,7 +190,7 @@ class MatrixSolveOpGpu : public AsyncOpKernel {
input.shape(), &input_copy),
done);
OP_REQUIRES_OK_ASYNC(context,
- DoTranspose(device, input, perm, &input_copy), done);
+ DoMatrixTranspose(device, input, &input_copy), done);
} else {
OP_REQUIRES_OK_ASYNC(
context,
@@ -267,7 +264,7 @@ class MatrixSolveOpGpu : public AsyncOpKernel {
done);
if (nrhs > 1) {
OP_REQUIRES_OK_ASYNC(
- context, DoTranspose(device, rhs, perm, &transposed_rhs), done);
+ context, DoMatrixTranspose(device, rhs, &transposed_rhs), done);
} else {
device.memcpy(transposed_rhs.flat<Scalar>().data(),
rhs.flat<Scalar>().data(),
@@ -327,7 +324,7 @@ class MatrixSolveOpGpu : public AsyncOpKernel {
// 4. Transpose X to get the final result in row-major form.
if (nrhs > 1) {
OP_REQUIRES_OK_ASYNC(
- context, DoTranspose(device, transposed_rhs, perm, output), done);
+ context, DoMatrixTranspose(device, transposed_rhs, output), done);
} else {
device.memcpy(output->flat<Scalar>().data(),
transposed_rhs.flat<Scalar>().data(),
diff --git a/tensorflow/core/kernels/qr_op_impl.h b/tensorflow/core/kernels/qr_op_impl.h
index e263eb22f1..c51d601437 100644
--- a/tensorflow/core/kernels/qr_op_impl.h
+++ b/tensorflow/core/kernels/qr_op_impl.h
@@ -190,12 +190,9 @@ class QrOpGpu : public AsyncOpKernel {
// Transpose input, since cuSolver uses column-major, while TensorFlow uses
// row-major storage.
- std::vector<int> perm(ndims);
- std::iota(perm.begin(), perm.end(), 0);
- std::swap(perm[ndims - 2], perm[ndims - 1]);
const GPUDevice& device = context->eigen_device<GPUDevice>();
OP_REQUIRES_OK_ASYNC(
- context, DoTranspose(device, input, perm, &input_transposed), done);
+ context, DoMatrixTranspose(device, input, &input_transposed), done);
// Compute QR decomposition in-place in input_transposed.
std::vector<DeviceLapackInfo> dev_info;
@@ -218,7 +215,7 @@ class QrOpGpu : public AsyncOpKernel {
// and copy it to the output buffer.
if (full_matrices_ || m == n) {
OP_REQUIRES_OK_ASYNC(
- context, DoTranspose(device, input_transposed, perm, r), done);
+ context, DoMatrixTranspose(device, input_transposed, r), done);
} else {
const Scalar alpha(1);
const Scalar beta(0);
@@ -280,7 +277,7 @@ class QrOpGpu : public AsyncOpKernel {
done);
}
OP_REQUIRES_OK_ASYNC(
- context, DoTranspose(device, input_transposed, perm, q), done);
+ context, DoMatrixTranspose(device, input_transposed, q), done);
}
// Asynchronously check return status from cuSolver kernels.
diff --git a/tensorflow/core/kernels/self_adjoint_eig_v2_op_gpu.cc b/tensorflow/core/kernels/self_adjoint_eig_v2_op_gpu.cc
index b0b4f89a27..3a84df07a9 100644
--- a/tensorflow/core/kernels/self_adjoint_eig_v2_op_gpu.cc
+++ b/tensorflow/core/kernels/self_adjoint_eig_v2_op_gpu.cc
@@ -148,11 +148,8 @@ class SelfAdjointEigV2OpGpu : public AsyncOpKernel {
if (compute_v_) {
// Transpose eigenvectors now stored in input_copy in column-major form to
// output in row-major form.
- std::vector<int> perm(ndims);
- std::iota(perm.begin(), perm.end(), 0);
- std::swap(perm[ndims - 2], perm[ndims - 1]);
OP_REQUIRES_OK_ASYNC(
- context, DoTranspose(device, input_copy, perm, eigenvectors), done);
+ context, DoMatrixTranspose(device, input_copy, eigenvectors), done);
}
// Asynchronously check return status from cuSolver kernels.
diff --git a/tensorflow/core/kernels/svd_op_gpu.cu.cc b/tensorflow/core/kernels/svd_op_gpu.cu.cc
index 1603a8aeda..dedc2da60b 100644
--- a/tensorflow/core/kernels/svd_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/svd_op_gpu.cu.cc
@@ -190,8 +190,8 @@ class SvdOpGpu : public AsyncOpKernel {
// TODO: can the two cases (MgeqN and MlessN) be simplified,
// common boilerplate be reduced, or even combined in one method?
void PerformSVD_MgeqN(OpKernelContext* context, DoneCallback done, int64 m,
- int64 n, int64 p, const gtl::ArraySlice<int32>& perm,
- const Tensor& M, Tensor* S, Tensor* U, Tensor* V) {
+ int64 n, int64 p, const Tensor& M, Tensor* S, Tensor* U,
+ Tensor* V) {
TensorShape shapeRaw = M.shape();
shapeRaw.RemoveLastDims(2);
@@ -207,7 +207,7 @@ class SvdOpGpu : public AsyncOpKernel {
solver->allocate_scoped_tensor(M.dtype(), input_shape, &input_copy),
done);
auto device = context->eigen_device<GPUDevice>();
- OP_REQUIRES_OK_ASYNC(context, DoTranspose(device, M, perm, &input_copy),
+ OP_REQUIRES_OK_ASYNC(context, DoMatrixTranspose(device, M, &input_copy),
done);
// I need to transpose U at the end
@@ -250,7 +250,7 @@ class SvdOpGpu : public AsyncOpKernel {
// Transpose U
if (compute_uv_) {
- OP_REQUIRES_OK_ASYNC(context, DoTranspose(device, u_copy, perm, U), done);
+ OP_REQUIRES_OK_ASYNC(context, DoMatrixTranspose(device, u_copy, U), done);
}
// now check if the SVD operation succeeded or not
@@ -259,8 +259,8 @@ class SvdOpGpu : public AsyncOpKernel {
// The SVD if m < n
void PerformSVD_MlessN(OpKernelContext* context, DoneCallback done, int64 m,
- int64 n, int64 p, const gtl::ArraySlice<int32>& perm,
- const Tensor& M, Tensor* S, Tensor* U, Tensor* V) {
+ int64 n, int64 p, const Tensor& M, Tensor* S,
+ Tensor* U, Tensor* V) {
// Perform the SVD on M'
// Reuse the input buffer or make a copy for the SVD depending on whether
@@ -325,7 +325,7 @@ class SvdOpGpu : public AsyncOpKernel {
// Transpose V
if (compute_uv_) {
auto device = context->eigen_device<GPUDevice>();
- OP_REQUIRES_OK_ASYNC(context, DoTranspose(device, v_copy, perm, V), done);
+ OP_REQUIRES_OK_ASYNC(context, DoMatrixTranspose(device, v_copy, V), done);
}
// now check if the SVD operation succeeded or not
@@ -389,19 +389,12 @@ class SvdOpGpu : public AsyncOpKernel {
return;
}
- // Prepare permutation
- std::vector<int32> perm;
- for (size_t i = 0; i < ndims - 2; ++i) perm.push_back(i);
- perm.push_back(ndims - 1); // transpose last two dimensions
- perm.push_back(ndims - 2);
- gtl::ArraySlice<int32> permAS(perm);
-
// call implementations
if (m >= n) {
- PerformSVD_MgeqN(context, done, m, n, p, permAS, input, outputS, outputU,
+ PerformSVD_MgeqN(context, done, m, n, p, input, outputS, outputU,
outputV);
} else {
- PerformSVD_MlessN(context, done, m, n, p, permAS, input, outputS, outputU,
+ PerformSVD_MlessN(context, done, m, n, p, input, outputS, outputU,
outputV);
}
}
diff --git a/tensorflow/core/kernels/transpose_functor.h b/tensorflow/core/kernels/transpose_functor.h
index 87569f0275..a2eb0263e8 100644
--- a/tensorflow/core/kernels/transpose_functor.h
+++ b/tensorflow/core/kernels/transpose_functor.h
@@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
-
// Transpose tensor 'in' into tensor 'out' according to dimension
// permutation 'perm'.
//
@@ -46,6 +45,17 @@ template <typename Device>
Status DoConjugateTranspose(const Device& device, const Tensor& in,
const gtl::ArraySlice<int32> perm, Tensor* out);
+// Convenience versions of DoTranspose that only swap the last (inner) two
+// dimensions.
+template <typename Device>
+Status DoMatrixTranspose(const Device& device, const Tensor& in, Tensor* out);
+
+// Convenience versions of DoConjugateTranspose that only swap the last (inner)
+// two dimensions.
+template <typename Device>
+Status DoConjugateMatrixTranspose(const Device& device, const Tensor& in,
+ Tensor* out);
+
// Primary device specific functor to be specialized for each device and type.
template <typename Device, typename T, bool conjugate = false>
struct Transpose {
@@ -131,11 +141,6 @@ inline bool NonSingletonDimensionsAlign(const TensorShape& input_shape,
return true;
}
-// Device-specific naive implementation for transpose.
-template <typename Device, typename T, bool conjugate>
-void TransposeSimple(const Device& d, const Tensor& in,
- const gtl::ArraySlice<int32> perm, Tensor* out);
-
// Uses Eigen to transpose.
template <typename Device, typename T, int NDIMS>
void TransposeUsingEigen(const Device& d, const Tensor& in,
@@ -157,69 +162,78 @@ void TransposeUsingEigen(const Device& d, const Tensor& in,
}
template <typename Device>
-struct DoTransposeImpl {
- static Status run(const Device& d, const Tensor& in,
- const gtl::ArraySlice<int32> perm, bool conjugate,
- Tensor* out) {
- CHECK_GE(in.dims(), 2);
- CHECK_EQ(in.dims(), out->dims());
- CHECK_EQ(in.dims(), perm.size());
- CHECK_EQ(in.dtype(), out->dtype());
- switch (in.dtype()) {
- case DT_BOOL:
- case DT_INT8:
- case DT_QINT8:
- case DT_QUINT8:
- case DT_UINT8:
- Transpose<Device, uint8>::run(d, in, perm, out);
- break;
-
- case DT_BFLOAT16:
- case DT_HALF:
- case DT_INT16:
- case DT_QINT16:
- case DT_QUINT16:
- case DT_UINT16:
- Transpose<Device, uint16>::run(d, in, perm, out);
- break;
-
- case DT_FLOAT:
- case DT_INT32:
- case DT_QINT32:
- Transpose<Device, uint32>::run(d, in, perm, out);
- break;
-
- case DT_DOUBLE:
- case DT_INT64:
- Transpose<Device, uint64>::run(d, in, perm, out);
- break;
-
- case DT_COMPLEX64:
- if (conjugate) {
- Transpose<Device, complex64, true>::run(d, in, perm, out);
- } else {
- Transpose<Device, complex64, false>::run(d, in, perm, out);
- }
- break;
-
- case DT_COMPLEX128:
- if (conjugate) {
- Transpose<Device, complex128, true>::run(d, in, perm, out);
- } else {
- Transpose<Device, complex128, false>::run(d, in, perm, out);
- }
- break;
-
- case DT_STRING:
- Transpose<Device, string>::run(d, in, perm, out);
- break;
-
- default:
- return errors::Unimplemented("Unsupported dtype on CPU: ", in.dtype());
- }
- return Status::OK();
+Status DoTransposeImpl(const Device& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, bool conjugate,
+ Tensor* out) {
+ CHECK_GE(in.dims(), 2);
+ CHECK_EQ(in.dims(), out->dims());
+ CHECK_EQ(in.dims(), perm.size());
+ CHECK_EQ(in.dtype(), out->dtype());
+ switch (in.dtype()) {
+ case DT_BOOL:
+ case DT_INT8:
+ case DT_QINT8:
+ case DT_QUINT8:
+ case DT_UINT8:
+ Transpose<Device, uint8>::run(d, in, perm, out);
+ break;
+
+ case DT_BFLOAT16:
+ case DT_HALF:
+ case DT_INT16:
+ case DT_QINT16:
+ case DT_QUINT16:
+ case DT_UINT16:
+ Transpose<Device, uint16>::run(d, in, perm, out);
+ break;
+
+ case DT_FLOAT:
+ case DT_INT32:
+ case DT_QINT32:
+ Transpose<Device, uint32>::run(d, in, perm, out);
+ break;
+
+ case DT_DOUBLE:
+ case DT_INT64:
+ Transpose<Device, uint64>::run(d, in, perm, out);
+ break;
+
+ case DT_COMPLEX64:
+ if (conjugate) {
+ Transpose<Device, complex64, true>::run(d, in, perm, out);
+ } else {
+ Transpose<Device, complex64, false>::run(d, in, perm, out);
+ }
+ break;
+
+ case DT_COMPLEX128:
+ if (conjugate) {
+ Transpose<Device, complex128, true>::run(d, in, perm, out);
+ } else {
+ Transpose<Device, complex128, false>::run(d, in, perm, out);
+ }
+ break;
+
+ case DT_STRING:
+ Transpose<Device, string>::run(d, in, perm, out);
+ break;
+
+ default:
+ return errors::Unimplemented("Unsupported dtype on CPU: ", in.dtype());
}
-};
+ return Status::OK();
+}
+
+template <typename Device>
+inline Status DoMatrixTransposeImpl(const Device& device, const Tensor& in,
+ bool conjugate, Tensor* out) {
+ const int ndims = in.dims();
+ if (ndims == 0) return Status::OK();
+ TransposePermsVec perm(ndims);
+ std::iota(perm.begin(), perm.end(), 0);
+ std::swap(perm[ndims - 2], perm[ndims - 1]);
+ return DoTransposeImpl(device, in, perm, conjugate, out);
+}
#ifdef TENSORFLOW_USE_SYCL
// For SYCL lets always go through Eigen
diff --git a/tensorflow/core/kernels/transpose_functor_cpu.cc b/tensorflow/core/kernels/transpose_functor_cpu.cc
index b2de012be1..41b73fdaf4 100644
--- a/tensorflow/core/kernels/transpose_functor_cpu.cc
+++ b/tensorflow/core/kernels/transpose_functor_cpu.cc
@@ -29,17 +29,18 @@ limitations under the License.
typedef Eigen::ThreadPoolDevice CPUDevice;
namespace tensorflow {
-namespace internal {
+namespace {
-template <typename Device, typename T, bool conjugate>
-void TransposeSimple(const Device& device, const Tensor& in,
+template <typename T, bool conjugate>
+void TransposeSimple(const CPUDevice& device, const Tensor& in,
const gtl::ArraySlice<int32> perm, Tensor* out) {
const int ndims = in.dims();
gtl::InlinedVector<int64, 8> in_strides = ComputeStride<int64>(in.shape());
gtl::InlinedVector<int64, 8> out_strides = ComputeStride<int64>(out->shape());
const T* p = reinterpret_cast<const T*>(in.tensor_data().data());
T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data())));
- auto transpose_fn = [=](int64 begin, int64 end) {
+ auto transpose_fn = [=, &in_strides, &out_strides, &perm](int64 begin,
+ int64 end) {
for (int64 o_idx = begin; o_idx < end; ++o_idx) {
int64 i_idx = 0;
int64 t = o_idx;
@@ -64,7 +65,7 @@ void TransposeSimple(const Device& device, const Tensor& in,
device.parallelFor(in.NumElements(), cost, std::move(transpose_fn));
}
-} // end namespace internal
+} // namespace
template <typename T, bool conjugate>
struct Transpose<CPUDevice, T, conjugate> {
@@ -88,32 +89,47 @@ struct Transpose<CPUDevice, T, conjugate> {
out);
break;
default:
- internal::TransposeSimple<CPUDevice, T, conjugate>(d, in, perm, out);
+ TransposeSimple<T, conjugate>(d, in, perm, out);
break;
}
}
};
-template <>
-Status DoTranspose(const CPUDevice& device, const Tensor& in,
- const gtl::ArraySlice<int32> perm, Tensor* out) {
- return internal::DoTransposeImpl<CPUDevice>::run(device, in, perm,
- false /* conjugate */, out);
-}
+#define INSTANTIATE(DEVICE) \
+ template <> \
+ Status DoTranspose(const DEVICE& device, const Tensor& in, \
+ const gtl::ArraySlice<int32> perm, Tensor* out) { \
+ return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/false, \
+ out); \
+ } \
+ template <> \
+ Status DoConjugateTranspose(const DEVICE& device, const Tensor& in, \
+ const gtl::ArraySlice<int32> perm, \
+ Tensor* out) { \
+ return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/true, \
+ out); \
+ } \
+ template <> \
+ Status DoMatrixTranspose(const DEVICE& device, const Tensor& in, \
+ Tensor* out) { \
+ return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/false, \
+ out); \
+ } \
+ template <> \
+ Status DoConjugateMatrixTranspose(const DEVICE& device, const Tensor& in, \
+ Tensor* out) { \
+ return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/true, \
+ out); \
+ }
-template <>
-Status DoConjugateTranspose(const CPUDevice& device, const Tensor& in,
- const gtl::ArraySlice<int32> perm, Tensor* out) {
- return internal::DoTransposeImpl<CPUDevice>::run(device, in, perm,
- true /* conjugate */, out);
-}
+INSTANTIATE(CPUDevice)
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
namespace internal {
-template <typename Device, typename T>
-void TransposeSYCL(const Device& d, const Tensor& in,
+template <typename T>
+void TransposeSYCL(const SYCLDevice& d, const Tensor& in,
const gtl::ArraySlice<int32> perm, bool conjugate,
Tensor* out) {
switch (in.dims()) {
@@ -165,19 +181,11 @@ struct Transpose<SYCLDevice, string, conjugate> {
}
};
-template <>
-Status DoTranspose(const SYCLDevice& device, const Tensor& in,
- const gtl::ArraySlice<int32> perm, Tensor* out) {
- return internal::DoTransposeImpl<SYCLDevice>::run(device, in, perm,
- false /* conjugate */, out);
-}
+// Explicit instantiation.
+template struct Transpose<SYCLDevice, string, false>;
-template <>
-Status DoConjugateTranspose(const SYCLDevice& device, const Tensor& in,
- const gtl::ArraySlice<int32> perm, Tensor* out) {
- return internal::DoTransposeImpl<SYCLDevice>::run(device, in, perm,
- true /* conjugate */, out);
-}
+INSTANTIATE(SYCLDevice)
+#undef INSTANTIATE
#endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
index 364baf9a51..493dac9a7c 100644
--- a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
+++ b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
@@ -53,8 +53,8 @@ __global__ void TransposeKernel(int nthreads, const T* src, const int32* buf,
}
}
-template <typename Device, typename T, bool conjugate>
-void TransposeSimple(const Device& d, const Tensor& in,
+template <typename T, bool conjugate>
+void TransposeSimple(const GPUDevice& d, const Tensor& in,
const gtl::ArraySlice<int32> perm, Tensor* out) {
// Ensures we can use 32-bit index.
const int64 nelem = in.NumElements();
@@ -165,23 +165,9 @@ struct TransposeUsingTile<complex128, conjugate> {
}
};
-} // end namespace internal
-
-template <>
-Status DoTranspose(const GPUDevice& device, const Tensor& in,
- const gtl::ArraySlice<int32> perm, Tensor* out) {
- return internal::DoTransposeImpl<GPUDevice>::run(device, in, perm,
- false /* conjugate */, out);
-}
-
-template <>
-Status DoConjugateTranspose(const GPUDevice& device, const Tensor& in,
- const gtl::ArraySlice<int32> perm, Tensor* out) {
- return internal::DoTransposeImpl<GPUDevice>::run(device, in, perm,
- true /* conjugate */, out);
-}
+} // namespace internal
-// Transpose kernel specialized for CPU Device.
+// Transpose kernel specialized for GPU Device.
template <typename T, bool conjugate>
struct Transpose<GPUDevice, T, conjugate> {
static void run(const GPUDevice& d, const Tensor& in,
@@ -216,19 +202,43 @@ struct Transpose<GPUDevice, T, conjugate> {
}
break;
default:
- internal::TransposeSimple<GPUDevice, T, conjugate>(d, in, perm, out);
+ internal::TransposeSimple<T, conjugate>(d, in, perm, out);
break;
}
}
};
-template <>
-struct Transpose<GPUDevice, string> {
+template <bool conjugate>
+struct Transpose<GPUDevice, string, conjugate> {
static void run(const GPUDevice& d, const Tensor& in,
const gtl::ArraySlice<int32> perm, Tensor* out) {
LOG(FATAL) << "Transpose of DT_STRING tensor not supported on GPU.";
}
};
+// Explicit instantiation.
+template struct Transpose<GPUDevice, string, false>;
+
+template <>
+Status DoTranspose(const GPUDevice& device, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/false, out);
+}
+template <>
+Status DoConjugateTranspose(const GPUDevice& device, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/true, out);
+}
+template <>
+Status DoMatrixTranspose(const GPUDevice& device, const Tensor& in,
+ Tensor* out) {
+ return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/false, out);
+}
+template <>
+Status DoConjugateMatrixTranspose(const GPUDevice& device, const Tensor& in,
+ Tensor* out) {
+ return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/true, out);
+}
+
} // namespace tensorflow
#endif // GOOGLE_CUDA