aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-02-16 09:52:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-16 11:08:17 -0800
commit6804c9cafc11fa73be3fdb057e033f0304661622 (patch)
tree87a1c65806c0bf73263855c659e112977ddb27f1
parentcf661010261c80b97ab68c5aec383b454ef34f18 (diff)
Rewrite of transpose so that its compilation time is tolerable. Main
approach: 1. Do not instantiate templates for all tf types. Instead, various types is casted to one of uint8/uint16/uint32/uint64/string. 2. Use eigen3 for 2/3/4 rank tensors' transpose and fallback to a naive routine which is only templatized on type T but not on NDIMS. Change: 114763098
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/kernels/reduction_ops_common.h18
-rw-r--r--tensorflow/core/kernels/transpose_op.cc153
-rw-r--r--tensorflow/core/kernels/transpose_op.h37
-rw-r--r--tensorflow/core/kernels/transpose_op_cpu.cc119
-rw-r--r--tensorflow/core/kernels/transpose_op_functor.h66
-rw-r--r--tensorflow/core/kernels/transpose_op_gpu.cu.cc153
-rw-r--r--tensorflow/python/kernel_tests/transpose_op_test.py43
8 files changed, 401 insertions, 189 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 51bd76bce8..11788ebd3f 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -686,6 +686,7 @@ filegroup(
"//tensorflow/core:kernels/topk_op.cc",
"//tensorflow/core:kernels/training_ops.cc",
"//tensorflow/core:kernels/transpose_op.cc",
+ "//tensorflow/core:kernels/transpose_op_cpu.cc",
"//tensorflow/core:kernels/where_op.cc",
"//tensorflow/core:kernels/xent_op.cc",
],
diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h
index bd03c4183d..f1904b7255 100644
--- a/tensorflow/core/kernels/reduction_ops_common.h
+++ b/tensorflow/core/kernels/reduction_ops_common.h
@@ -23,9 +23,6 @@ limitations under the License.
#define EIGEN_USE_THREADS
-#include "tensorflow/core/kernels/reduction_ops.h"
-#include "tensorflow/core/kernels/transpose_op.h"
-
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
@@ -33,6 +30,8 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/reduction_ops.h"
+#include "tensorflow/core/kernels/transpose_op_functor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
@@ -197,7 +196,12 @@ class ReductionHelper {
}
// Shape of shuffled input
- const gtl::ArraySlice<int64> data_reshape() const { return data_reshape_; }
+ TensorShape data_reshape() const {
+ const int dims = data_reshape_.size();
+ TensorShape shape;
+ for (auto s : data_reshape_) shape.AddDim(s);
+ return shape;
+ }
// Shape with all reduction dimensions at the end
TensorShape shuffled_shape() {
@@ -315,12 +319,14 @@ class ReductionOp : public OpKernel {
} else {
// If we don't hit one of the cases above, transpose the data so that
// all reduced dimensions are last and reuse the 2-D -> 1-D case.
+ Tensor data_reshaped;
+ CHECK(data_reshaped.CopyFrom(data, helper.data_reshape()));
Tensor shuffled;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
helper.shuffled_shape(), &shuffled,
alloc_attr));
- TransposeTensor<Device, T>(d, data, helper.data_reshape(),
- helper.permutation(), &shuffled);
+ OP_REQUIRES_OK(
+ ctx, DoTranspose(d, data_reshaped, helper.permutation(), &shuffled));
const int64 unreduced = tmp_out.NumElements();
const int64 reduced = shuffled.NumElements() / unreduced;
const Tensor& const_shuffled = shuffled;
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index e7294cb3e0..530e333bff 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -18,6 +18,9 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/transpose_op.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/transpose_op_functor.h"
@@ -27,9 +30,6 @@ limitations under the License.
namespace tensorflow {
-typedef Eigen::ThreadPoolDevice CPUDevice;
-typedef Eigen::GpuDevice GPUDevice;
-
// inv = InvertPermutationOp(T<int32> p) takes a permutation of
// integers 0, 1, ..., n - 1 and returns the inverted
// permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n).
@@ -89,26 +89,21 @@ REGISTER_KERNEL_BUILDER(Name("InvertPermutation")
// REQUIRES: input.dims() == perm.size().
// REQUIRES: perm is a permutation.
-template <typename Device, typename T>
-TransposeOp<Device, T>::TransposeOp(OpKernelConstruction* context)
- : OpKernel(context) {}
-
-template <typename Device, typename T>
-void TransposeOp<Device, T>::Compute(OpKernelContext* context) {
- const Tensor& input = context->input(0);
- const Tensor& perm = context->input(1);
+void TransposeOp::Compute(OpKernelContext* ctx) {
+ const Tensor& input = ctx->input(0);
+ const Tensor& perm = ctx->input(1);
// Preliminary validation of sizes.
- OP_REQUIRES(context, TensorShapeUtils::IsVector(perm.shape()),
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm.shape()),
errors::InvalidArgument("perm must be a vector, not ",
perm.shape().DebugString()));
auto Vperm = perm.vec<int32>();
const int dims = input.dims();
static const int kMinDims = 0;
static const int kMaxDims = 10;
- OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims,
+ OP_REQUIRES(ctx, kMinDims <= dims && dims <= kMaxDims,
errors::Unimplemented("Transposing a tensor of rank ", dims,
" is not implemented."));
- OP_REQUIRES(context, dims == Vperm.size(),
+ OP_REQUIRES(ctx, dims == Vperm.size(),
errors::InvalidArgument(
"transpose expects a vector of size ", input.dims(),
". But input(1) is a vector of size ", Vperm.size()));
@@ -120,118 +115,62 @@ void TransposeOp<Device, T>::Compute(OpKernelContext* context) {
gtl::InlinedVector<bool, 8> bits(dims);
for (const int32 d : permutation) {
OP_REQUIRES(
- context, 0 <= d && d < dims,
+ ctx, 0 <= d && d < dims,
errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));
bits[d] = true;
shape.AddDim(input.dim_size(d));
}
for (int i = 0; i < dims; ++i) {
- OP_REQUIRES(context, bits[i], errors::InvalidArgument(
- i, " is missing from {",
- str_util::Join(permutation, ","), "}."));
+ OP_REQUIRES(ctx, bits[i], errors::InvalidArgument(
+ i, " is missing from {",
+ str_util::Join(permutation, ","), "}."));
}
// 0-D and 1-D transposes do nothing
if (dims <= 1) {
- context->set_output(0, input);
+ ctx->set_output(0, input);
return;
}
Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output));
- TransposeTensor<Device, T>(context->eigen_device<Device>(), input,
- input.shape().dim_sizes(), permutation, output);
-}
-
-template <typename Device, typename T>
-void TransposeTensor(const Device& device, const Tensor& input,
- const gtl::ArraySlice<int64> input_shape,
- gtl::ArraySlice<int32> permutation, Tensor* output) {
- const int dims = input_shape.size();
- CHECK(permutation.size() == dims);
- if (input.NumElements() == 0) {
- return;
- }
- switch (dims) {
-#define EXPAND_DIM(N) \
- case N: { \
- functor::TransposeFunctor<Device, T, N> func; \
- func(device, output->tensor<T, N>(), input.shaped<T, N>(input_shape), \
- permutation.data()); \
- break; \
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output));
+ if (shape.num_elements() > 0) {
+ OP_REQUIRES_OK(ctx, DoTranspose(ctx, input, permutation, output));
}
- EXPAND_DIM(2);
- EXPAND_DIM(3);
- EXPAND_DIM(4);
- EXPAND_DIM(5);
- EXPAND_DIM(6);
- EXPAND_DIM(7);
- EXPAND_DIM(8);
- EXPAND_DIM(9);
- EXPAND_DIM(10);
- default:
- LOG(FATAL) << "Unexpected dims: " << dims;
- }
-#undef EXPAND_CASE
}
-namespace functor {
-
-template <typename Device, typename T, int NDIMS>
-void TransposeMaybeInline(const Device& d,
- typename TTypes<T, NDIMS>::Tensor out,
- typename TTypes<T, NDIMS>::ConstTensor in,
- const int* perm) {
- // perm[] is a permutation of 0, 1, ..., NDIMS-1. perm[] is on CPU.
- Eigen::array<int, NDIMS> p;
- for (int i = 0; i < NDIMS; ++i) p[i] = perm[i];
- if (out.size() * sizeof(T) < 131072) { // Small transpose on a CPU: do inline
- out = in.shuffle(p);
- } else {
- out.device(d) = in.shuffle(p);
- }
+Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
+ gtl::ArraySlice<int32> perm, Tensor* out) {
+ typedef Eigen::ThreadPoolDevice CPUDevice;
+ return ::tensorflow::DoTranspose(ctx->eigen_device<CPUDevice>(), in, perm,
+ out);
}
-template <typename T, int NDIMS>
-struct TransposeFunctor<CPUDevice, T, NDIMS> {
- void operator()(const CPUDevice& d, typename TTypes<T, NDIMS>::Tensor out,
- typename TTypes<T, NDIMS>::ConstTensor in, const int* perm) {
- TransposeMaybeInline<CPUDevice, T, NDIMS>(d, out, in, perm);
- }
-};
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER(Name("Transpose") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("perm"), \
+ TransposeCpuOp);
+TF_CALL_ALL_TYPES(REGISTER)
+#undef REGISTER
-} // namespace functor
-
-#define REGISTER(D, T) \
- template class TransposeOp<D##Device, T>; \
- REGISTER_KERNEL_BUILDER(Name("Transpose") \
- .Device(DEVICE_##D) \
- .TypeConstraint<T>("T") \
- .HostMemory("perm"), \
- TransposeOp<D##Device, T>); \
- template void TransposeTensor<D##Device, T>( \
- const D##Device&, const Tensor&, const gtl::ArraySlice<int64>, \
- gtl::ArraySlice<int32>, Tensor*);
-REGISTER(CPU, float);
-REGISTER(CPU, double);
-REGISTER(CPU, complex64);
-REGISTER(CPU, uint8);
-REGISTER(CPU, int8);
-REGISTER(CPU, int16);
-REGISTER(CPU, int32);
-REGISTER(CPU, int64);
-REGISTER(CPU, string);
-REGISTER(CPU, bool);
#if GOOGLE_CUDA
-REGISTER(GPU, uint8);
-REGISTER(GPU, int8);
-REGISTER(GPU, int16);
-REGISTER(GPU, int32);
-REGISTER(GPU, int64);
-REGISTER(GPU, float);
-REGISTER(GPU, double);
-REGISTER(GPU, complex64);
-REGISTER(GPU, bool);
-#endif
+Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
+ gtl::ArraySlice<int32> perm, Tensor* out) {
+ typedef Eigen::GpuDevice GPUDevice;
+ return ::tensorflow::DoTranspose(ctx->eigen_device<GPUDevice>(), in, perm,
+ out);
+}
+
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER(Name("Transpose") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("perm"), \
+ TransposeGpuOp);
+TF_CALL_NUMBER_TYPES(REGISTER);
#undef REGISTER
+#endif
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/transpose_op.h b/tensorflow/core/kernels/transpose_op.h
index 15cd1b6488..4d546298b4 100644
--- a/tensorflow/core/kernels/transpose_op.h
+++ b/tensorflow/core/kernels/transpose_op.h
@@ -1,4 +1,4 @@
-/* Copyright 2015 Google Inc. All Rights Reserved.
+/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -16,24 +16,39 @@ limitations under the License.
#ifndef TENSORFLOW_KERNELS_TRANSPOSE_OP_H_
#define TENSORFLOW_KERNELS_TRANSPOSE_OP_H_
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
-template <typename Device, typename T>
class TransposeOp : public OpKernel {
public:
- explicit TransposeOp(OpKernelConstruction* context);
- void Compute(OpKernelContext* context) override;
+ explicit TransposeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override;
+
+ protected:
+ virtual Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
+ gtl::ArraySlice<int32> perm, Tensor* out) = 0;
};
-// Exposed for use in reduction ops
-template <typename Device, typename T>
-void TransposeTensor(const Device& device, const Tensor& input,
- const gtl::ArraySlice<int64> input_shape,
- gtl::ArraySlice<int32> permutation, Tensor* output);
+class TransposeCpuOp : public TransposeOp {
+ public:
+ explicit TransposeCpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {}
+
+ protected:
+ Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
+ gtl::ArraySlice<int32> perm, Tensor* out) override;
+};
+
+class TransposeGpuOp : public TransposeOp {
+ public:
+ explicit TransposeGpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {}
+
+ protected:
+ Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
+ gtl::ArraySlice<int32> perm, Tensor* out) override;
+};
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/transpose_op_cpu.cc b/tensorflow/core/kernels/transpose_op_cpu.cc
new file mode 100644
index 0000000000..ea039bf471
--- /dev/null
+++ b/tensorflow/core/kernels/transpose_op_cpu.cc
@@ -0,0 +1,119 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/transpose_op_functor.h"
+
+namespace tensorflow {
+namespace internal {
+
+template <typename Device, typename T>
+void TransposeSimple(const Device& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ const int ndims = in.dims();
+ gtl::InlinedVector<int64, 8> in_strides(ndims);
+ ComputeStride(in.shape(), in_strides.data());
+ gtl::InlinedVector<int64, 8> out_strides(ndims);
+ ComputeStride(out->shape(), out_strides.data());
+ const int64 nelem = in.NumElements();
+ const T* p = reinterpret_cast<const T*>(in.tensor_data().data());
+ T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data())));
+
+ // TODO(zhifengc): Shard by range.
+ // TODO(zhifengc): Avoids the division.
+ for (int64 o_idx = 0; o_idx < nelem; ++o_idx) {
+ int64 i_idx = 0;
+ int64 t = o_idx;
+ for (int i = 0; i < ndims; ++i) {
+ i_idx += (t / out_strides[i]) * in_strides[perm[i]];
+ t = t % out_strides[i];
+ }
+ q[o_idx] = p[i_idx];
+ }
+}
+
+template <typename Device, typename T, int NDIMS>
+void TransposeUsingEigen(const Device& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ Eigen::array<int, NDIMS> p;
+ for (int i = 0; i < NDIMS; ++i) p[i] = perm[i];
+ auto x = typename TTypes<T, NDIMS>::ConstTensor(
+ reinterpret_cast<const T*>(in.tensor_data().data()),
+ in.shape().AsEigenDSizes<NDIMS>());
+ auto y = typename TTypes<T, NDIMS>::Tensor(
+ reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data())),
+ out->shape().AsEigenDSizes<NDIMS>());
+ auto nelem = in.NumElements();
+ static const int64 kInlineThreshold = 131072;
+ if (nelem * sizeof(T) < kInlineThreshold) {
+ // Don't bother multi-threaded transpose if 'in' is small.
+ y = x.shuffle(p);
+ } else {
+ y.device(d) = x.shuffle(p);
+ }
+}
+
+} // end namespace internal
+
+typedef Eigen::ThreadPoolDevice Device;
+
+template <>
+Status DoTranspose<Device>(const Device& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ CHECK_GE(in.dims(), 2);
+ CHECK_EQ(in.dims(), out->dims());
+ CHECK_EQ(in.dims(), perm.size());
+ CHECK_EQ(in.dtype(), out->dtype());
+ switch (in.dtype()) {
+ case DT_BOOL:
+ case DT_INT8:
+ case DT_QINT8:
+ case DT_QUINT8:
+ case DT_UINT8:
+ internal::Transpose<Device, uint8>(d, in, perm, out);
+ break;
+
+ case DT_BFLOAT16:
+ case DT_INT16:
+ case DT_QINT16:
+ case DT_QUINT16:
+ case DT_UINT16:
+ internal::Transpose<Device, uint16>(d, in, perm, out);
+ break;
+
+ case DT_FLOAT:
+ case DT_INT32:
+ case DT_QINT32:
+ internal::Transpose<Device, uint32>(d, in, perm, out);
+ break;
+
+ case DT_COMPLEX64:
+ case DT_DOUBLE:
+ case DT_INT64:
+ internal::Transpose<Device, uint64>(d, in, perm, out);
+ break;
+
+ case DT_STRING:
+ internal::Transpose<Device, string>(d, in, perm, out);
+ break;
+
+ default:
+ return errors::Unimplemented("Unsupported dtype on CPU: ", in.dtype());
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/transpose_op_functor.h b/tensorflow/core/kernels/transpose_op_functor.h
index e478c6d966..b79c3c7f2f 100644
--- a/tensorflow/core/kernels/transpose_op_functor.h
+++ b/tensorflow/core/kernels/transpose_op_functor.h
@@ -16,28 +16,66 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_
#define TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
-namespace functor {
-template <typename Device, typename T, int NDIMS>
-void Transpose(const Device& d, typename TTypes<T, NDIMS>::Tensor out,
- typename TTypes<T, NDIMS>::ConstTensor in, const int* perm) {
- // perm[] is a permutation of 0, 1, ..., NDIMS-1. perm[] is on CPU.
- Eigen::array<int, NDIMS> p;
- for (int i = 0; i < NDIMS; ++i) p[i] = perm[i];
- out.device(d) = in.shuffle(p);
+// Transpose tensor 'in' into tensor 'out' according to dimension
+// permutation 'perm'.
+//
+// REQUIRES: in.dtype() == out->dtype()
+// REQUIRES: in.dims() == out->dims()
+// REQUIRES: in.dims() == perm.size()
+// REQUIRES: in.dim_size(perm[i]) == out->dim_size(i)
+template <typename Device>
+Status DoTranspose(const Device& device, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out);
+
+// Implementation details.
+namespace internal {
+
+// Helper to compute 'strides' given a tensor 'shape'. I.e.,
+// strides[i] = prod(shape.dim_size[(i+1):])
+template <typename Index>
+void ComputeStride(const TensorShape& shape, Index* strides) {
+ const int ndims = shape.dims();
+ Index stride = 1;
+ for (int i = ndims - 1; i >= 0; --i) {
+ strides[i] = stride;
+ stride *= static_cast<Index>(shape.dim_size(i));
+ }
}
+// Device-specific naive implementation for tranpose.
+template <typename Device, typename T>
+void TransposeSimple(const Device& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out);
+
+// Uses Eigen to transpose.
template <typename Device, typename T, int NDIMS>
-struct TransposeFunctor {
- void operator()(const Device& d, typename TTypes<T, NDIMS>::Tensor out,
- typename TTypes<T, NDIMS>::ConstTensor in, const int* perm);
-};
+void TransposeUsingEigen(const Device& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out);
-} // namespace functor
+template <typename Device, typename T>
+void Transpose(const Device& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ switch (in.dims()) {
+ case 2:
+ TransposeUsingEigen<Device, T, 2>(d, in, perm, out);
+ break;
+ case 3:
+ TransposeUsingEigen<Device, T, 3>(d, in, perm, out);
+ break;
+ case 4:
+ TransposeUsingEigen<Device, T, 4>(d, in, perm, out);
+ break;
+ default:
+ TransposeSimple<Device, T>(d, in, perm, out);
+ break;
+ }
+}
+} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/transpose_op_gpu.cu.cc b/tensorflow/core/kernels/transpose_op_gpu.cu.cc
index b42a69c785..238ee0a090 100644
--- a/tensorflow/core/kernels/transpose_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/transpose_op_gpu.cu.cc
@@ -1,4 +1,4 @@
-/* Copyright 2015 Google Inc. All Rights Reserved.
+/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -17,46 +17,123 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "tensorflow/core/framework/numeric_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/transpose_op_functor.h"
-#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow {
-namespace functor {
-
-template <typename T, int NDIMS>
-struct TransposeFunctor<Eigen::GpuDevice, T, NDIMS> {
- void operator()(const Eigen::GpuDevice& d,
- typename TTypes<T, NDIMS>::Tensor out,
- typename TTypes<T, NDIMS>::ConstTensor in, const int* perm) {
- Transpose<Eigen::GpuDevice, T, NDIMS>(d, out, in, perm);
+namespace internal {
+
+template <typename T>
+__global__ void TransposeKernel(int nthreads, const T* src, const int32* buf,
+ const int32 ndims, T* dst) {
+ const int32* in_strides = buf;
+ const int32* out_strides = buf + ndims;
+ const int32* perm = buf + ndims * 2;
+ CUDA_1D_KERNEL_LOOP(o_idx, nthreads) {
+ int32 i_idx = 0;
+ int32 t = o_idx;
+ for (int i = 0; i < ndims; ++i) {
+ i_idx += (t / out_strides[i]) * in_strides[perm[i]];
+ t = t % out_strides[i];
+ }
+ dst[o_idx] = ldg(src + i_idx);
+ }
+}
+
+template <typename Device, typename T>
+void TransposeSimple(const Device& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ // Ensures we can use 32-bit index.
+ const int64 nelem = in.NumElements();
+ CHECK_LT(nelem, kint32max) << "Tensor too large to transpose on GPU";
+ // Pack strides and permutation into one buffer.
+ const int32 ndims = in.dims();
+ gtl::InlinedVector<int32, 16> host_buf(ndims * 3);
+ // Input strides.
+ ComputeStride(in.shape(), &host_buf[0]);
+ // Output strides.
+ ComputeStride(out->shape(), &host_buf[ndims]);
+ // Dimension permutation.
+ for (int i = 0; i < ndims; ++i) {
+ host_buf[ndims * 2 + i] = perm[i];
+ }
+ // Copies the input strides, output strides and permutation to the device.
+ auto num_bytes = sizeof(int64) * host_buf.size();
+ auto dev_buf = d.allocate(num_bytes);
+ // NOTE: host_buf is not allocated by CudaHostAllocator, and
+ // therefore we are doing a sync copy effectively.
+ d.memcpyHostToDevice(dev_buf, host_buf.data(), num_bytes);
+ // Launch kernel to q[...] = p[...].
+ const T* p = reinterpret_cast<const T*>(in.tensor_data().data());
+ T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data())));
+ CudaLaunchConfig cfg = GetCudaLaunchConfig(nelem, d);
+ TransposeKernel<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(
+ cfg.virtual_thread_count, p, reinterpret_cast<const int32*>(dev_buf),
+ ndims, q);
+ // Safe to deallocate immediately after the kernel launch.
+ d.deallocate(dev_buf);
+}
+
+template <typename Device, typename T, int NDIMS>
+void TransposeUsingEigen(const Device& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ Eigen::array<int, NDIMS> p;
+ for (int i = 0; i < NDIMS; ++i) p[i] = perm[i];
+ auto x = typename TTypes<T, NDIMS>::ConstTensor(
+ reinterpret_cast<const T*>(in.tensor_data().data()),
+ in.shape().AsEigenDSizes<NDIMS>());
+ auto y = typename TTypes<T, NDIMS>::Tensor(
+ reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data())),
+ out->shape().AsEigenDSizes<NDIMS>());
+ y.device(d) = x.shuffle(p);
+}
+
+} // end namespace internal
+
+typedef Eigen::GpuDevice Device;
+
+template <>
+Status DoTranspose<Device>(const Device& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ CHECK_GE(in.dims(), 2);
+ CHECK_EQ(in.dims(), out->dims());
+ CHECK_EQ(in.dims(), perm.size());
+ CHECK_EQ(in.dtype(), out->dtype());
+ switch (in.dtype()) {
+ case DT_BOOL:
+ case DT_INT8:
+ case DT_QINT8:
+ case DT_QUINT8:
+ case DT_UINT8:
+ internal::Transpose<Device, uint8>(d, in, perm, out);
+ break;
+
+ case DT_BFLOAT16:
+ case DT_INT16:
+ case DT_QINT16:
+ case DT_QUINT16:
+ case DT_UINT16:
+ internal::Transpose<Device, uint16>(d, in, perm, out);
+ break;
+
+ case DT_FLOAT:
+ case DT_INT32:
+ case DT_QINT32:
+ internal::Transpose<Device, uint32>(d, in, perm, out);
+ break;
+
+ case DT_COMPLEX64:
+ case DT_DOUBLE:
+ case DT_INT64:
+ internal::Transpose<Device, uint64>(d, in, perm, out);
+ break;
+
+ default:
+ return errors::Unimplemented("Unsupported dtype on GPU: ", in.dtype());
}
-};
-
-#define DEFINE(T, N) template struct TransposeFunctor<Eigen::GpuDevice, T, N>;
-#define DEFINE_DIM(T) \
- DEFINE(T, 2); \
- DEFINE(T, 3); \
- DEFINE(T, 4); \
- DEFINE(T, 5); \
- DEFINE(T, 6); \
- DEFINE(T, 7); \
- DEFINE(T, 8); \
- DEFINE(T, 9); \
- DEFINE(T, 10);
-DEFINE_DIM(uint8);
-DEFINE_DIM(int8);
-DEFINE_DIM(int16);
-DEFINE_DIM(int32);
-DEFINE_DIM(int64);
-DEFINE_DIM(float);
-DEFINE_DIM(double);
-DEFINE_DIM(complex64);
-DEFINE_DIM(bool);
-#undef DEFINE_DIM
-#undef DEFINE
-
-} // end namespace functor
-} // end namespace tensorflow
+ return Status::OK();
+}
+} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py
index 3e980ba198..db0cc89e20 100644
--- a/tensorflow/python/kernel_tests/transpose_op_test.py
+++ b/tensorflow/python/kernel_tests/transpose_op_test.py
@@ -58,6 +58,7 @@ class TransposeTest(tf.test.TestCase):
inx = tf.convert_to_tensor(x)
y = tf.transpose(inx, p)
tf_ans = y.eval()
+
self.assertAllEqual(np_ans, tf_ans)
self.assertShapeEqual(np_ans, y)
@@ -101,8 +102,9 @@ class TransposeTest(tf.test.TestCase):
self.assertAllClose(tf_a_cpu, tf_a_gpu, 1e-6, 1e-6)
self.assertAllClose(tf_g_cpu, tf_g_gpu, 1e-6, 1e-6)
- def _testCpu(self, x):
+ def _testBoth(self, x):
self._compare(x, use_gpu=False)
+ self._compare(x, use_gpu=True)
def test1D(self):
self._compareCpu(np.arange(0., 2), [0])
@@ -118,33 +120,48 @@ class TransposeTest(tf.test.TestCase):
self._compare_cpu_gpu(np.arange(0, 21).reshape([3, 7]).astype(np.float32))
self._compare_cpu_gpu(
np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.float32))
+ self._compare_cpu_gpu(np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype(
+ np.float32))
def testDouble(self):
self._compare_cpu_gpu(np.arange(0, 21).reshape([3, 7]).astype(np.float64))
self._compare_cpu_gpu(
np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.float64))
+ self._compare_cpu_gpu(np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype(
+ np.float64))
def testSComplex(self):
- self._testCpu(np.complex(1, 2) * np.arange(0, 21).reshape(
- [3, 7]).astype(np.complex64))
- self._testCpu(np.complex(1, 2) * np.arange(0, 210).reshape(
- [2, 3, 5, 7]).astype(np.complex64))
+ self._testBoth(np.complex(1, 2) *
+ np.arange(0, 21).reshape([3, 7]).astype(np.complex64))
+ self._testBoth(np.complex(1, 2) *
+ np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.complex64))
+ self._testBoth(
+ np.complex(1, 2) *
+ np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype(np.complex64))
def testInt8(self):
- self._testCpu(np.arange(0, 21).reshape([3, 7]).astype(np.int8))
- self._testCpu(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int8))
+ self._testBoth(np.arange(0, 21).reshape([3, 7]).astype(np.int8))
+ self._testBoth(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int8))
+ self._testBoth(np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype(
+ np.int8))
def testInt16(self):
- self._testCpu(np.arange(0, 21).reshape([3, 7]).astype(np.int16))
- self._testCpu(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int16))
+ self._testBoth(np.arange(0, 21).reshape([3, 7]).astype(np.int16))
+ self._testBoth(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int16))
+ self._testBoth(np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype(
+ np.int16))
def testInt32(self):
- self._testCpu(np.arange(0, 21).reshape([3, 7]).astype(np.int32))
- self._testCpu(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int32))
+ self._testBoth(np.arange(0, 21).reshape([3, 7]).astype(np.int32))
+ self._testBoth(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int32))
+ self._testBoth(np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype(
+ np.int32))
def testInt64(self):
- self._testCpu(np.arange(0, 21).reshape([3, 7]).astype(np.int64))
- self._testCpu(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int64))
+ self._testBoth(np.arange(0, 21).reshape([3, 7]).astype(np.int64))
+ self._testBoth(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int64))
+ self._testBoth(np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype(
+ np.int64))
def testTranspose2DAuto(self):
x_np = [[1, 2, 3], [4, 5, 6]]