From 47e4d4b6b5742350233a8fd83cd81269792ed286 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 17 Oct 2017 16:19:08 -0700 Subject: Use optimized functor for conjugate transpose in MatrixInverseOp. Introduce convenience functions DoMatrixTranspose and DoConjugateMatrixTranspose. Misc. minor cleanup of templates in transpose_functor*. PiperOrigin-RevId: 172532252 --- tensorflow/core/kernels/cuda_solvers.h | 8 -- tensorflow/core/kernels/cuda_solvers_gpu.cu.cc | 18 --- tensorflow/core/kernels/matrix_inverse_op.cc | 12 +- tensorflow/core/kernels/matrix_solve_op.cc | 9 +- tensorflow/core/kernels/qr_op_impl.h | 9 +- .../core/kernels/self_adjoint_eig_v2_op_gpu.cc | 5 +- tensorflow/core/kernels/svd_op_gpu.cu.cc | 25 ++-- tensorflow/core/kernels/transpose_functor.h | 150 +++++++++++---------- tensorflow/core/kernels/transpose_functor_cpu.cc | 72 +++++----- .../core/kernels/transpose_functor_gpu.cu.cc | 52 ++++--- 10 files changed, 174 insertions(+), 186 deletions(-) diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h index 60c4a0bfb4..eb720b191f 100644 --- a/tensorflow/core/kernels/cuda_solvers.h +++ b/tensorflow/core/kernels/cuda_solvers.h @@ -409,14 +409,6 @@ class DeviceLapackInfo : public ScratchSpace { }; namespace functor { -// Helper functor to transpose and conjugate all matrices in a flattened batch. -template -struct AdjointBatchFunctor { - // We assume that the tensor sizes are correct. - void operator()(const Device& device, - typename TTypes::ConstTensor input, - typename TTypes::Tensor output); -}; // Helper functor to compute the product of diagonal elements in all matrices // in a flattened batch. diff --git a/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc b/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc index 79961c01ca..4171f9d68e 100644 --- a/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc +++ b/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc @@ -29,24 +29,6 @@ namespace functor { typedef Eigen::GpuDevice GPUDevice; -// TODO(rmlarsen): Add a faster custom kernel similar to -// SwapDimension1And2InTensor3 in tensorflow/core/kernels/conv_ops_gpu_3.cu.cc -template -struct AdjointBatchFunctor { - void operator()(const GPUDevice& device, - typename TTypes::ConstTensor input, - typename TTypes::Tensor output) { - const Eigen::array perm({0, 2, 1}); - To32Bit(output).device(device) = To32Bit(input).shuffle(perm).conjugate(); - } -}; - -// Instantiate implementations for the 4 numeric types. -template struct AdjointBatchFunctor; -template struct AdjointBatchFunctor; -template struct AdjointBatchFunctor; -template struct AdjointBatchFunctor; - namespace { // Hacks around missing support for complex arithmetic in nvcc. diff --git a/tensorflow/core/kernels/matrix_inverse_op.cc b/tensorflow/core/kernels/matrix_inverse_op.cc index 832e508bb7..64edfe470d 100644 --- a/tensorflow/core/kernels/matrix_inverse_op.cc +++ b/tensorflow/core/kernels/matrix_inverse_op.cc @@ -33,6 +33,7 @@ limitations under the License. #if GOOGLE_CUDA #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/kernels/cuda_solvers.h" +#include "tensorflow/core/kernels/transpose_functor.h" #endif namespace tensorflow { @@ -135,15 +136,15 @@ class MatrixInverseOpGpu : public AsyncOpKernel { input.shape(), &input_copy), done); auto input_copy_reshaped = input_copy.template flat_inner_dims(); - auto input_reshaped = input.template flat_inner_dims(); const GPUDevice& device = context->eigen_device(); if (!adjoint_) { device.memcpy(input_copy.flat().data(), input.flat().data(), input.NumElements() * sizeof(Scalar)); } else { - functor::AdjointBatchFunctor functor; - functor(device, input_reshaped, input_copy_reshaped); + OP_REQUIRES_OK_ASYNC( + context, DoConjugateMatrixTranspose(device, input, &input_copy), + done); } const int64 batch_size = input_copy_reshaped.dimension(0); @@ -238,10 +239,7 @@ class MatrixInverseOpGpu : public AsyncOpKernel { done); } } - // Callback for checking info after kernels finish. Also capture the - // temporary Tensors/ScratchSpace so they don't get deallocated before the - // kernels run. TODO(rmlarsen): Use move capture once C++14 becomes - // available. + // Callback for checking info after kernels finish. auto info_checker = [context, done]( const Status& status, const std::vector& host_infos) { diff --git a/tensorflow/core/kernels/matrix_solve_op.cc b/tensorflow/core/kernels/matrix_solve_op.cc index 862033e9fa..2e4098dfab 100644 --- a/tensorflow/core/kernels/matrix_solve_op.cc +++ b/tensorflow/core/kernels/matrix_solve_op.cc @@ -181,9 +181,6 @@ class MatrixSolveOpGpu : public AsyncOpKernel { // false, try to reuse the input buffer if this op owns it exclusively. Tensor input_copy; const GPUDevice& device = context->eigen_device(); - std::vector perm(ndims); - std::iota(perm.begin(), perm.end(), 0); - std::swap(perm[ndims - 2], perm[ndims - 1]); if (adjoint_) { // For the adjoint case, it is simpler to always make a transposed copy up // front. @@ -193,7 +190,7 @@ class MatrixSolveOpGpu : public AsyncOpKernel { input.shape(), &input_copy), done); OP_REQUIRES_OK_ASYNC(context, - DoTranspose(device, input, perm, &input_copy), done); + DoMatrixTranspose(device, input, &input_copy), done); } else { OP_REQUIRES_OK_ASYNC( context, @@ -267,7 +264,7 @@ class MatrixSolveOpGpu : public AsyncOpKernel { done); if (nrhs > 1) { OP_REQUIRES_OK_ASYNC( - context, DoTranspose(device, rhs, perm, &transposed_rhs), done); + context, DoMatrixTranspose(device, rhs, &transposed_rhs), done); } else { device.memcpy(transposed_rhs.flat().data(), rhs.flat().data(), @@ -327,7 +324,7 @@ class MatrixSolveOpGpu : public AsyncOpKernel { // 4. Transpose X to get the final result in row-major form. if (nrhs > 1) { OP_REQUIRES_OK_ASYNC( - context, DoTranspose(device, transposed_rhs, perm, output), done); + context, DoMatrixTranspose(device, transposed_rhs, output), done); } else { device.memcpy(output->flat().data(), transposed_rhs.flat().data(), diff --git a/tensorflow/core/kernels/qr_op_impl.h b/tensorflow/core/kernels/qr_op_impl.h index e263eb22f1..c51d601437 100644 --- a/tensorflow/core/kernels/qr_op_impl.h +++ b/tensorflow/core/kernels/qr_op_impl.h @@ -190,12 +190,9 @@ class QrOpGpu : public AsyncOpKernel { // Transpose input, since cuSolver uses column-major, while TensorFlow uses // row-major storage. - std::vector perm(ndims); - std::iota(perm.begin(), perm.end(), 0); - std::swap(perm[ndims - 2], perm[ndims - 1]); const GPUDevice& device = context->eigen_device(); OP_REQUIRES_OK_ASYNC( - context, DoTranspose(device, input, perm, &input_transposed), done); + context, DoMatrixTranspose(device, input, &input_transposed), done); // Compute QR decomposition in-place in input_transposed. std::vector dev_info; @@ -218,7 +215,7 @@ class QrOpGpu : public AsyncOpKernel { // and copy it to the output buffer. if (full_matrices_ || m == n) { OP_REQUIRES_OK_ASYNC( - context, DoTranspose(device, input_transposed, perm, r), done); + context, DoMatrixTranspose(device, input_transposed, r), done); } else { const Scalar alpha(1); const Scalar beta(0); @@ -280,7 +277,7 @@ class QrOpGpu : public AsyncOpKernel { done); } OP_REQUIRES_OK_ASYNC( - context, DoTranspose(device, input_transposed, perm, q), done); + context, DoMatrixTranspose(device, input_transposed, q), done); } // Asynchronously check return status from cuSolver kernels. diff --git a/tensorflow/core/kernels/self_adjoint_eig_v2_op_gpu.cc b/tensorflow/core/kernels/self_adjoint_eig_v2_op_gpu.cc index b0b4f89a27..3a84df07a9 100644 --- a/tensorflow/core/kernels/self_adjoint_eig_v2_op_gpu.cc +++ b/tensorflow/core/kernels/self_adjoint_eig_v2_op_gpu.cc @@ -148,11 +148,8 @@ class SelfAdjointEigV2OpGpu : public AsyncOpKernel { if (compute_v_) { // Transpose eigenvectors now stored in input_copy in column-major form to // output in row-major form. - std::vector perm(ndims); - std::iota(perm.begin(), perm.end(), 0); - std::swap(perm[ndims - 2], perm[ndims - 1]); OP_REQUIRES_OK_ASYNC( - context, DoTranspose(device, input_copy, perm, eigenvectors), done); + context, DoMatrixTranspose(device, input_copy, eigenvectors), done); } // Asynchronously check return status from cuSolver kernels. diff --git a/tensorflow/core/kernels/svd_op_gpu.cu.cc b/tensorflow/core/kernels/svd_op_gpu.cu.cc index 1603a8aeda..dedc2da60b 100644 --- a/tensorflow/core/kernels/svd_op_gpu.cu.cc +++ b/tensorflow/core/kernels/svd_op_gpu.cu.cc @@ -190,8 +190,8 @@ class SvdOpGpu : public AsyncOpKernel { // TODO: can the two cases (MgeqN and MlessN) be simplified, // common boilerplate be reduced, or even combined in one method? void PerformSVD_MgeqN(OpKernelContext* context, DoneCallback done, int64 m, - int64 n, int64 p, const gtl::ArraySlice& perm, - const Tensor& M, Tensor* S, Tensor* U, Tensor* V) { + int64 n, int64 p, const Tensor& M, Tensor* S, Tensor* U, + Tensor* V) { TensorShape shapeRaw = M.shape(); shapeRaw.RemoveLastDims(2); @@ -207,7 +207,7 @@ class SvdOpGpu : public AsyncOpKernel { solver->allocate_scoped_tensor(M.dtype(), input_shape, &input_copy), done); auto device = context->eigen_device(); - OP_REQUIRES_OK_ASYNC(context, DoTranspose(device, M, perm, &input_copy), + OP_REQUIRES_OK_ASYNC(context, DoMatrixTranspose(device, M, &input_copy), done); // I need to transpose U at the end @@ -250,7 +250,7 @@ class SvdOpGpu : public AsyncOpKernel { // Transpose U if (compute_uv_) { - OP_REQUIRES_OK_ASYNC(context, DoTranspose(device, u_copy, perm, U), done); + OP_REQUIRES_OK_ASYNC(context, DoMatrixTranspose(device, u_copy, U), done); } // now check if the SVD operation succeeded or not @@ -259,8 +259,8 @@ class SvdOpGpu : public AsyncOpKernel { // The SVD if m < n void PerformSVD_MlessN(OpKernelContext* context, DoneCallback done, int64 m, - int64 n, int64 p, const gtl::ArraySlice& perm, - const Tensor& M, Tensor* S, Tensor* U, Tensor* V) { + int64 n, int64 p, const Tensor& M, Tensor* S, + Tensor* U, Tensor* V) { // Perform the SVD on M' // Reuse the input buffer or make a copy for the SVD depending on whether @@ -325,7 +325,7 @@ class SvdOpGpu : public AsyncOpKernel { // Transpose V if (compute_uv_) { auto device = context->eigen_device(); - OP_REQUIRES_OK_ASYNC(context, DoTranspose(device, v_copy, perm, V), done); + OP_REQUIRES_OK_ASYNC(context, DoMatrixTranspose(device, v_copy, V), done); } // now check if the SVD operation succeeded or not @@ -389,19 +389,12 @@ class SvdOpGpu : public AsyncOpKernel { return; } - // Prepare permutation - std::vector perm; - for (size_t i = 0; i < ndims - 2; ++i) perm.push_back(i); - perm.push_back(ndims - 1); // transpose last two dimensions - perm.push_back(ndims - 2); - gtl::ArraySlice permAS(perm); - // call implementations if (m >= n) { - PerformSVD_MgeqN(context, done, m, n, p, permAS, input, outputS, outputU, + PerformSVD_MgeqN(context, done, m, n, p, input, outputS, outputU, outputV); } else { - PerformSVD_MlessN(context, done, m, n, p, permAS, input, outputS, outputU, + PerformSVD_MlessN(context, done, m, n, p, input, outputS, outputU, outputV); } } diff --git a/tensorflow/core/kernels/transpose_functor.h b/tensorflow/core/kernels/transpose_functor.h index 87569f0275..a2eb0263e8 100644 --- a/tensorflow/core/kernels/transpose_functor.h +++ b/tensorflow/core/kernels/transpose_functor.h @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace tensorflow { - // Transpose tensor 'in' into tensor 'out' according to dimension // permutation 'perm'. // @@ -46,6 +45,17 @@ template Status DoConjugateTranspose(const Device& device, const Tensor& in, const gtl::ArraySlice perm, Tensor* out); +// Convenience versions of DoTranspose that only swap the last (inner) two +// dimensions. +template +Status DoMatrixTranspose(const Device& device, const Tensor& in, Tensor* out); + +// Convenience versions of DoConjugateTranspose that only swap the last (inner) +// two dimensions. +template +Status DoConjugateMatrixTranspose(const Device& device, const Tensor& in, + Tensor* out); + // Primary device specific functor to be specialized for each device and type. template struct Transpose { @@ -131,11 +141,6 @@ inline bool NonSingletonDimensionsAlign(const TensorShape& input_shape, return true; } -// Device-specific naive implementation for transpose. -template -void TransposeSimple(const Device& d, const Tensor& in, - const gtl::ArraySlice perm, Tensor* out); - // Uses Eigen to transpose. template void TransposeUsingEigen(const Device& d, const Tensor& in, @@ -157,69 +162,78 @@ void TransposeUsingEigen(const Device& d, const Tensor& in, } template -struct DoTransposeImpl { - static Status run(const Device& d, const Tensor& in, - const gtl::ArraySlice perm, bool conjugate, - Tensor* out) { - CHECK_GE(in.dims(), 2); - CHECK_EQ(in.dims(), out->dims()); - CHECK_EQ(in.dims(), perm.size()); - CHECK_EQ(in.dtype(), out->dtype()); - switch (in.dtype()) { - case DT_BOOL: - case DT_INT8: - case DT_QINT8: - case DT_QUINT8: - case DT_UINT8: - Transpose::run(d, in, perm, out); - break; - - case DT_BFLOAT16: - case DT_HALF: - case DT_INT16: - case DT_QINT16: - case DT_QUINT16: - case DT_UINT16: - Transpose::run(d, in, perm, out); - break; - - case DT_FLOAT: - case DT_INT32: - case DT_QINT32: - Transpose::run(d, in, perm, out); - break; - - case DT_DOUBLE: - case DT_INT64: - Transpose::run(d, in, perm, out); - break; - - case DT_COMPLEX64: - if (conjugate) { - Transpose::run(d, in, perm, out); - } else { - Transpose::run(d, in, perm, out); - } - break; - - case DT_COMPLEX128: - if (conjugate) { - Transpose::run(d, in, perm, out); - } else { - Transpose::run(d, in, perm, out); - } - break; - - case DT_STRING: - Transpose::run(d, in, perm, out); - break; - - default: - return errors::Unimplemented("Unsupported dtype on CPU: ", in.dtype()); - } - return Status::OK(); +Status DoTransposeImpl(const Device& d, const Tensor& in, + const gtl::ArraySlice perm, bool conjugate, + Tensor* out) { + CHECK_GE(in.dims(), 2); + CHECK_EQ(in.dims(), out->dims()); + CHECK_EQ(in.dims(), perm.size()); + CHECK_EQ(in.dtype(), out->dtype()); + switch (in.dtype()) { + case DT_BOOL: + case DT_INT8: + case DT_QINT8: + case DT_QUINT8: + case DT_UINT8: + Transpose::run(d, in, perm, out); + break; + + case DT_BFLOAT16: + case DT_HALF: + case DT_INT16: + case DT_QINT16: + case DT_QUINT16: + case DT_UINT16: + Transpose::run(d, in, perm, out); + break; + + case DT_FLOAT: + case DT_INT32: + case DT_QINT32: + Transpose::run(d, in, perm, out); + break; + + case DT_DOUBLE: + case DT_INT64: + Transpose::run(d, in, perm, out); + break; + + case DT_COMPLEX64: + if (conjugate) { + Transpose::run(d, in, perm, out); + } else { + Transpose::run(d, in, perm, out); + } + break; + + case DT_COMPLEX128: + if (conjugate) { + Transpose::run(d, in, perm, out); + } else { + Transpose::run(d, in, perm, out); + } + break; + + case DT_STRING: + Transpose::run(d, in, perm, out); + break; + + default: + return errors::Unimplemented("Unsupported dtype on CPU: ", in.dtype()); } -}; + return Status::OK(); +} + +template +inline Status DoMatrixTransposeImpl(const Device& device, const Tensor& in, + bool conjugate, Tensor* out) { + const int ndims = in.dims(); + if (ndims == 0) return Status::OK(); + TransposePermsVec perm(ndims); + std::iota(perm.begin(), perm.end(), 0); + std::swap(perm[ndims - 2], perm[ndims - 1]); + return DoTransposeImpl(device, in, perm, conjugate, out); +} #ifdef TENSORFLOW_USE_SYCL // For SYCL lets always go through Eigen diff --git a/tensorflow/core/kernels/transpose_functor_cpu.cc b/tensorflow/core/kernels/transpose_functor_cpu.cc index b2de012be1..41b73fdaf4 100644 --- a/tensorflow/core/kernels/transpose_functor_cpu.cc +++ b/tensorflow/core/kernels/transpose_functor_cpu.cc @@ -29,17 +29,18 @@ limitations under the License. typedef Eigen::ThreadPoolDevice CPUDevice; namespace tensorflow { -namespace internal { +namespace { -template -void TransposeSimple(const Device& device, const Tensor& in, +template +void TransposeSimple(const CPUDevice& device, const Tensor& in, const gtl::ArraySlice perm, Tensor* out) { const int ndims = in.dims(); gtl::InlinedVector in_strides = ComputeStride(in.shape()); gtl::InlinedVector out_strides = ComputeStride(out->shape()); const T* p = reinterpret_cast(in.tensor_data().data()); T* q = reinterpret_cast(const_cast((out->tensor_data().data()))); - auto transpose_fn = [=](int64 begin, int64 end) { + auto transpose_fn = [=, &in_strides, &out_strides, &perm](int64 begin, + int64 end) { for (int64 o_idx = begin; o_idx < end; ++o_idx) { int64 i_idx = 0; int64 t = o_idx; @@ -64,7 +65,7 @@ void TransposeSimple(const Device& device, const Tensor& in, device.parallelFor(in.NumElements(), cost, std::move(transpose_fn)); } -} // end namespace internal +} // namespace template struct Transpose { @@ -88,32 +89,47 @@ struct Transpose { out); break; default: - internal::TransposeSimple(d, in, perm, out); + TransposeSimple(d, in, perm, out); break; } } }; -template <> -Status DoTranspose(const CPUDevice& device, const Tensor& in, - const gtl::ArraySlice perm, Tensor* out) { - return internal::DoTransposeImpl::run(device, in, perm, - false /* conjugate */, out); -} +#define INSTANTIATE(DEVICE) \ + template <> \ + Status DoTranspose(const DEVICE& device, const Tensor& in, \ + const gtl::ArraySlice perm, Tensor* out) { \ + return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/false, \ + out); \ + } \ + template <> \ + Status DoConjugateTranspose(const DEVICE& device, const Tensor& in, \ + const gtl::ArraySlice perm, \ + Tensor* out) { \ + return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/true, \ + out); \ + } \ + template <> \ + Status DoMatrixTranspose(const DEVICE& device, const Tensor& in, \ + Tensor* out) { \ + return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/false, \ + out); \ + } \ + template <> \ + Status DoConjugateMatrixTranspose(const DEVICE& device, const Tensor& in, \ + Tensor* out) { \ + return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/true, \ + out); \ + } -template <> -Status DoConjugateTranspose(const CPUDevice& device, const Tensor& in, - const gtl::ArraySlice perm, Tensor* out) { - return internal::DoTransposeImpl::run(device, in, perm, - true /* conjugate */, out); -} +INSTANTIATE(CPUDevice) #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; namespace internal { -template -void TransposeSYCL(const Device& d, const Tensor& in, +template +void TransposeSYCL(const SYCLDevice& d, const Tensor& in, const gtl::ArraySlice perm, bool conjugate, Tensor* out) { switch (in.dims()) { @@ -165,19 +181,11 @@ struct Transpose { } }; -template <> -Status DoTranspose(const SYCLDevice& device, const Tensor& in, - const gtl::ArraySlice perm, Tensor* out) { - return internal::DoTransposeImpl::run(device, in, perm, - false /* conjugate */, out); -} +// Explicit instantiation. +template struct Transpose; -template <> -Status DoConjugateTranspose(const SYCLDevice& device, const Tensor& in, - const gtl::ArraySlice perm, Tensor* out) { - return internal::DoTransposeImpl::run(device, in, perm, - true /* conjugate */, out); -} +INSTANTIATE(SYCLDevice) +#undef INSTANTIATE #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc index 364baf9a51..493dac9a7c 100644 --- a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc @@ -53,8 +53,8 @@ __global__ void TransposeKernel(int nthreads, const T* src, const int32* buf, } } -template -void TransposeSimple(const Device& d, const Tensor& in, +template +void TransposeSimple(const GPUDevice& d, const Tensor& in, const gtl::ArraySlice perm, Tensor* out) { // Ensures we can use 32-bit index. const int64 nelem = in.NumElements(); @@ -165,23 +165,9 @@ struct TransposeUsingTile { } }; -} // end namespace internal - -template <> -Status DoTranspose(const GPUDevice& device, const Tensor& in, - const gtl::ArraySlice perm, Tensor* out) { - return internal::DoTransposeImpl::run(device, in, perm, - false /* conjugate */, out); -} - -template <> -Status DoConjugateTranspose(const GPUDevice& device, const Tensor& in, - const gtl::ArraySlice perm, Tensor* out) { - return internal::DoTransposeImpl::run(device, in, perm, - true /* conjugate */, out); -} +} // namespace internal -// Transpose kernel specialized for CPU Device. +// Transpose kernel specialized for GPU Device. template struct Transpose { static void run(const GPUDevice& d, const Tensor& in, @@ -216,19 +202,43 @@ struct Transpose { } break; default: - internal::TransposeSimple(d, in, perm, out); + internal::TransposeSimple(d, in, perm, out); break; } } }; -template <> -struct Transpose { +template +struct Transpose { static void run(const GPUDevice& d, const Tensor& in, const gtl::ArraySlice perm, Tensor* out) { LOG(FATAL) << "Transpose of DT_STRING tensor not supported on GPU."; } }; +// Explicit instantiation. +template struct Transpose; + +template <> +Status DoTranspose(const GPUDevice& device, const Tensor& in, + const gtl::ArraySlice perm, Tensor* out) { + return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/false, out); +} +template <> +Status DoConjugateTranspose(const GPUDevice& device, const Tensor& in, + const gtl::ArraySlice perm, Tensor* out) { + return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/true, out); +} +template <> +Status DoMatrixTranspose(const GPUDevice& device, const Tensor& in, + Tensor* out) { + return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/false, out); +} +template <> +Status DoConjugateMatrixTranspose(const GPUDevice& device, const Tensor& in, + Tensor* out) { + return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/true, out); +} + } // namespace tensorflow #endif // GOOGLE_CUDA -- cgit v1.2.3