diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-02-16 09:52:33 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-02-16 11:08:17 -0800 |
commit | 6804c9cafc11fa73be3fdb057e033f0304661622 (patch) | |
tree | 87a1c65806c0bf73263855c659e112977ddb27f1 | |
parent | cf661010261c80b97ab68c5aec383b454ef34f18 (diff) |
Rewrite of transpose so that its compilation time is tolerable. Main
approach:
1. Do not instantiate templates for all tf types. Instead, various
types is casted to one of uint8/uint16/uint32/uint64/string.
2. Use eigen3 for 2/3/4 rank tensors' transpose and fallback to a
naive routine which is only templatized on type T but not on
NDIMS.
Change: 114763098
-rw-r--r-- | tensorflow/core/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/reduction_ops_common.h | 18 | ||||
-rw-r--r-- | tensorflow/core/kernels/transpose_op.cc | 153 | ||||
-rw-r--r-- | tensorflow/core/kernels/transpose_op.h | 37 | ||||
-rw-r--r-- | tensorflow/core/kernels/transpose_op_cpu.cc | 119 | ||||
-rw-r--r-- | tensorflow/core/kernels/transpose_op_functor.h | 66 | ||||
-rw-r--r-- | tensorflow/core/kernels/transpose_op_gpu.cu.cc | 153 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/transpose_op_test.py | 43 |
8 files changed, 401 insertions, 189 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 51bd76bce8..11788ebd3f 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -686,6 +686,7 @@ filegroup( "//tensorflow/core:kernels/topk_op.cc", "//tensorflow/core:kernels/training_ops.cc", "//tensorflow/core:kernels/transpose_op.cc", + "//tensorflow/core:kernels/transpose_op_cpu.cc", "//tensorflow/core:kernels/where_op.cc", "//tensorflow/core:kernels/xent_op.cc", ], diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h index bd03c4183d..f1904b7255 100644 --- a/tensorflow/core/kernels/reduction_ops_common.h +++ b/tensorflow/core/kernels/reduction_ops_common.h @@ -23,9 +23,6 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "tensorflow/core/kernels/reduction_ops.h" -#include "tensorflow/core/kernels/transpose_op.h" - #include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" @@ -33,6 +30,8 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/reduction_ops.h" +#include "tensorflow/core/kernels/transpose_op_functor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/logging.h" @@ -197,7 +196,12 @@ class ReductionHelper { } // Shape of shuffled input - const gtl::ArraySlice<int64> data_reshape() const { return data_reshape_; } + TensorShape data_reshape() const { + const int dims = data_reshape_.size(); + TensorShape shape; + for (auto s : data_reshape_) shape.AddDim(s); + return shape; + } // Shape with all reduction dimensions at the end TensorShape shuffled_shape() { @@ -315,12 +319,14 @@ class ReductionOp : public OpKernel { } else { // If we don't hit one of the cases above, transpose the data so that // all reduced dimensions are last and reuse the 2-D -> 1-D case. + Tensor data_reshaped; + CHECK(data_reshaped.CopyFrom(data, helper.data_reshape())); Tensor shuffled; OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, helper.shuffled_shape(), &shuffled, alloc_attr)); - TransposeTensor<Device, T>(d, data, helper.data_reshape(), - helper.permutation(), &shuffled); + OP_REQUIRES_OK( + ctx, DoTranspose(d, data_reshaped, helper.permutation(), &shuffled)); const int64 unreduced = tmp_out.NumElements(); const int64 reduced = shuffled.NumElements() / unreduced; const Tensor& const_shuffled = shuffled; diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index e7294cb3e0..530e333bff 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -18,6 +18,9 @@ limitations under the License. #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/transpose_op.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/transpose_op_functor.h" @@ -27,9 +30,6 @@ limitations under the License. namespace tensorflow { -typedef Eigen::ThreadPoolDevice CPUDevice; -typedef Eigen::GpuDevice GPUDevice; - // inv = InvertPermutationOp(T<int32> p) takes a permutation of // integers 0, 1, ..., n - 1 and returns the inverted // permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n). @@ -89,26 +89,21 @@ REGISTER_KERNEL_BUILDER(Name("InvertPermutation") // REQUIRES: input.dims() == perm.size(). // REQUIRES: perm is a permutation. -template <typename Device, typename T> -TransposeOp<Device, T>::TransposeOp(OpKernelConstruction* context) - : OpKernel(context) {} - -template <typename Device, typename T> -void TransposeOp<Device, T>::Compute(OpKernelContext* context) { - const Tensor& input = context->input(0); - const Tensor& perm = context->input(1); +void TransposeOp::Compute(OpKernelContext* ctx) { + const Tensor& input = ctx->input(0); + const Tensor& perm = ctx->input(1); // Preliminary validation of sizes. - OP_REQUIRES(context, TensorShapeUtils::IsVector(perm.shape()), + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm.shape()), errors::InvalidArgument("perm must be a vector, not ", perm.shape().DebugString())); auto Vperm = perm.vec<int32>(); const int dims = input.dims(); static const int kMinDims = 0; static const int kMaxDims = 10; - OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims, + OP_REQUIRES(ctx, kMinDims <= dims && dims <= kMaxDims, errors::Unimplemented("Transposing a tensor of rank ", dims, " is not implemented.")); - OP_REQUIRES(context, dims == Vperm.size(), + OP_REQUIRES(ctx, dims == Vperm.size(), errors::InvalidArgument( "transpose expects a vector of size ", input.dims(), ". But input(1) is a vector of size ", Vperm.size())); @@ -120,118 +115,62 @@ void TransposeOp<Device, T>::Compute(OpKernelContext* context) { gtl::InlinedVector<bool, 8> bits(dims); for (const int32 d : permutation) { OP_REQUIRES( - context, 0 <= d && d < dims, + ctx, 0 <= d && d < dims, errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")")); bits[d] = true; shape.AddDim(input.dim_size(d)); } for (int i = 0; i < dims; ++i) { - OP_REQUIRES(context, bits[i], errors::InvalidArgument( - i, " is missing from {", - str_util::Join(permutation, ","), "}.")); + OP_REQUIRES(ctx, bits[i], errors::InvalidArgument( + i, " is missing from {", + str_util::Join(permutation, ","), "}.")); } // 0-D and 1-D transposes do nothing if (dims <= 1) { - context->set_output(0, input); + ctx->set_output(0, input); return; } Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output)); - TransposeTensor<Device, T>(context->eigen_device<Device>(), input, - input.shape().dim_sizes(), permutation, output); -} - -template <typename Device, typename T> -void TransposeTensor(const Device& device, const Tensor& input, - const gtl::ArraySlice<int64> input_shape, - gtl::ArraySlice<int32> permutation, Tensor* output) { - const int dims = input_shape.size(); - CHECK(permutation.size() == dims); - if (input.NumElements() == 0) { - return; - } - switch (dims) { -#define EXPAND_DIM(N) \ - case N: { \ - functor::TransposeFunctor<Device, T, N> func; \ - func(device, output->tensor<T, N>(), input.shaped<T, N>(input_shape), \ - permutation.data()); \ - break; \ + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output)); + if (shape.num_elements() > 0) { + OP_REQUIRES_OK(ctx, DoTranspose(ctx, input, permutation, output)); } - EXPAND_DIM(2); - EXPAND_DIM(3); - EXPAND_DIM(4); - EXPAND_DIM(5); - EXPAND_DIM(6); - EXPAND_DIM(7); - EXPAND_DIM(8); - EXPAND_DIM(9); - EXPAND_DIM(10); - default: - LOG(FATAL) << "Unexpected dims: " << dims; - } -#undef EXPAND_CASE } -namespace functor { - -template <typename Device, typename T, int NDIMS> -void TransposeMaybeInline(const Device& d, - typename TTypes<T, NDIMS>::Tensor out, - typename TTypes<T, NDIMS>::ConstTensor in, - const int* perm) { - // perm[] is a permutation of 0, 1, ..., NDIMS-1. perm[] is on CPU. - Eigen::array<int, NDIMS> p; - for (int i = 0; i < NDIMS; ++i) p[i] = perm[i]; - if (out.size() * sizeof(T) < 131072) { // Small transpose on a CPU: do inline - out = in.shuffle(p); - } else { - out.device(d) = in.shuffle(p); - } +Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, + gtl::ArraySlice<int32> perm, Tensor* out) { + typedef Eigen::ThreadPoolDevice CPUDevice; + return ::tensorflow::DoTranspose(ctx->eigen_device<CPUDevice>(), in, perm, + out); } -template <typename T, int NDIMS> -struct TransposeFunctor<CPUDevice, T, NDIMS> { - void operator()(const CPUDevice& d, typename TTypes<T, NDIMS>::Tensor out, - typename TTypes<T, NDIMS>::ConstTensor in, const int* perm) { - TransposeMaybeInline<CPUDevice, T, NDIMS>(d, out, in, perm); - } -}; +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER(Name("Transpose") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .HostMemory("perm"), \ + TransposeCpuOp); +TF_CALL_ALL_TYPES(REGISTER) +#undef REGISTER -} // namespace functor - -#define REGISTER(D, T) \ - template class TransposeOp<D##Device, T>; \ - REGISTER_KERNEL_BUILDER(Name("Transpose") \ - .Device(DEVICE_##D) \ - .TypeConstraint<T>("T") \ - .HostMemory("perm"), \ - TransposeOp<D##Device, T>); \ - template void TransposeTensor<D##Device, T>( \ - const D##Device&, const Tensor&, const gtl::ArraySlice<int64>, \ - gtl::ArraySlice<int32>, Tensor*); -REGISTER(CPU, float); -REGISTER(CPU, double); -REGISTER(CPU, complex64); -REGISTER(CPU, uint8); -REGISTER(CPU, int8); -REGISTER(CPU, int16); -REGISTER(CPU, int32); -REGISTER(CPU, int64); -REGISTER(CPU, string); -REGISTER(CPU, bool); #if GOOGLE_CUDA -REGISTER(GPU, uint8); -REGISTER(GPU, int8); -REGISTER(GPU, int16); -REGISTER(GPU, int32); -REGISTER(GPU, int64); -REGISTER(GPU, float); -REGISTER(GPU, double); -REGISTER(GPU, complex64); -REGISTER(GPU, bool); -#endif +Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, + gtl::ArraySlice<int32> perm, Tensor* out) { + typedef Eigen::GpuDevice GPUDevice; + return ::tensorflow::DoTranspose(ctx->eigen_device<GPUDevice>(), in, perm, + out); +} + +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER(Name("Transpose") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<T>("T") \ + .HostMemory("perm"), \ + TransposeGpuOp); +TF_CALL_NUMBER_TYPES(REGISTER); #undef REGISTER +#endif + } // namespace tensorflow diff --git a/tensorflow/core/kernels/transpose_op.h b/tensorflow/core/kernels/transpose_op.h index 15cd1b6488..4d546298b4 100644 --- a/tensorflow/core/kernels/transpose_op.h +++ b/tensorflow/core/kernels/transpose_op.h @@ -1,4 +1,4 @@ -/* Copyright 2015 Google Inc. All Rights Reserved. +/* Copyright 2016 Google Inc. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,24 +16,39 @@ limitations under the License. #ifndef TENSORFLOW_KERNELS_TRANSPOSE_OP_H_ #define TENSORFLOW_KERNELS_TRANSPOSE_OP_H_ -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/tensor.h" namespace tensorflow { -template <typename Device, typename T> class TransposeOp : public OpKernel { public: - explicit TransposeOp(OpKernelConstruction* context); - void Compute(OpKernelContext* context) override; + explicit TransposeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; + + protected: + virtual Status DoTranspose(OpKernelContext* ctx, const Tensor& in, + gtl::ArraySlice<int32> perm, Tensor* out) = 0; }; -// Exposed for use in reduction ops -template <typename Device, typename T> -void TransposeTensor(const Device& device, const Tensor& input, - const gtl::ArraySlice<int64> input_shape, - gtl::ArraySlice<int32> permutation, Tensor* output); +class TransposeCpuOp : public TransposeOp { + public: + explicit TransposeCpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {} + + protected: + Status DoTranspose(OpKernelContext* ctx, const Tensor& in, + gtl::ArraySlice<int32> perm, Tensor* out) override; +}; + +class TransposeGpuOp : public TransposeOp { + public: + explicit TransposeGpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {} + + protected: + Status DoTranspose(OpKernelContext* ctx, const Tensor& in, + gtl::ArraySlice<int32> perm, Tensor* out) override; +}; } // namespace tensorflow diff --git a/tensorflow/core/kernels/transpose_op_cpu.cc b/tensorflow/core/kernels/transpose_op_cpu.cc new file mode 100644 index 0000000000..ea039bf471 --- /dev/null +++ b/tensorflow/core/kernels/transpose_op_cpu.cc @@ -0,0 +1,119 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/transpose_op_functor.h" + +namespace tensorflow { +namespace internal { + +template <typename Device, typename T> +void TransposeSimple(const Device& d, const Tensor& in, + const gtl::ArraySlice<int32> perm, Tensor* out) { + const int ndims = in.dims(); + gtl::InlinedVector<int64, 8> in_strides(ndims); + ComputeStride(in.shape(), in_strides.data()); + gtl::InlinedVector<int64, 8> out_strides(ndims); + ComputeStride(out->shape(), out_strides.data()); + const int64 nelem = in.NumElements(); + const T* p = reinterpret_cast<const T*>(in.tensor_data().data()); + T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data()))); + + // TODO(zhifengc): Shard by range. + // TODO(zhifengc): Avoids the division. + for (int64 o_idx = 0; o_idx < nelem; ++o_idx) { + int64 i_idx = 0; + int64 t = o_idx; + for (int i = 0; i < ndims; ++i) { + i_idx += (t / out_strides[i]) * in_strides[perm[i]]; + t = t % out_strides[i]; + } + q[o_idx] = p[i_idx]; + } +} + +template <typename Device, typename T, int NDIMS> +void TransposeUsingEigen(const Device& d, const Tensor& in, + const gtl::ArraySlice<int32> perm, Tensor* out) { + Eigen::array<int, NDIMS> p; + for (int i = 0; i < NDIMS; ++i) p[i] = perm[i]; + auto x = typename TTypes<T, NDIMS>::ConstTensor( + reinterpret_cast<const T*>(in.tensor_data().data()), + in.shape().AsEigenDSizes<NDIMS>()); + auto y = typename TTypes<T, NDIMS>::Tensor( + reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data())), + out->shape().AsEigenDSizes<NDIMS>()); + auto nelem = in.NumElements(); + static const int64 kInlineThreshold = 131072; + if (nelem * sizeof(T) < kInlineThreshold) { + // Don't bother multi-threaded transpose if 'in' is small. + y = x.shuffle(p); + } else { + y.device(d) = x.shuffle(p); + } +} + +} // end namespace internal + +typedef Eigen::ThreadPoolDevice Device; + +template <> +Status DoTranspose<Device>(const Device& d, const Tensor& in, + const gtl::ArraySlice<int32> perm, Tensor* out) { + CHECK_GE(in.dims(), 2); + CHECK_EQ(in.dims(), out->dims()); + CHECK_EQ(in.dims(), perm.size()); + CHECK_EQ(in.dtype(), out->dtype()); + switch (in.dtype()) { + case DT_BOOL: + case DT_INT8: + case DT_QINT8: + case DT_QUINT8: + case DT_UINT8: + internal::Transpose<Device, uint8>(d, in, perm, out); + break; + + case DT_BFLOAT16: + case DT_INT16: + case DT_QINT16: + case DT_QUINT16: + case DT_UINT16: + internal::Transpose<Device, uint16>(d, in, perm, out); + break; + + case DT_FLOAT: + case DT_INT32: + case DT_QINT32: + internal::Transpose<Device, uint32>(d, in, perm, out); + break; + + case DT_COMPLEX64: + case DT_DOUBLE: + case DT_INT64: + internal::Transpose<Device, uint64>(d, in, perm, out); + break; + + case DT_STRING: + internal::Transpose<Device, string>(d, in, perm, out); + break; + + default: + return errors::Unimplemented("Unsupported dtype on CPU: ", in.dtype()); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/transpose_op_functor.h b/tensorflow/core/kernels/transpose_op_functor.h index e478c6d966..b79c3c7f2f 100644 --- a/tensorflow/core/kernels/transpose_op_functor.h +++ b/tensorflow/core/kernels/transpose_op_functor.h @@ -16,28 +16,66 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_ #define TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_ -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" namespace tensorflow { -namespace functor { -template <typename Device, typename T, int NDIMS> -void Transpose(const Device& d, typename TTypes<T, NDIMS>::Tensor out, - typename TTypes<T, NDIMS>::ConstTensor in, const int* perm) { - // perm[] is a permutation of 0, 1, ..., NDIMS-1. perm[] is on CPU. - Eigen::array<int, NDIMS> p; - for (int i = 0; i < NDIMS; ++i) p[i] = perm[i]; - out.device(d) = in.shuffle(p); +// Transpose tensor 'in' into tensor 'out' according to dimension +// permutation 'perm'. +// +// REQUIRES: in.dtype() == out->dtype() +// REQUIRES: in.dims() == out->dims() +// REQUIRES: in.dims() == perm.size() +// REQUIRES: in.dim_size(perm[i]) == out->dim_size(i) +template <typename Device> +Status DoTranspose(const Device& device, const Tensor& in, + const gtl::ArraySlice<int32> perm, Tensor* out); + +// Implementation details. +namespace internal { + +// Helper to compute 'strides' given a tensor 'shape'. I.e., +// strides[i] = prod(shape.dim_size[(i+1):]) +template <typename Index> +void ComputeStride(const TensorShape& shape, Index* strides) { + const int ndims = shape.dims(); + Index stride = 1; + for (int i = ndims - 1; i >= 0; --i) { + strides[i] = stride; + stride *= static_cast<Index>(shape.dim_size(i)); + } } +// Device-specific naive implementation for tranpose. +template <typename Device, typename T> +void TransposeSimple(const Device& d, const Tensor& in, + const gtl::ArraySlice<int32> perm, Tensor* out); + +// Uses Eigen to transpose. template <typename Device, typename T, int NDIMS> -struct TransposeFunctor { - void operator()(const Device& d, typename TTypes<T, NDIMS>::Tensor out, - typename TTypes<T, NDIMS>::ConstTensor in, const int* perm); -}; +void TransposeUsingEigen(const Device& d, const Tensor& in, + const gtl::ArraySlice<int32> perm, Tensor* out); -} // namespace functor +template <typename Device, typename T> +void Transpose(const Device& d, const Tensor& in, + const gtl::ArraySlice<int32> perm, Tensor* out) { + switch (in.dims()) { + case 2: + TransposeUsingEigen<Device, T, 2>(d, in, perm, out); + break; + case 3: + TransposeUsingEigen<Device, T, 3>(d, in, perm, out); + break; + case 4: + TransposeUsingEigen<Device, T, 4>(d, in, perm, out); + break; + default: + TransposeSimple<Device, T>(d, in, perm, out); + break; + } +} +} // namespace internal } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_ diff --git a/tensorflow/core/kernels/transpose_op_gpu.cu.cc b/tensorflow/core/kernels/transpose_op_gpu.cu.cc index b42a69c785..238ee0a090 100644 --- a/tensorflow/core/kernels/transpose_op_gpu.cu.cc +++ b/tensorflow/core/kernels/transpose_op_gpu.cu.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 Google Inc. All Rights Reserved. +/* Copyright 2016 Google Inc. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,46 +17,123 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/framework/numeric_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/kernels/transpose_op_functor.h" -#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { -namespace functor { - -template <typename T, int NDIMS> -struct TransposeFunctor<Eigen::GpuDevice, T, NDIMS> { - void operator()(const Eigen::GpuDevice& d, - typename TTypes<T, NDIMS>::Tensor out, - typename TTypes<T, NDIMS>::ConstTensor in, const int* perm) { - Transpose<Eigen::GpuDevice, T, NDIMS>(d, out, in, perm); +namespace internal { + +template <typename T> +__global__ void TransposeKernel(int nthreads, const T* src, const int32* buf, + const int32 ndims, T* dst) { + const int32* in_strides = buf; + const int32* out_strides = buf + ndims; + const int32* perm = buf + ndims * 2; + CUDA_1D_KERNEL_LOOP(o_idx, nthreads) { + int32 i_idx = 0; + int32 t = o_idx; + for (int i = 0; i < ndims; ++i) { + i_idx += (t / out_strides[i]) * in_strides[perm[i]]; + t = t % out_strides[i]; + } + dst[o_idx] = ldg(src + i_idx); + } +} + +template <typename Device, typename T> +void TransposeSimple(const Device& d, const Tensor& in, + const gtl::ArraySlice<int32> perm, Tensor* out) { + // Ensures we can use 32-bit index. + const int64 nelem = in.NumElements(); + CHECK_LT(nelem, kint32max) << "Tensor too large to transpose on GPU"; + // Pack strides and permutation into one buffer. + const int32 ndims = in.dims(); + gtl::InlinedVector<int32, 16> host_buf(ndims * 3); + // Input strides. + ComputeStride(in.shape(), &host_buf[0]); + // Output strides. + ComputeStride(out->shape(), &host_buf[ndims]); + // Dimension permutation. + for (int i = 0; i < ndims; ++i) { + host_buf[ndims * 2 + i] = perm[i]; + } + // Copies the input strides, output strides and permutation to the device. + auto num_bytes = sizeof(int64) * host_buf.size(); + auto dev_buf = d.allocate(num_bytes); + // NOTE: host_buf is not allocated by CudaHostAllocator, and + // therefore we are doing a sync copy effectively. + d.memcpyHostToDevice(dev_buf, host_buf.data(), num_bytes); + // Launch kernel to q[...] = p[...]. + const T* p = reinterpret_cast<const T*>(in.tensor_data().data()); + T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data()))); + CudaLaunchConfig cfg = GetCudaLaunchConfig(nelem, d); + TransposeKernel<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>( + cfg.virtual_thread_count, p, reinterpret_cast<const int32*>(dev_buf), + ndims, q); + // Safe to deallocate immediately after the kernel launch. + d.deallocate(dev_buf); +} + +template <typename Device, typename T, int NDIMS> +void TransposeUsingEigen(const Device& d, const Tensor& in, + const gtl::ArraySlice<int32> perm, Tensor* out) { + Eigen::array<int, NDIMS> p; + for (int i = 0; i < NDIMS; ++i) p[i] = perm[i]; + auto x = typename TTypes<T, NDIMS>::ConstTensor( + reinterpret_cast<const T*>(in.tensor_data().data()), + in.shape().AsEigenDSizes<NDIMS>()); + auto y = typename TTypes<T, NDIMS>::Tensor( + reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data())), + out->shape().AsEigenDSizes<NDIMS>()); + y.device(d) = x.shuffle(p); +} + +} // end namespace internal + +typedef Eigen::GpuDevice Device; + +template <> +Status DoTranspose<Device>(const Device& d, const Tensor& in, + const gtl::ArraySlice<int32> perm, Tensor* out) { + CHECK_GE(in.dims(), 2); + CHECK_EQ(in.dims(), out->dims()); + CHECK_EQ(in.dims(), perm.size()); + CHECK_EQ(in.dtype(), out->dtype()); + switch (in.dtype()) { + case DT_BOOL: + case DT_INT8: + case DT_QINT8: + case DT_QUINT8: + case DT_UINT8: + internal::Transpose<Device, uint8>(d, in, perm, out); + break; + + case DT_BFLOAT16: + case DT_INT16: + case DT_QINT16: + case DT_QUINT16: + case DT_UINT16: + internal::Transpose<Device, uint16>(d, in, perm, out); + break; + + case DT_FLOAT: + case DT_INT32: + case DT_QINT32: + internal::Transpose<Device, uint32>(d, in, perm, out); + break; + + case DT_COMPLEX64: + case DT_DOUBLE: + case DT_INT64: + internal::Transpose<Device, uint64>(d, in, perm, out); + break; + + default: + return errors::Unimplemented("Unsupported dtype on GPU: ", in.dtype()); } -}; - -#define DEFINE(T, N) template struct TransposeFunctor<Eigen::GpuDevice, T, N>; -#define DEFINE_DIM(T) \ - DEFINE(T, 2); \ - DEFINE(T, 3); \ - DEFINE(T, 4); \ - DEFINE(T, 5); \ - DEFINE(T, 6); \ - DEFINE(T, 7); \ - DEFINE(T, 8); \ - DEFINE(T, 9); \ - DEFINE(T, 10); -DEFINE_DIM(uint8); -DEFINE_DIM(int8); -DEFINE_DIM(int16); -DEFINE_DIM(int32); -DEFINE_DIM(int64); -DEFINE_DIM(float); -DEFINE_DIM(double); -DEFINE_DIM(complex64); -DEFINE_DIM(bool); -#undef DEFINE_DIM -#undef DEFINE - -} // end namespace functor -} // end namespace tensorflow + return Status::OK(); +} +} // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py index 3e980ba198..db0cc89e20 100644 --- a/tensorflow/python/kernel_tests/transpose_op_test.py +++ b/tensorflow/python/kernel_tests/transpose_op_test.py @@ -58,6 +58,7 @@ class TransposeTest(tf.test.TestCase): inx = tf.convert_to_tensor(x) y = tf.transpose(inx, p) tf_ans = y.eval() + self.assertAllEqual(np_ans, tf_ans) self.assertShapeEqual(np_ans, y) @@ -101,8 +102,9 @@ class TransposeTest(tf.test.TestCase): self.assertAllClose(tf_a_cpu, tf_a_gpu, 1e-6, 1e-6) self.assertAllClose(tf_g_cpu, tf_g_gpu, 1e-6, 1e-6) - def _testCpu(self, x): + def _testBoth(self, x): self._compare(x, use_gpu=False) + self._compare(x, use_gpu=True) def test1D(self): self._compareCpu(np.arange(0., 2), [0]) @@ -118,33 +120,48 @@ class TransposeTest(tf.test.TestCase): self._compare_cpu_gpu(np.arange(0, 21).reshape([3, 7]).astype(np.float32)) self._compare_cpu_gpu( np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.float32)) + self._compare_cpu_gpu(np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype( + np.float32)) def testDouble(self): self._compare_cpu_gpu(np.arange(0, 21).reshape([3, 7]).astype(np.float64)) self._compare_cpu_gpu( np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.float64)) + self._compare_cpu_gpu(np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype( + np.float64)) def testSComplex(self): - self._testCpu(np.complex(1, 2) * np.arange(0, 21).reshape( - [3, 7]).astype(np.complex64)) - self._testCpu(np.complex(1, 2) * np.arange(0, 210).reshape( - [2, 3, 5, 7]).astype(np.complex64)) + self._testBoth(np.complex(1, 2) * + np.arange(0, 21).reshape([3, 7]).astype(np.complex64)) + self._testBoth(np.complex(1, 2) * + np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.complex64)) + self._testBoth( + np.complex(1, 2) * + np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype(np.complex64)) def testInt8(self): - self._testCpu(np.arange(0, 21).reshape([3, 7]).astype(np.int8)) - self._testCpu(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int8)) + self._testBoth(np.arange(0, 21).reshape([3, 7]).astype(np.int8)) + self._testBoth(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int8)) + self._testBoth(np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype( + np.int8)) def testInt16(self): - self._testCpu(np.arange(0, 21).reshape([3, 7]).astype(np.int16)) - self._testCpu(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int16)) + self._testBoth(np.arange(0, 21).reshape([3, 7]).astype(np.int16)) + self._testBoth(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int16)) + self._testBoth(np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype( + np.int16)) def testInt32(self): - self._testCpu(np.arange(0, 21).reshape([3, 7]).astype(np.int32)) - self._testCpu(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int32)) + self._testBoth(np.arange(0, 21).reshape([3, 7]).astype(np.int32)) + self._testBoth(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int32)) + self._testBoth(np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype( + np.int32)) def testInt64(self): - self._testCpu(np.arange(0, 21).reshape([3, 7]).astype(np.int64)) - self._testCpu(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int64)) + self._testBoth(np.arange(0, 21).reshape([3, 7]).astype(np.int64)) + self._testBoth(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int64)) + self._testBoth(np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype( + np.int64)) def testTranspose2DAuto(self): x_np = [[1, 2, 3], [4, 5, 6]] |