aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-11 16:53:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-11 16:57:26 -0700
commitd835d677ade78a41e0e097f67c87b6ab8588a90a (patch)
treee267977299903a88efd655ea62bf6ad396da8785
parent6285db2546f03296a4f30071ce96217ccd17c452 (diff)
Extend the transpose ops in TensorFlow to support conjugate (a.k.a. Hermitian) transposition. Currently, this can only be accomplished by adding extra conjugation ops, which means reading the tensor data from memory twice. More importantly, Hermitian transpose is the most common transposition operation when using complex arithmetic, so using it in new code helps prevent "conjugation bugs" by making the math work for real and complex types alike. The alias tf.linalg.adjoint was added to help with the latter.
Optimized fused conjugate transpose op for GPU will be added in a followup. Get rid of some duplication of code among CPU/GPU/SYCL in transpose_functor. Support accelerating 2D transpose ops using MKL in more cases. PiperOrigin-RevId: 171895454
-rw-r--r--tensorflow/core/kernels/BUILD31
-rw-r--r--tensorflow/core/kernels/mkl_transpose_op.cc94
-rw-r--r--tensorflow/core/kernels/transpose_functor.h106
-rw-r--r--tensorflow/core/kernels/transpose_functor_cpu.cc207
-rw-r--r--tensorflow/core/kernels/transpose_functor_gpu.cu.cc233
-rw-r--r--tensorflow/core/kernels/transpose_op.cc72
-rw-r--r--tensorflow/core/kernels/transpose_op.h53
-rw-r--r--tensorflow/core/ops/array_ops.cc131
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py9
-rw-r--r--tensorflow/python/kernel_tests/linalg_ops_test.py15
-rw-r--r--tensorflow/python/kernel_tests/transpose_op_test.py24
-rw-r--r--tensorflow/python/ops/array_ops.py35
-rw-r--r--tensorflow/python/ops/hidden_ops.txt1
-rw-r--r--tensorflow/python/ops/linalg/linalg_impl.py25
-rw-r--r--tensorflow/tools/api/golden/tensorflow.linalg.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt4
16 files changed, 662 insertions, 384 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 3b7d803bea..dbf6449bc2 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1350,10 +1350,11 @@ tf_kernel_library(
],
visibility = [":friends"],
deps = [
+ ":conv_ops",
+ ":cwise_op",
":ops_util",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core/kernels:conv_ops",
"//third_party/eigen3",
],
alwayslink = 0,
@@ -2276,13 +2277,15 @@ LINALG_DEPS = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:linalg_ops_op_lib",
-]
+] + if_cuda([
+ ":cuda_solvers",
+ ":transpose_functor",
+])
tf_kernel_library(
name = "cholesky_op",
prefix = "cholesky_op",
deps = if_cuda([
- ":cuda_solvers",
":matrix_band_part_op",
]) + LINALG_DEPS,
)
@@ -2297,7 +2300,6 @@ tf_kernel_library(
name = "determinant_op",
prefix = "determinant_op",
deps = if_cuda([
- ":cuda_solvers",
":fill_functor",
]) + LINALG_DEPS,
)
@@ -2314,17 +2316,13 @@ tf_kernel_library(
deps = LINALG_DEPS + if_cuda([
":cast_op",
":cwise_op",
- ":cuda_solvers",
- ":transpose_functor",
]),
)
tf_kernel_library(
name = "matrix_inverse_op",
prefix = "matrix_inverse_op",
- deps = if_cuda([
- ":cuda_solvers",
- ]) + LINALG_DEPS,
+ deps = LINALG_DEPS,
)
tf_kernel_library(
@@ -2336,10 +2334,7 @@ tf_kernel_library(
tf_kernel_library(
name = "matrix_solve_op",
prefix = "matrix_solve_op",
- deps = if_cuda([
- ":cuda_solvers",
- ":transpose_functor",
- ]) + LINALG_DEPS,
+ deps = LINALG_DEPS,
)
tf_kernel_library(
@@ -2354,20 +2349,15 @@ tf_kernel_library(
name = "qr_op",
prefix = "qr_op",
deps = LINALG_DEPS + if_cuda([
- ":cuda_solvers",
":cwise_op",
":matrix_band_part_op",
- ":transpose_functor",
]),
)
tf_kernel_library(
name = "svd_op",
prefix = "svd_op",
- deps = LINALG_DEPS + if_cuda([
- ":cuda_solvers",
- ":transpose_functor",
- ]),
+ deps = LINALG_DEPS,
)
cc_library(
@@ -2457,7 +2447,6 @@ tf_cc_tests(
MATH_DEPS = [
":bounds_check",
":fill_functor",
- ":transpose_functor",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -2617,7 +2606,7 @@ tf_kernel_library(
name = "reduction_ops",
gpu_srcs = ["reduction_gpu_kernels.cu.h"],
prefix = "reduction_ops",
- deps = MATH_DEPS + if_cuda(["@cub_archive//:cub"]),
+ deps = MATH_DEPS + [":transpose_functor"] + if_cuda(["@cub_archive//:cub"]),
)
tf_kernel_library(
diff --git a/tensorflow/core/kernels/mkl_transpose_op.cc b/tensorflow/core/kernels/mkl_transpose_op.cc
index 50d25ac511..89a1d5e8a7 100644
--- a/tensorflow/core/kernels/mkl_transpose_op.cc
+++ b/tensorflow/core/kernels/mkl_transpose_op.cc
@@ -39,28 +39,86 @@ namespace tensorflow {
// REQUIRES: input.dims() == perm.size().
// REQUIRES: perm is a permutation.
-Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
- gtl::ArraySlice<int32> perm,
- Tensor* out) {
- if (in.dims() == 2 && in.dtype() == DT_FLOAT) {
- float* user_o = out->flat<float>().data();
- const float* user_i = in.flat<float>().data();
-
- // Documentation here: https://software.intel.com/en-us/node/520863
- // Parameters: (ordering:row-major, operation:transpose, num_rows, num_cols,
- // alpha (for scaling), array, dist_bet_adjacent_cols/rows
- // (source), array, dist_bet_adjacent_cols/rows (dest))
- mkl_somatcopy('R', 'T', in.dim_size(0), in.dim_size(1), 1, user_i,
- in.dim_size(1), user_o, in.dim_size(0));
+namespace {
+template <typename T>
+void MKLTranspose2D(const char trans, const Tensor& in, Tensor* out) {}
+// Documentation here: https://software.intel.com/en-us/node/520863
+// Parameters: (ordering:row-major, operation:transpose, num_rows, num_cols,
+// alpha (for scaling), array, dist_bet_adjacent_cols/rows
+// (source), array, dist_bet_adjacent_cols/rows (dest))
+
+#define INSTANTIATE(T, PREFIX) \
+ template <> \
+ Status MKLTranspose2D<T>(const char trans, const Tensor& in, Tensor* out) { \
+ mkl_##PREFIX##omatcopy('R', trans, in.dim_size(0), in.dim_size(1), 1, \
+ in.flat<T>().data(), in.dim_size(1), \
+ out->flat<T>().data(), in.dim_size(0)); \
return Status::OK();
}
- // Fallback to eigen if transpose parameters not supported by MKL
- typedef Eigen::ThreadPoolDevice CPUDevice;
- return ::tensorflow::DoTranspose(ctx->eigen_device<CPUDevice>(), in, perm,
- out);
-} // MklTransposeCpuOp::DoTranspose
+ INSTANTIATE(float, s)
+ INSTANTIATE(double, d)
+ INSTANTIATE(complex64, c)
+ INSTANTIATE(complex128, z)
+#undef INSTANTIATE
+
+ static const char kMKLTranspose = 'T';
+ static const char kMKLConjugateTranspose = 'C';
+
+ } // namespace tensorflow
+
+ Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
+ gtl::ArraySlice<int32> perm,
+ Tensor* out) {
+ if (in.dims() == 2) {
+ switch (in.dtype()) {
+ case DT_FLOAT:
+ return MKLTranspose2D<float>(kMKLTranspose, in, out);
+ case DT_DOUBLE:
+ return MKLTranspose2D<double>(kMKLTranspose, in, out);
+ case DT_COMPLEX64:
+ return MKLTranspose2D<complex64>(kMKLTranspose, in, out);
+ case DT_COMPLEX128:
+ return MKLTranspose2D<complex128>(kMKLTranspose, in, out);
+ default:
+ break;
+ }
+ }
+ // Fallback to eigen if transpose parameters not supported by MKL
+ typedef Eigen::ThreadPoolDevice CPUDevice;
+ return ::tensorflow::DoTranspose(ctx->eigen_device<CPUDevice>(), in, perm,
+ out);
+ }
+
+ Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
+ const Tensor& in,
+ gtl::ArraySlice<int32> perm,
+ Tensor* out) {
+ if (in.dims() == 2) {
+ // TODO(rmlarsen): By setting lda and ldb, we could use the MKL kernels
+ // for any transpose that can be reduced to swapping the last two
+ // dimensions in a rank-3 tensor. We can even run each outer dimension in
+ // a separate thread.
+ switch (in.dtype()) {
+ case DT_FLOAT:
+ return MKLTranspose2D<float>(kMKLTranspose, in, out);
+ case DT_DOUBLE:
+ return MKLTranspose2D<double>(kMKLTranspose, in, out);
+ case DT_COMPLEX64:
+ return MKLTranspose2D<complex64>(kMKLConjugateTranspose, in, out);
+ case DT_COMPLEX128:
+ return MKLTranspose2D<complex128>(kMKLConjugateTranspose, in, out);
+ default:
+ break;
+ }
+ }
+ // Fallback to eigen if transpose parameters not supported by MKL
+ typedef Eigen::ThreadPoolDevice CPUDevice;
+ return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<CPUDevice>(),
+ in, perm, out);
+ }
+
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/transpose_functor.h b/tensorflow/core/kernels/transpose_functor.h
index 498030fdfe..317a534fd6 100644
--- a/tensorflow/core/kernels/transpose_functor.h
+++ b/tensorflow/core/kernels/transpose_functor.h
@@ -32,6 +32,24 @@ template <typename Device>
Status DoTranspose(const Device& device, const Tensor& in,
const gtl::ArraySlice<int32> perm, Tensor* out);
+// Conjugate and transpose tensor 'in' into tensor 'out' according to dimension
+// permutation 'perm'.
+//
+// REQUIRES: in.dtype() == out->dtype()
+// REQUIRES: in.dims() == out->dims()
+// REQUIRES: in.dims() == perm.size()
+// REQUIRES: in.dim_size(perm[i]) == out->dim_size(i)
+template <typename Device>
+Status DoConjugateTranspose(const Device& device, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out);
+
+// Primary device specific functor to be specialized for each device and type.
+template <typename Device, typename T, bool conjugate = false>
+struct Transpose {
+ static void run(const Device& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out);
+};
+
// Implementation details.
namespace internal {
@@ -111,14 +129,15 @@ inline bool NonSingletonDimensionsAlign(const TensorShape& input_shape,
}
// Device-specific naive implementation for transpose.
-template <typename Device, typename T>
+template <typename Device, typename T, bool conjugate>
void TransposeSimple(const Device& d, const Tensor& in,
const gtl::ArraySlice<int32> perm, Tensor* out);
// Uses Eigen to transpose.
template <typename Device, typename T, int NDIMS>
void TransposeUsingEigen(const Device& d, const Tensor& in,
- const gtl::ArraySlice<int32> perm, Tensor* out) {
+ const gtl::ArraySlice<int32> perm, bool conjugate,
+ Tensor* out) {
Eigen::array<int, NDIMS> p;
for (int i = 0; i < NDIMS; ++i) p[i] = perm[i];
auto x = typename TTypes<T, NDIMS>::ConstTensor(
@@ -127,24 +146,87 @@ void TransposeUsingEigen(const Device& d, const Tensor& in,
auto y = typename TTypes<T, NDIMS>::Tensor(
reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data())),
out->shape().AsEigenDSizes<NDIMS>());
- y.device(d) = x.shuffle(p);
+ if (conjugate) {
+ y.device(d) = x.conjugate().shuffle(p);
+ } else {
+ y.device(d) = x.shuffle(p);
+ }
}
+template <typename Device>
+struct DoTransposeImpl {
+ static Status run(const Device& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, bool conjugate,
+ Tensor* out) {
+ CHECK_GE(in.dims(), 2);
+ CHECK_EQ(in.dims(), out->dims());
+ CHECK_EQ(in.dims(), perm.size());
+ CHECK_EQ(in.dtype(), out->dtype());
+ switch (in.dtype()) {
+ case DT_BOOL:
+ case DT_INT8:
+ case DT_QINT8:
+ case DT_QUINT8:
+ case DT_UINT8:
+ Transpose<Device, uint8>::run(d, in, perm, out);
+ break;
+
+ case DT_BFLOAT16:
+ case DT_HALF:
+ case DT_INT16:
+ case DT_QINT16:
+ case DT_QUINT16:
+ case DT_UINT16:
+ Transpose<Device, uint16>::run(d, in, perm, out);
+ break;
+
+ case DT_FLOAT:
+ case DT_INT32:
+ case DT_QINT32:
+ Transpose<Device, uint32>::run(d, in, perm, out);
+ break;
+
+ case DT_DOUBLE:
+ case DT_INT64:
+ Transpose<Device, uint64>::run(d, in, perm, out);
+ break;
+
+ case DT_COMPLEX64:
+ if (conjugate) {
+ Transpose<Device, complex64, true>::run(d, in, perm, out);
+ } else {
+ Transpose<Device, complex64, false>::run(d, in, perm, out);
+ }
+ break;
+
+ case DT_COMPLEX128:
+ if (conjugate) {
+ Transpose<Device, complex128, true>::run(d, in, perm, out);
+ } else {
+ Transpose<Device, complex128, false>::run(d, in, perm, out);
+ }
+ break;
+
+ case DT_STRING:
+ Transpose<Device, string>::run(d, in, perm, out);
+ break;
+
+ default:
+ return errors::Unimplemented("Unsupported dtype on CPU: ", in.dtype());
+ }
+ return Status::OK();
+ }
+};
#ifdef TENSORFLOW_USE_SYCL
// For SYCL lets always go through Eigen
template <typename Device, typename T>
void TransposeSYCL(const Device& d, const Tensor& in,
- const gtl::ArraySlice<int32> perm, Tensor* out);
-#endif // TENSORFLOW_USE_SYCL
-} // namespace internal
-
-template <typename Device, typename T>
-struct Transpose {
- static void run(const Device& d, const Tensor& in,
- const gtl::ArraySlice<int32> perm, Tensor* out);
-};
+ const gtl::ArraySlice<int32> perm, bool conjugate,
+ Tensor* out);
+#endif // TENSORFLOW_USE_SYCL
+} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/transpose_functor_cpu.cc b/tensorflow/core/kernels/transpose_functor_cpu.cc
index c3f3df722f..b983bf695c 100644
--- a/tensorflow/core/kernels/transpose_functor_cpu.cc
+++ b/tensorflow/core/kernels/transpose_functor_cpu.cc
@@ -19,10 +19,12 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/transpose_functor.h"
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
namespace tensorflow {
namespace internal {
-template <typename Device, typename T>
+template <typename Device, typename T, bool conjugate>
void TransposeSimple(const Device& d, const Tensor& in,
const gtl::ArraySlice<int32> perm, Tensor* out) {
const int ndims = in.dims();
@@ -41,122 +43,90 @@ void TransposeSimple(const Device& d, const Tensor& in,
i_idx += (t / out_strides[i]) * in_strides[perm[i]];
t = t % out_strides[i];
}
- q[o_idx] = p[i_idx];
+ if (conjugate) {
+ q[o_idx] = Eigen::numext::conj(p[i_idx]);
+ } else {
+ q[o_idx] = p[i_idx];
+ }
}
}
} // end namespace internal
-typedef Eigen::ThreadPoolDevice CPUDevice;
-
-template <typename T>
-struct Transpose<CPUDevice, T> {
+template <typename T, bool conjugate>
+struct Transpose<CPUDevice, T, conjugate> {
static void run(const CPUDevice& d, const Tensor& in,
const gtl::ArraySlice<int32> perm, Tensor* out) {
switch (in.dims()) {
case 2:
- internal::TransposeUsingEigen<CPUDevice, T, 2>(d, in, perm, out);
+ internal::TransposeUsingEigen<CPUDevice, T, 2>(d, in, perm, conjugate,
+ out);
break;
case 3:
- internal::TransposeUsingEigen<CPUDevice, T, 3>(d, in, perm, out);
+ internal::TransposeUsingEigen<CPUDevice, T, 3>(d, in, perm, conjugate,
+ out);
break;
case 4:
- internal::TransposeUsingEigen<CPUDevice, T, 4>(d, in, perm, out);
+ internal::TransposeUsingEigen<CPUDevice, T, 4>(d, in, perm, conjugate,
+ out);
break;
case 5:
- internal::TransposeUsingEigen<CPUDevice, T, 5>(d, in, perm, out);
+ internal::TransposeUsingEigen<CPUDevice, T, 5>(d, in, perm, conjugate,
+ out);
break;
default:
- internal::TransposeSimple<CPUDevice, T>(d, in, perm, out);
+ internal::TransposeSimple<CPUDevice, T, conjugate>(d, in, perm, out);
break;
}
}
};
-// TODO(yangzihao): Merge this code with its GPU counterpart to reduce code
-// duplication.
template <>
-Status DoTranspose<CPUDevice>(const CPUDevice& d, const Tensor& in,
- const gtl::ArraySlice<int32> perm, Tensor* out) {
- typedef CPUDevice Device;
- CHECK_GE(in.dims(), 2);
- CHECK_EQ(in.dims(), out->dims());
- CHECK_EQ(in.dims(), perm.size());
- CHECK_EQ(in.dtype(), out->dtype());
- switch (in.dtype()) {
- case DT_BOOL:
- case DT_INT8:
- case DT_QINT8:
- case DT_QUINT8:
- case DT_UINT8:
- Transpose<Device, uint8>::run(d, in, perm, out);
- break;
-
- case DT_BFLOAT16:
- case DT_HALF:
- case DT_INT16:
- case DT_QINT16:
- case DT_QUINT16:
- case DT_UINT16:
- Transpose<Device, uint16>::run(d, in, perm, out);
- break;
-
- case DT_FLOAT:
- case DT_INT32:
- case DT_QINT32:
- Transpose<Device, uint32>::run(d, in, perm, out);
- break;
-
- case DT_COMPLEX64:
- case DT_DOUBLE:
- case DT_INT64:
- Transpose<Device, uint64>::run(d, in, perm, out);
- break;
-
- case DT_COMPLEX128:
- Transpose<Device, complex128>::run(d, in, perm, out);
- break;
-
- case DT_STRING:
- Transpose<Device, string>::run(d, in, perm, out);
- break;
+Status DoTranspose(const CPUDevice& device, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ return internal::DoTransposeImpl<CPUDevice>::run(device, in, perm,
+ false /* conjugate */, out);
+}
- default:
- return errors::Unimplemented("Unsupported dtype on CPU: ", in.dtype());
- }
- return Status::OK();
+template <>
+Status DoConjugateTranspose(const CPUDevice& device, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ return internal::DoTransposeImpl<CPUDevice>::run(device, in, perm,
+ true /* conjugate */, out);
}
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
+namespace internal {
template <typename Device, typename T>
void TransposeSYCL(const Device& d, const Tensor& in,
- const gtl::ArraySlice<int32> perm, Tensor* out) {
+ const gtl::ArraySlice<int32> perm, bool conjugate,
+ Tensor* out) {
switch (in.dims()) {
case 1:
- internal::TransposeUsingEigen<Device, T, 1>(d, in, perm, out);
+ TransposeUsingEigen<SYCLDevice, T, 1>(d, in, perm, conjugate, out);
break;
case 2:
- internal::TransposeUsingEigen<Device, T, 2>(d, in, perm, out);
+ TransposeUsingEigen<SYCLDevice, T, 2>(d, in, perm, conjugate, out);
break;
case 3:
- internal::TransposeUsingEigen<Device, T, 3>(d, in, perm, out);
+ TransposeUsingEigen<SYCLDevice, T, 3>(d, in, perm, conjugate, out);
break;
case 4:
- internal::TransposeUsingEigen<Device, T, 4>(d, in, perm, out);
+ TransposeUsingEigen<SYCLDevice, T, 4>(d, in, perm, conjugate, out);
break;
case 5:
- internal::TransposeUsingEigen<Device, T, 5>(d, in, perm, out);
+ TransposeUsingEigen<SYCLDevice, T, 5>(d, in, perm, conjugate, out);
break;
case 6:
- internal::TransposeUsingEigen<Device, T, 6>(d, in, perm, out);
+ TransposeUsingEigen<SYCLDevice, T, 6>(d, in, perm, conjugate, out);
break;
case 7:
- internal::TransposeUsingEigen<Device, T, 7>(d, in, perm, out);
+ TransposeUsingEigen<SYCLDevice, T, 7>(d, in, perm, conjugate, out);
break;
case 8:
- internal::TransposeUsingEigen<Device, T, 8>(d, in, perm, out);
+ TransposeUsingEigen<SYCLDevice, T, 8>(d, in, perm, conjugate, out);
break;
default:
LOG(FATAL) << "Unsupported TransposeUsingEigen for: " << in.dims();
@@ -164,87 +134,38 @@ void TransposeSYCL(const Device& d, const Tensor& in,
}
}
-template <typename T>
-struct Transpose<SYCLDevice, T> {
+} // namespace internal
+
+template <typename T, bool conjugate>
+struct Transpose<SYCLDevice, T, conjugate> {
static void run(const SYCLDevice& d, const Tensor& in,
const gtl::ArraySlice<int32> perm, Tensor* out) {
- switch (in.dims()) {
- case 1:
- internal::TransposeUsingEigen<SYCLDevice, T, 1>(d, in, perm, out);
- break;
- case 2:
- internal::TransposeUsingEigen<SYCLDevice, T, 2>(d, in, perm, out);
- break;
- case 3:
- internal::TransposeUsingEigen<SYCLDevice, T, 3>(d, in, perm, out);
- break;
- case 4:
- internal::TransposeUsingEigen<SYCLDevice, T, 4>(d, in, perm, out);
- break;
- case 5:
- internal::TransposeUsingEigen<SYCLDevice, T, 5>(d, in, perm, out);
- break;
- case 6:
- internal::TransposeUsingEigen<SYCLDevice, T, 6>(d, in, perm, out);
- break;
- case 7:
- internal::TransposeUsingEigen<SYCLDevice, T, 7>(d, in, perm, out);
- break;
- case 8:
- internal::TransposeUsingEigen<SYCLDevice, T, 8>(d, in, perm, out);
- break;
- default:
- LOG(FATAL) << "Unsupported TransposeUsingEigen for: " << in.dims();
- break;
- }
+ internal::TransposeSycl(d, in, perm, conjugate, out);
}
};
-template <>
-Status DoTranspose<SYCLDevice>(const SYCLDevice& d, const Tensor& in,
- const gtl::ArraySlice<int32> perm, Tensor* out) {
- CHECK_GE(in.dims(), 2);
- CHECK_EQ(in.dims(), out->dims());
- CHECK_EQ(in.dims(), perm.size());
- CHECK_EQ(in.dtype(), out->dtype());
- switch (in.dtype()) {
- case DT_BOOL:
- case DT_INT8:
- case DT_QINT8:
- case DT_QUINT8:
- case DT_UINT8:
- TransposeSYCL<SYCLDevice, uint8>(d, in, perm, out);
- break;
-
- case DT_BFLOAT16:
- case DT_HALF:
- case DT_INT16:
- case DT_QINT16:
- case DT_QUINT16:
- case DT_UINT16:
- TransposeSYCL<SYCLDevice, uint16>(d, in, perm, out);
- break;
- case DT_FLOAT:
- case DT_INT32:
- case DT_QINT32:
- TransposeSYCL<SYCLDevice, uint32>(d, in, perm, out);
- break;
-
- case DT_COMPLEX64:
- case DT_DOUBLE:
- case DT_INT64:
- TransposeSYCL<SYCLDevice, uint64>(d, in, perm, out);
- break;
+template <bool conjugate>
+struct Transpose<SYCLDevice, string, conjugate> {
+ static void run(const SYCLDevice& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ LOG(FATAL) << "DT_STRING not supported on SYCL device.";
+ }
+};
- case DT_COMPLEX128:
- TransposeSYCL<SYCLDevice, complex128>(d, in, perm, out);
- break;
+template <>
+Status DoTranspose(const SYCLDevice& device, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ return internal::DoTransposeImpl<SYCLDevice>::run(device, in, perm,
+ false /* conjugate */, out);
+}
- default:
- return errors::Unimplemented("Unsupported dtype on SYCL: ", in.dtype());
- }
- return Status::OK();
+template <>
+Status DoConjugateTranspose(const SYCLDevice& device, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ return internal::DoTransposeImpl<SYCLDevice>::run(device, in, perm,
+ true /* conjugate */, out);
}
+
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
index a118cc80c9..87af1ba0c4 100644
--- a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
+++ b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
@@ -18,18 +18,21 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/transpose_functor.h"
+#include "tensorflow/core/kernels/cwise_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
// TODO(yangzihao): Remove the dependency of conv_2d.h once we move all
// GPU util functions and transpose kernels into separate files.
#include "tensorflow/core/kernels/conv_2d.h"
+typedef Eigen::GpuDevice GPUDevice;
+
namespace tensorflow {
namespace internal {
-template <typename T>
+template <typename T, bool conjugate>
__global__ void TransposeKernel(int nthreads, const T* src, const int32* buf,
const int32 ndims, T* dst) {
const int32* in_strides = buf;
@@ -42,11 +45,15 @@ __global__ void TransposeKernel(int nthreads, const T* src, const int32* buf,
i_idx += (t / out_strides[i]) * in_strides[perm[i]];
t = t % out_strides[i];
}
- dst[o_idx] = ldg(src + i_idx);
+ if (conjugate) {
+ dst[o_idx] = Eigen::numext::conj(ldg(src + i_idx));
+ } else {
+ dst[o_idx] = ldg(src + i_idx);
+ }
}
}
-template <typename Device, typename T>
+template <typename Device, typename T, bool conjugate>
void TransposeSimple(const Device& d, const Tensor& in,
const gtl::ArraySlice<int32> perm, Tensor* out) {
// Ensures we can use 32-bit index.
@@ -73,9 +80,10 @@ void TransposeSimple(const Device& d, const Tensor& in,
const T* p = reinterpret_cast<const T*>(in.tensor_data().data());
T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data())));
CudaLaunchConfig cfg = GetCudaLaunchConfig(nelem, d);
- TransposeKernel<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(
- cfg.virtual_thread_count, p, reinterpret_cast<const int32*>(dev_buf),
- ndims, q);
+ TransposeKernel<T, conjugate>
+ <<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(
+ cfg.virtual_thread_count, p, reinterpret_cast<const int32*>(dev_buf),
+ ndims, q);
// Safe to deallocate immediately after the kernel launch.
d.deallocate(dev_buf);
}
@@ -84,133 +92,152 @@ void TransposeSimple(const Device& d, const Tensor& in,
// then call special kernels to swap either dimension 1 and dimension 2 or
// dimension 0 and dimension 2. It returns true if the operation is success,
// false otherwise.
-template <typename T>
-bool TransposeUsingTile(const Eigen::GpuDevice& d, const Tensor& in,
- const gtl::ArraySlice<int32> perm, Tensor* out) {
- // First try to reduce the dimensions of the input tensor.
- TransposePermsVec new_perm;
- TransposeDimsVec new_dims;
- ReduceTransposeDimensions(in.shape(), perm, &new_perm, &new_dims);
-
- // Only use special GPU kernel when dimension is 2 or 3.
- int dims = new_dims.size();
- if (dims < 2 || dims > 3) return false;
- auto in_data = reinterpret_cast<const T*>(in.tensor_data().data());
- auto out_data =
- reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data()));
- switch (dims) {
- case 2:
- if (new_perm[0] == 1 && new_perm[1] == 0) {
- // Add the first dimension size as 1.
- new_dims.insert(new_dims.begin(), 1);
- tensorflow::functor::SwapDimension1And2InTensor3<Eigen::GpuDevice, T>()(
- d, in_data, new_dims, out_data);
- return true;
- }
- break;
- case 3:
- if (new_perm == TransposePermsVec({0, 2, 1})) {
- tensorflow::functor::SwapDimension1And2InTensor3<Eigen::GpuDevice, T>()(
- d, in_data, new_dims, out_data);
- return true;
- } else if (new_perm == TransposePermsVec({2, 1, 0})) {
- tensorflow::functor::SwapDimension0And2InTensor3<Eigen::GpuDevice, T>()(
- d, in_data, new_dims, out_data);
- return true;
- } else {
- // do not handle other 3D permutations
+template <typename T, bool conjugate = false>
+struct TransposeUsingTile {
+ static bool run(const Eigen::GpuDevice& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ // First try to reduce the dimensions of the input tensor.
+ TransposePermsVec new_perm;
+ TransposeDimsVec new_dims;
+ ReduceTransposeDimensions(in.shape(), perm, &new_perm, &new_dims);
+
+ // Only use special GPU kernel when dimension is 2 or 3.
+ int dims = new_dims.size();
+ if (dims < 2 || dims > 3) return false;
+ auto in_data = reinterpret_cast<const T*>(in.tensor_data().data());
+ auto out_data =
+ reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data()));
+ switch (dims) {
+ case 2:
+ if (new_perm[0] == 1 && new_perm[1] == 0) {
+ // Add the first dimension size as 1.
+ new_dims.insert(new_dims.begin(), 1);
+ tensorflow::functor::SwapDimension1And2InTensor3<GPUDevice, T>()(
+ d, in_data, new_dims, out_data);
+ return true;
+ }
+ break;
+ case 3:
+ if (new_perm == TransposePermsVec({0, 2, 1})) {
+ tensorflow::functor::SwapDimension1And2InTensor3<GPUDevice, T>()(
+ d, in_data, new_dims, out_data);
+ return true;
+ } else if (new_perm == TransposePermsVec({2, 1, 0})) {
+ tensorflow::functor::SwapDimension0And2InTensor3<GPUDevice, T>()(
+ d, in_data, new_dims, out_data);
+ return true;
+ } else {
+ // do not handle other 3D permutations
+ return false;
+ }
+ break;
+ default:
return false;
- }
- break;
- default:
+ }
+ return false;
+ }
+};
+
+template <bool conjugate>
+struct TransposeUsingTile<complex64, conjugate> {
+ static bool run(const Eigen::GpuDevice& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ if (!TransposeUsingTile<uint64>::run(d, in, perm, out)) {
return false;
+ }
+ if (conjugate) {
+ // TODO(rmlarsen): Get rid of this call and conjugate on-the-fly in the
+ // transposition kernels so we only touch the memory once.
+ functor::UnaryFunctor<GPUDevice, functor::conj<complex64>> conj;
+ conj(d, out->flat<complex64>() /*out*/,
+ const_cast<const Tensor*>(out)->flat<complex64>() /*in*/);
+ }
+ return true;
}
- return false;
-}
+};
+
+template <bool conjugate>
+struct TransposeUsingTile<complex128, conjugate> {
+ static bool run(const Eigen::GpuDevice& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ if (!TransposeUsingTile<float4>::run(d, in, perm, out)) {
+ return false;
+ }
+ if (conjugate) {
+ // TODO(rmlarsen): Get rid of this call and conjugate on-the-fly in the
+ // transposition kernels so we only touch the memory once.
+ functor::UnaryFunctor<GPUDevice, functor::conj<complex128>> conj;
+ conj(d, out->flat<complex128>() /*out*/,
+ const_cast<const Tensor*>(out)->flat<complex128>() /*in*/);
+ }
+ return true;
+ }
+};
} // end namespace internal
-typedef Eigen::GpuDevice GPUDevice;
+template <>
+Status DoTranspose(const GPUDevice& device, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ return internal::DoTransposeImpl<GPUDevice>::run(device, in, perm,
+ false /* conjugate */, out);
+}
+
+template <>
+Status DoConjugateTranspose(const GPUDevice& device, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ return internal::DoTransposeImpl<GPUDevice>::run(device, in, perm,
+ true /* conjugate */, out);
+}
// Transpose kernel specialized for CPU Device.
-template <typename T>
-struct Transpose<GPUDevice, T> {
+template <typename T, bool conjugate>
+struct Transpose<GPUDevice, T, conjugate> {
static void run(const GPUDevice& d, const Tensor& in,
const gtl::ArraySlice<int32> perm, Tensor* out) {
switch (in.dims()) {
case 2:
- if (!internal::TransposeUsingTile<T>(d, in, perm, out)) {
- internal::TransposeUsingEigen<GPUDevice, T, 2>(d, in, perm, out);
+ if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
+ out)) {
+ internal::TransposeUsingEigen<GPUDevice, T, 2>(d, in, perm, conjugate,
+ out);
}
break;
case 3:
- if (!internal::TransposeUsingTile<T>(d, in, perm, out)) {
- internal::TransposeUsingEigen<GPUDevice, T, 3>(d, in, perm, out);
+ if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
+ out)) {
+ internal::TransposeUsingEigen<GPUDevice, T, 3>(d, in, perm, conjugate,
+ out);
}
break;
case 4:
- if (!internal::TransposeUsingTile<T>(d, in, perm, out)) {
- internal::TransposeUsingEigen<GPUDevice, T, 4>(d, in, perm, out);
+ if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
+ out)) {
+ internal::TransposeUsingEigen<GPUDevice, T, 4>(d, in, perm, conjugate,
+ out);
}
break;
case 5:
- if (!internal::TransposeUsingTile<T>(d, in, perm, out)) {
- internal::TransposeUsingEigen<GPUDevice, T, 5>(d, in, perm, out);
+ if (!internal::TransposeUsingTile<T, conjugate>::run(d, in, perm,
+ out)) {
+ internal::TransposeUsingEigen<GPUDevice, T, 5>(d, in, perm, conjugate,
+ out);
}
break;
default:
- internal::TransposeSimple<GPUDevice, T>(d, in, perm, out);
+ internal::TransposeSimple<GPUDevice, T, conjugate>(d, in, perm, out);
break;
}
}
};
template <>
-Status DoTranspose<GPUDevice>(const GPUDevice& d, const Tensor& in,
- const gtl::ArraySlice<int32> perm, Tensor* out) {
- CHECK_GE(in.dims(), 2);
- CHECK_EQ(in.dims(), out->dims());
- CHECK_EQ(in.dims(), perm.size());
- CHECK_EQ(in.dtype(), out->dtype());
- switch (in.dtype()) {
- case DT_BOOL:
- case DT_INT8:
- case DT_QINT8:
- case DT_QUINT8:
- case DT_UINT8:
- Transpose<GPUDevice, uint8>::run(d, in, perm, out);
- break;
-
- case DT_BFLOAT16:
- case DT_HALF:
- case DT_INT16:
- case DT_QINT16:
- case DT_QUINT16:
- case DT_UINT16:
- Transpose<GPUDevice, uint16>::run(d, in, perm, out);
- break;
-
- case DT_FLOAT:
- case DT_INT32:
- case DT_QINT32:
- Transpose<GPUDevice, uint32>::run(d, in, perm, out);
- break;
-
- case DT_COMPLEX64:
- case DT_DOUBLE:
- case DT_INT64:
- Transpose<GPUDevice, uint64>::run(d, in, perm, out);
- break;
-
- case DT_COMPLEX128:
- Transpose<GPUDevice, float4>::run(d, in, perm, out);
- break;
-
- default:
- return errors::Unimplemented("Unsupported dtype on GPU: ", in.dtype());
+struct Transpose<GPUDevice, string> {
+ static void run(const GPUDevice& d, const Tensor& in,
+ const gtl::ArraySlice<int32> perm, Tensor* out) {
+ LOG(FATAL) << "Transpose of DT_STRING tensor not supported on GPU.";
}
- return Status::OK();
-}
+};
} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index d3305fb83a..e151b38d90 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -142,17 +142,18 @@ void TransposeOp::Compute(OpKernelContext* ctx) {
}
}
for (int i = 0; i < dims; ++i) {
- OP_REQUIRES(ctx, bits[i], errors::InvalidArgument(
- i, " is missing from {",
- str_util::Join(permutation, ","), "}."));
+ OP_REQUIRES(
+ ctx, bits[i],
+ errors::InvalidArgument(i, " is missing from {",
+ str_util::Join(permutation, ","), "}."));
}
// 0-D, 1-D, and identity transposes do nothing.
- if (dims <= 1 || is_identity) {
+ if (!IsConjugate() && (dims <= 1 || is_identity)) {
ctx->set_output(0, input);
return;
- } else if (internal::NonSingletonDimensionsAlign(input.shape(),
- permutation)) {
+ } else if (!IsConjugate() && internal::NonSingletonDimensionsAlign(
+ input.shape(), permutation)) {
Tensor output;
OP_REQUIRES(ctx, output.CopyFrom(input, shape),
errors::Unknown("Error reshaping Tensor."));
@@ -174,6 +175,15 @@ Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
out);
}
+Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
+ const Tensor& in,
+ gtl::ArraySlice<int32> perm,
+ Tensor* out) {
+ typedef Eigen::ThreadPoolDevice CPUDevice;
+ return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<CPUDevice>(), in,
+ perm, out);
+}
+
#ifdef INTEL_MKL
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
@@ -181,7 +191,13 @@ Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
.TypeConstraint<T>("T") \
.TypeConstraint<int32>("Tperm") \
.HostMemory("perm"), \
- MklTransposeCpuOp);
+ MklTransposeCpuOp); \
+ REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int32>("Tperm") \
+ .HostMemory("perm"), \
+ MklConjugateTransposeCpuOp);
TF_CALL_ALL_TYPES(REGISTER);
REGISTER(bfloat16);
#undef REGISTER
@@ -194,7 +210,13 @@ REGISTER(bfloat16);
.TypeConstraint<T>("T") \
.TypeConstraint<int32>("Tperm") \
.HostMemory("perm"), \
- TransposeCpuOp);
+ TransposeCpuOp); \
+ REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int32>("Tperm") \
+ .HostMemory("perm"), \
+ ConjugateTransposeCpuOp);
TF_CALL_ALL_TYPES(REGISTER)
REGISTER(bfloat16);
#undef REGISTER
@@ -207,6 +229,14 @@ Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
return ::tensorflow::DoTranspose(ctx->eigen_device<GPUDevice>(), in, perm,
out);
}
+Status ConjugateTransposeGpuOp::DoTranspose(OpKernelContext* ctx,
+ const Tensor& in,
+ gtl::ArraySlice<int32> perm,
+ Tensor* out) {
+ typedef Eigen::GpuDevice GPUDevice;
+ return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<GPUDevice>(), in,
+ perm, out);
+}
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
@@ -214,25 +244,45 @@ Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
.TypeConstraint<T>("T") \
.TypeConstraint<int32>("Tperm") \
.HostMemory("perm"), \
- TransposeGpuOp);
+ TransposeGpuOp); \
+ REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int32>("Tperm") \
+ .HostMemory("perm"), \
+ ConjugateTransposeGpuOp);
TF_CALL_POD_TYPES(REGISTER);
#undef REGISTER
#endif
#ifdef TENSORFLOW_USE_SYCL
Status TransposeSyclOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
- gtl::ArraySlice<int32> perm, Tensor* out) {
+ gtl::ArraySlice<int32> perm, Tensor* out) {
typedef Eigen::SyclDevice SYCLDevice;
return ::tensorflow::DoTranspose(ctx->eigen_device<SYCLDevice>(), in, perm,
out);
}
+Status ConjugateTransposeSyclOp::DoTranspose(OpKernelContext* ctx,
+ const Tensor& in,
+ gtl::ArraySlice<int32> perm,
+ Tensor* out) {
+ typedef Eigen::SyclDevice SYCLDevice;
+ return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<SYCLDevice>(), in,
+ perm, out);
+}
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
.Device(DEVICE_SYCL) \
.TypeConstraint<T>("T") \
.TypeConstraint<int32>("Tperm") \
.HostMemory("perm"), \
- TransposeSyclOp);
+ TransposeSyclOp); \
+ REGISTER_KERNEL_BUILDER(Name("ConjugateTranspose") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int32>("Tperm") \
+ .HostMemory("perm"), \
+ ConjugateTransposeSyclOp);
TF_CALL_POD_TYPES(REGISTER);
#undef REGISTER
#endif
diff --git a/tensorflow/core/kernels/transpose_op.h b/tensorflow/core/kernels/transpose_op.h
index a69eecc2f8..ff9cf5d4ff 100644
--- a/tensorflow/core/kernels/transpose_op.h
+++ b/tensorflow/core/kernels/transpose_op.h
@@ -30,6 +30,7 @@ class TransposeOp : public OpKernel {
protected:
virtual Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
gtl::ArraySlice<int32> perm, Tensor* out) = 0;
+ virtual bool IsConjugate() const { return false; }
};
class TransposeCpuOp : public TransposeOp {
@@ -70,7 +71,57 @@ class TransposeSyclOp : public TransposeOp {
Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
gtl::ArraySlice<int32> perm, Tensor* out) override;
};
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
+
+// Conjugating transpose ops.
+class ConjugateTransposeCpuOp : public TransposeOp {
+ public:
+ explicit ConjugateTransposeCpuOp(OpKernelConstruction* ctx)
+ : TransposeOp(ctx) {}
+
+ protected:
+ Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
+ gtl::ArraySlice<int32> perm, Tensor* out) override;
+ bool IsConjugate() const override { return true; }
+};
+
+#ifdef INTEL_MKL
+template <bool conjugate = false>
+class MklConjugateTransposeCpuOp : public TransposeOp {
+ public:
+ explicit MklConjugateTransposeCpuOp(OpKernelConstruction* ctx)
+ : TransposeOp(ctx) {}
+
+ protected:
+ Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
+ gtl::ArraySlice<int32> perm, Tensor* out) override;
+ bool IsConjugate() const override { return true; }
+};
+#endif // INTEL_MKL
+
+class ConjugateTransposeGpuOp : public TransposeOp {
+ public:
+ explicit ConjugateTransposeGpuOp(OpKernelConstruction* ctx)
+ : TransposeOp(ctx) {}
+
+ protected:
+ Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
+ gtl::ArraySlice<int32> perm, Tensor* out) override;
+ bool IsConjugate() const override { return true; }
+};
+
+#ifdef TENSORFLOW_USE_SYCL
+class ConjugateTransposeSyclOp : public TransposeOp {
+ public:
+ explicit ConjugateTransposeSyclOp(OpKernelConstruction* ctx)
+ : TransposeOp(ctx) {}
+
+ protected:
+ Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
+ gtl::ArraySlice<int32> perm, Tensor* out) override;
+ bool IsConjugate() const override { return true; }
+};
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 108c29ed6e..25a7c9eb39 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -110,6 +110,64 @@ Status PadShapeFn(InferenceContext* c) {
}
}
+Status TransposeShapeFn(InferenceContext* c) {
+ ShapeHandle input = c->input(0);
+ ShapeHandle perm_shape = c->input(1);
+ const Tensor* perm = c->input_tensor(1);
+ DimensionHandle perm_elems = c->NumElements(perm_shape);
+ // If we don't have rank information on the input or value information on
+ // perm we can't return any shape information, otherwise we have enough
+ // information to at least find the rank of the output.
+ if (!c->RankKnown(input) && !c->ValueKnown(perm_elems) && perm == nullptr) {
+ c->set_output(0, c->UnknownShape());
+ return Status::OK();
+ }
+
+ // Find our value of the rank.
+ int64 rank;
+ if (c->RankKnown(input)) {
+ rank = c->Rank(input);
+ } else if (c->ValueKnown(perm_elems)) {
+ rank = c->Value(perm_elems);
+ } else {
+ rank = perm->NumElements();
+ }
+ std::vector<DimensionHandle> dims;
+ dims.resize(rank);
+ TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
+ // Ensure that perm is a vector and has rank elements.
+ TF_RETURN_IF_ERROR(c->WithRank(perm_shape, 1, &perm_shape));
+ TF_RETURN_IF_ERROR(c->WithValue(perm_elems, rank, &perm_elems));
+
+ // If we know the rank of the input and the value of perm, we can return
+ // all shape informantion, otherwise we can only return rank information,
+ // but no information for the dimensions.
+ if (perm != nullptr) {
+ std::vector<int64> data;
+ if (perm->dtype() == DT_INT32) {
+ data = AsInt64<int32>(perm, rank);
+ } else {
+ data = AsInt64<int64>(perm, rank);
+ }
+
+ for (int32 i = 0; i < rank; ++i) {
+ int64 in_idx = data[i];
+ if (in_idx >= rank) {
+ return errors::InvalidArgument("perm dim ", in_idx,
+ " is out of range of input rank ", rank);
+ }
+ dims[i] = c->Dim(input, in_idx);
+ }
+ } else {
+ for (int i = 0; i < rank; ++i) {
+ dims[i] = c->UnknownDim();
+ }
+ }
+
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+}
+
Status SetOutputShapeForReshape(InferenceContext* c) {
ShapeHandle in = c->input(0);
ShapeHandle out;
@@ -1913,69 +1971,28 @@ REGISTER_OP("Transpose")
.Output("y: T")
.Attr("T: type")
.Attr("Tperm: {int32, int64} = DT_INT32")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle input = c->input(0);
- ShapeHandle perm_shape = c->input(1);
- const Tensor* perm = c->input_tensor(1);
- DimensionHandle perm_elems = c->NumElements(perm_shape);
- // If we don't have rank information on the input or value information on
- // perm we can't return any shape information, otherwise we have enough
- // information to at least find the rank of the output.
- if (!c->RankKnown(input) && !c->ValueKnown(perm_elems) &&
- perm == nullptr) {
- c->set_output(0, c->UnknownShape());
- return Status::OK();
- }
-
- // Find our value of the rank.
- int64 rank;
- if (c->RankKnown(input)) {
- rank = c->Rank(input);
- } else if (c->ValueKnown(perm_elems)) {
- rank = c->Value(perm_elems);
- } else {
- rank = perm->NumElements();
- }
- std::vector<DimensionHandle> dims;
- dims.resize(rank);
- TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
- // Ensure that perm is a vector and has rank elements.
- TF_RETURN_IF_ERROR(c->WithRank(perm_shape, 1, &perm_shape));
- TF_RETURN_IF_ERROR(c->WithValue(perm_elems, rank, &perm_elems));
-
- // If we know the rank of the input and the value of perm, we can return
- // all shape informantion, otherwise we can only return rank information,
- // but no information for the dimensions.
- if (perm != nullptr) {
- std::vector<int64> data;
- if (perm->dtype() == DT_INT32) {
- data = AsInt64<int32>(perm, rank);
- } else {
- data = AsInt64<int64>(perm, rank);
- }
+ .SetShapeFn(TransposeShapeFn)
+ .Doc(R"doc(
+Shuffle dimensions of x according to a permutation.
- for (int32 i = 0; i < rank; ++i) {
- int64 in_idx = data[i];
- if (in_idx >= rank) {
- return errors::InvalidArgument(
- "perm dim ", in_idx, " is out of range of input rank ", rank);
- }
- dims[i] = c->Dim(input, in_idx);
- }
- } else {
- for (int i = 0; i < rank; ++i) {
- dims[i] = c->UnknownDim();
- }
- }
+The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
+ `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
+)doc");
- c->set_output(0, c->MakeShape(dims));
- return Status::OK();
- })
+// --------------------------------------------------------------------------
+REGISTER_OP("ConjugateTranspose")
+ .Input("x: T")
+ .Input("perm: Tperm")
+ .Output("y: T")
+ .Attr("T: type")
+ .Attr("Tperm: {int32, int64} = DT_INT32")
+ .SetShapeFn(TransposeShapeFn)
.Doc(R"doc(
-Shuffle dimensions of x according to a permutation.
+Shuffle dimensions of x according to a permutation and conjugate the result.
The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
`y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
+ `y[i,j,k,...,s,t,u] == conj(x[perm[i], perm[j], perm[k],...,perm[s], perm[t], perm[u]])`
)doc");
// --------------------------------------------------------------------------
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 1792886417..8f4c94f318 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -51,6 +51,15 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
self.assertEqual((3, 2), transposed.get_shape())
self.assertAllEqual(expected_transposed, transposed.eval())
+ def testConjugate(self):
+ m = [[1 + 1j, 2 + 2j, 3 + 3j], [4 + 4j, 5 + 5j, 6 + 6j]]
+ expected_transposed = [[1 - 1j, 4 - 4j], [2 - 2j, 5 - 5j], [3 - 3j, 6 - 6j]]
+ with self.test_session():
+ matrix = ops.convert_to_tensor(m)
+ transposed = array_ops.matrix_transpose(matrix, conjugate=True)
+ self.assertEqual((3, 2), transposed.get_shape())
+ self.assertAllEqual(expected_transposed, transposed.eval())
+
def testBatchMatrix(self):
matrix_0 = [[1, 2, 3], [4, 5, 6]]
matrix_0_t = [[1, 4], [2, 5], [3, 6]]
diff --git a/tensorflow/python/kernel_tests/linalg_ops_test.py b/tensorflow/python/kernel_tests/linalg_ops_test.py
index 8bb583ce1b..2f28d37eff 100644
--- a/tensorflow/python/kernel_tests/linalg_ops_test.py
+++ b/tensorflow/python/kernel_tests/linalg_ops_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
@@ -120,6 +121,20 @@ class SlogdetTest(test.TestCase):
self.assertAllClose(sign_np, sign_tf.eval(), atol=atol)
+class AdjointTest(test.TestCase):
+
+ def test_compare_to_numpy(self):
+ for dtype in np.float64, np.float64, np.complex64, np.complex128:
+ matrix_np = np.array([[1 + 1j, 2 + 2j, 3 + 3j], [4 + 4j, 5 + 5j,
+ 6 + 6j]]).astype(dtype)
+ expected_transposed = np.conj(matrix_np.T)
+ with self.test_session():
+ matrix = ops.convert_to_tensor(matrix_np)
+ transposed = linalg.adjoint(matrix)
+ self.assertEqual((3, 2), transposed.get_shape())
+ self.assertAllEqual(expected_transposed, transposed.eval())
+
+
class EyeTest(test.TestCase):
pass # Will be filled in below
diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py
index 9e1f83395b..3b352937c8 100644
--- a/tensorflow/python/kernel_tests/transpose_op_test.py
+++ b/tensorflow/python/kernel_tests/transpose_op_test.py
@@ -38,14 +38,16 @@ class TransposeTest(test.TestCase):
ret = ret.transpose(perm)
return ret
- def _compareCpu(self, x, p):
+ def _compareCpu(self, x, p, conjugate=False):
np_ans = self._np_transpose(x, p)
+ if conjugate:
+ np_ans = np.conj(np_ans)
with self.test_session(use_gpu=False):
inx = ops.convert_to_tensor(x)
- y = array_ops.transpose(inx, p)
+ y = array_ops.transpose(inx, p, conjugate=conjugate)
tf_ans = y.eval()
- self.assertAllEqual(np_ans, tf_ans)
self.assertShapeEqual(np_ans, y)
+ self.assertAllEqual(np_ans, tf_ans)
jacob_t = None
# Gradient check on CPU.
@@ -62,11 +64,13 @@ class TransposeTest(test.TestCase):
return tf_ans, jacob_t
- def _compareGpu(self, x, p):
+ def _compareGpu(self, x, p, conjugate=False):
np_ans = self._np_transpose(x, p)
+ if conjugate:
+ np_ans = np.conj(np_ans)
with self.test_session(use_gpu=True):
inx = ops.convert_to_tensor(x)
- y = array_ops.transpose(inx, p)
+ y = array_ops.transpose(inx, p, conjugate=conjugate)
tf_ans = y.eval()
self.assertAllEqual(np_ans, tf_ans)
@@ -92,10 +96,12 @@ class TransposeTest(test.TestCase):
# generate all permutations of [0, 1, ... n-1] in random order.
all_perm = np.random.permutation(
[p for p in itertools.permutations(range(n))]).astype(np.int32)
- for p in all_perm[:2]:
- self._compareCpu(x, p)
- if use_gpu:
- self._compareGpu(x, p)
+ cs = [False, True] if x.dtype in [np.complex64, np.complex128] else [False]
+ for c in cs:
+ for p in all_perm[:2]:
+ self._compareCpu(x, p, conjugate=c)
+ if use_gpu:
+ self._compareGpu(x, p, conjugate=c)
def _compare_cpu_gpu(self, x):
n = np.ndim(x)
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 61405e3f45..dc3aa735da 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1283,13 +1283,15 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"):
name=name)
-def transpose(a, perm=None, name="transpose"):
+def transpose(a, perm=None, name="transpose", conjugate=False):
"""Transposes `a`. Permutes the dimensions according to `perm`.
The returned tensor's dimension i will correspond to the input dimension
`perm[i]`. If `perm` is not given, it is set to (n-1...0), where n is
the rank of the input tensor. Hence by default, this operation performs a
- regular matrix transpose on 2-D input Tensors.
+ regular matrix transpose on 2-D input Tensors. If conjugate is True and
+ `a.dtype` is either `complex64` or `complex128` then the values of `a`
+ are conjugated and transposed.
For example:
@@ -1304,6 +1306,13 @@ def transpose(a, perm=None, name="transpose"):
# [2, 5]
# [3, 6]]
+ # If x is complex, setting conjugate=True gives the conjugate transpose
+ x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
+ [4 + 4j, 5 + 5j, 6 + 6j]])
+ tf.transpose(x, conjugate=True) # [[1 - 1j, 4 - 4j],
+ # [2 - 2j, 5 - 5j],
+ # [3 - 3j, 6 - 6j]]
+
# 'perm' is more useful for n-dimensional tensors, for n > 2
x = tf.constant([[[ 1, 2, 3],
[ 4, 5, 6]],
@@ -1311,6 +1320,7 @@ def transpose(a, perm=None, name="transpose"):
[10, 11, 12]]])
# Take the transpose of the matrices in dimension-0
+ # (this common operation has a shorthand `matrix_transpose`)
tf.transpose(x, perm=[0, 2, 1]) # [[[1, 4],
# [2, 5],
# [3, 6]],
@@ -1323,15 +1333,20 @@ def transpose(a, perm=None, name="transpose"):
a: A `Tensor`.
perm: A permutation of the dimensions of `a`.
name: A name for the operation (optional).
+ conjugate: Optional bool. Setting it to `True` is mathematically equivalent
+ to tf.conj(tf.transpose(input)).
Returns:
A transposed `Tensor`.
"""
with ops.name_scope(name, "transpose", [a]) as name:
+ transpose_fn = (
+ gen_array_ops._conjugate_transpose
+ if conjugate else gen_array_ops.transpose)
if perm is None:
rank = gen_array_ops.rank(a)
perm = (rank - 1) - gen_math_ops._range(0, rank, 1)
- ret = gen_array_ops.transpose(a, perm, name=name)
+ ret = transpose_fn(a, perm, name=name)
# NOTE(mrry): Setting the shape explicitly because
# reverse is not handled by the shape function.
if context.in_graph_mode():
@@ -1339,12 +1354,12 @@ def transpose(a, perm=None, name="transpose"):
if input_shape is not None:
ret.set_shape(input_shape[::-1])
else:
- ret = gen_array_ops.transpose(a, perm, name=name)
+ ret = transpose_fn(a, perm, name=name)
return ret
# pylint: disable=invalid-name
-def matrix_transpose(a, name="matrix_transpose"):
+def matrix_transpose(a, name="matrix_transpose", conjugate=False):
"""Transposes last two dimensions of tensor `a`.
For example:
@@ -1355,6 +1370,12 @@ def matrix_transpose(a, name="matrix_transpose"):
# [2, 5],
# [3, 6]]
+ x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
+ [4 + 4j, 5 + 5j, 6 + 6j]])
+ tf.matrix_transpose(x, conjugate=True) # [[1 - 1j, 4 - 4j],
+ # [2 - 2j, 5 - 5j],
+ # [3 - 3j, 6 - 6j]]
+
# Matrix with two batch dimensions.
# x.shape is [1, 2, 3, 4]
# tf.matrix_transpose(x) is shape [1, 2, 4, 3]
@@ -1374,6 +1395,8 @@ def matrix_transpose(a, name="matrix_transpose"):
Args:
a: A `Tensor` with `rank >= 2`.
name: A name for the operation (optional).
+ conjugate: Optional bool. Setting it to `True` is mathematically equivalent
+ to tf.conj(tf.matrix_transpose(input)).
Returns:
A transposed batch matrix `Tensor`.
@@ -1401,7 +1424,7 @@ def matrix_transpose(a, name="matrix_transpose"):
perm = concat((gen_math_ops._range(0, a_rank - 2, 1),
[a_rank - 1, a_rank - 2]), 0)
- return transpose(a, perm=perm)
+ return transpose(a, perm=perm, conjugate=conjugate)
# pylint: enable=invalid-name
diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt
index d27e867583..fcd378e3c0 100644
--- a/tensorflow/python/ops/hidden_ops.txt
+++ b/tensorflow/python/ops/hidden_ops.txt
@@ -5,6 +5,7 @@ BroadcastGradientArgs
ConcatOffset
Concat
ConcatV2
+ConjugateTranspose
Const
DebugGradientIdentity
EditDistance
diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py
index ca57653d14..32d1b31d7d 100644
--- a/tensorflow/python/ops/linalg/linalg_impl.py
+++ b/tensorflow/python/ops/linalg/linalg_impl.py
@@ -54,3 +54,28 @@ def logdet(matrix, name=None):
return 2.0 * math_ops.reduce_sum(
math_ops.log(math_ops.real(array_ops.matrix_diag_part(chol))),
reduction_indices=[-1])
+
+
+def adjoint(matrix, name=None):
+ """Conjugates and transposes the last two dimensions of tensor `matrix`.
+
+ For example:
+
+ ```python
+ x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
+ [4 + 4j, 5 + 5j, 6 + 6j]])
+ tf.linalg.adjoint(x) # [[1 - 1j, 4 - 4j],
+ # [2 - 2j, 5 - 5j],
+ # [3 - 3j, 6 - 6j]]
+
+ Args:
+ matrix: A `Tensor`. Must be `float32`, `float64`, `complex64`, or
+ `complex128` with shape `[..., M, M]`.
+ name: A name to give this `Op` (optional).
+
+ Returns:
+ The adjoint (a.k.a. Hermitian transpose a.k.a. conjugate transpose) of
+ matrix.
+ """
+ with ops.name_scope(name, 'adjoint', [matrix]):
+ return array_ops.matrix_transpose(matrix, conjugate=True)
diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
index 4c94863caa..0d62585ff4 100644
--- a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
@@ -33,6 +33,10 @@ tf_module {
mtype: "<class \'abc.ABCMeta\'>"
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'matrix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "band_part"
argspec: "args=[\'input\', \'num_lower\', \'num_upper\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -118,7 +122,7 @@ tf_module {
}
member_method {
name: "transpose"
- argspec: "args=[\'a\', \'name\'], varargs=None, keywords=None, defaults=[\'matrix_transpose\'], "
+ argspec: "args=[\'a\', \'name\', \'conjugate\'], varargs=None, keywords=None, defaults=[\'matrix_transpose\', \'False\'], "
}
member_method {
name: "triangular_solve"
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index d77f8fd253..d56a59de72 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -1354,7 +1354,7 @@ tf_module {
}
member_method {
name: "matrix_transpose"
- argspec: "args=[\'a\', \'name\'], varargs=None, keywords=None, defaults=[\'matrix_transpose\'], "
+ argspec: "args=[\'a\', \'name\', \'conjugate\'], varargs=None, keywords=None, defaults=[\'matrix_transpose\', \'False\'], "
}
member_method {
name: "matrix_triangular_solve"
@@ -1990,7 +1990,7 @@ tf_module {
}
member_method {
name: "transpose"
- argspec: "args=[\'a\', \'perm\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'transpose\'], "
+ argspec: "args=[\'a\', \'perm\', \'name\', \'conjugate\'], varargs=None, keywords=None, defaults=[\'None\', \'transpose\', \'False\'], "
}
member_method {
name: "truediv"