diff options
Diffstat (limited to 'tensorflow/core/kernels/transpose_op.cc')
-rw-r--r-- | tensorflow/core/kernels/transpose_op.cc | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index 4d303f0173..fb2ceb4a4a 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -180,6 +180,20 @@ Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, out); } +#ifdef INTEL_MKL +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER(Name("Transpose") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<int32>("Tperm") \ + .HostMemory("perm"), \ + MklTransposeCpuOp); +TF_CALL_ALL_TYPES(REGISTER); +REGISTER(bfloat16); +#undef REGISTER + +#else // INTEL_MKL + #define REGISTER(T) \ REGISTER_KERNEL_BUILDER(Name("Transpose") \ .Device(DEVICE_CPU) \ @@ -190,6 +204,7 @@ Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, TF_CALL_ALL_TYPES(REGISTER) REGISTER(bfloat16); #undef REGISTER +#endif // INTEL_MKL #if GOOGLE_CUDA Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, |