diff options
author | 2018-03-19 18:20:12 -0700 | |
---|---|---|
committer | 2018-03-19 18:24:30 -0700 | |
commit | 331bbe2886712fffc96ed9a7fb33fc9f09600240 (patch) | |
tree | 9745c3830f9d48e3b399b8f84d9f2661e13d63f2 /tensorflow/core/kernels/data_format_ops.h | |
parent | d548cb4e811fc8a04dd10370c576441fc56b03f2 (diff) |
Support general permutation.
PiperOrigin-RevId: 189675019
Diffstat (limited to 'tensorflow/core/kernels/data_format_ops.h')
-rw-r--r-- | tensorflow/core/kernels/data_format_ops.h | 131 |
1 files changed, 9 insertions, 122 deletions
diff --git a/tensorflow/core/kernels/data_format_ops.h b/tensorflow/core/kernels/data_format_ops.h index d27415ed91..2ccc919586 100644 --- a/tensorflow/core/kernels/data_format_ops.h +++ b/tensorflow/core/kernels/data_format_ops.h @@ -23,13 +23,6 @@ limitations under the License. namespace tensorflow { namespace functor { -enum class DataFormat { - UNKNOWN = 0, - NHWC, - NCHW, - HWNC, -}; - // Functor used by DataFormatDimMapOP to do the computations. template <typename Device, typename T> struct DataFormatDimMap { @@ -47,65 +40,8 @@ struct DataFormatDimMap { }; template <typename T> -struct VecPermuteNHWCToNCHW { - Eigen::DSizes<Eigen::DenseIndex, 1> dimensions( - typename TTypes<T>::ConstFlat input) const { - Eigen::DSizes<Eigen::DenseIndex, 1> result; - result[0] = input.dimension(0); - return result; - } - template <typename Output, typename Device> - void eval(typename TTypes<T>::ConstFlat input, Output& output, - const Device& d) const { - if (input.size() == 8) { - output.template chip<0>(0).device(d) = input.template chip<0>(0); - output.template chip<0>(1).device(d) = input.template chip<0>(1); - output.template chip<0>(2).device(d) = input.template chip<0>(6); - output.template chip<0>(3).device(d) = input.template chip<0>(7); - output.template chip<0>(4).device(d) = input.template chip<0>(2); - output.template chip<0>(5).device(d) = input.template chip<0>(3); - output.template chip<0>(6).device(d) = input.template chip<0>(4); - output.template chip<0>(7).device(d) = input.template chip<0>(5); - } else { - output.template chip<0>(0).device(d) = input.template chip<0>(0); - output.template chip<0>(1).device(d) = input.template chip<0>(3); - output.template chip<0>(2).device(d) = input.template chip<0>(1); - output.template chip<0>(3).device(d) = input.template chip<0>(2); - } - } -}; - -template <typename T> -struct VecPermuteNCHWToNHWC { - Eigen::DSizes<Eigen::DenseIndex, 1> dimensions( - typename TTypes<T>::ConstFlat input) const { - Eigen::DSizes<Eigen::DenseIndex, 1> result; - result[0] = input.dimension(0); - return result; - } - template <typename Output, typename Device> - void eval(typename TTypes<T>::ConstFlat input, Output& output, - const Device& d) const { - if (input.size() == 8) { - output.template chip<0>(0).device(d) = input.template chip<0>(0); - output.template chip<0>(1).device(d) = input.template chip<0>(1); - output.template chip<0>(2).device(d) = input.template chip<0>(4); - output.template chip<0>(3).device(d) = input.template chip<0>(5); - output.template chip<0>(4).device(d) = input.template chip<0>(6); - output.template chip<0>(5).device(d) = input.template chip<0>(7); - output.template chip<0>(6).device(d) = input.template chip<0>(2); - output.template chip<0>(7).device(d) = input.template chip<0>(3); - } else { - output.template chip<0>(0).device(d) = input.template chip<0>(0); - output.template chip<0>(1).device(d) = input.template chip<0>(2); - output.template chip<0>(2).device(d) = input.template chip<0>(3); - output.template chip<0>(3).device(d) = input.template chip<0>(1); - } - } -}; - -template <typename T> -struct VecPermuteNHWCToHWNC { +struct VecPermute { + VecPermute(const Eigen::DSizes<Eigen::DenseIndex, 8>& dst) : dst_(dst) {} Eigen::DSizes<Eigen::DenseIndex, 1> dimensions( typename TTypes<T>::ConstFlat input) const { Eigen::DSizes<Eigen::DenseIndex, 1> result; @@ -115,71 +51,22 @@ struct VecPermuteNHWCToHWNC { template <typename Output, typename Device> void eval(typename TTypes<T>::ConstFlat input, Output& output, const Device& d) const { - if (input.size() == 8) { - output.template chip<0>(0).device(d) = input.template chip<0>(2); - output.template chip<0>(1).device(d) = input.template chip<0>(3); - output.template chip<0>(2).device(d) = input.template chip<0>(4); - output.template chip<0>(3).device(d) = input.template chip<0>(5); - output.template chip<0>(4).device(d) = input.template chip<0>(0); - output.template chip<0>(5).device(d) = input.template chip<0>(1); - output.template chip<0>(6).device(d) = input.template chip<0>(6); - output.template chip<0>(7).device(d) = input.template chip<0>(7); - } else { - output.template chip<0>(0).device(d) = input.template chip<0>(1); - output.template chip<0>(1).device(d) = input.template chip<0>(2); - output.template chip<0>(2).device(d) = input.template chip<0>(0); - output.template chip<0>(3).device(d) = input.template chip<0>(3); + for (int i = 0; i < input.size(); ++i) { + output.template chip<0>(dst_[i]).device(d) = input.template chip<0>(i); } } -}; -template <typename T> -struct VecPermuteHWNCToNHWC { - Eigen::DSizes<Eigen::DenseIndex, 1> dimensions( - typename TTypes<T>::ConstFlat input) const { - Eigen::DSizes<Eigen::DenseIndex, 1> result; - result[0] = input.dimension(0); - return result; - } - template <typename Output, typename Device> - void eval(typename TTypes<T>::ConstFlat input, Output& output, - const Device& d) const { - if (input.size() == 8) { - output.template chip<0>(0).device(d) = input.template chip<0>(4); - output.template chip<0>(1).device(d) = input.template chip<0>(5); - output.template chip<0>(2).device(d) = input.template chip<0>(0); - output.template chip<0>(3).device(d) = input.template chip<0>(1); - output.template chip<0>(4).device(d) = input.template chip<0>(2); - output.template chip<0>(5).device(d) = input.template chip<0>(3); - output.template chip<0>(6).device(d) = input.template chip<0>(6); - output.template chip<0>(7).device(d) = input.template chip<0>(7); - } else { - output.template chip<0>(0).device(d) = input.template chip<0>(2); - output.template chip<0>(1).device(d) = input.template chip<0>(0); - output.template chip<0>(2).device(d) = input.template chip<0>(1); - output.template chip<0>(3).device(d) = input.template chip<0>(3); - } - } + private: + Eigen::DSizes<Eigen::DenseIndex, 8> dst_; }; // Functor used by DataFormatVecPermuteOp to do the computations. template <typename Device, typename T> struct DataFormatVecPermute { void operator()(const Device& d, typename TTypes<T>::ConstFlat x, - typename TTypes<T>::Flat y, const DataFormat src_format, - const DataFormat dst_format) { - if (src_format == DataFormat::NHWC && dst_format == DataFormat::NCHW) { - y.device(d) = x.customOp(VecPermuteNHWCToNCHW<T>()); - } else if (src_format == DataFormat::NCHW && - dst_format == DataFormat::NHWC) { - y.device(d) = x.customOp(VecPermuteNCHWToNHWC<T>()); - } else if (src_format == DataFormat::NHWC && - dst_format == DataFormat::HWNC) { - y.device(d) = x.customOp(VecPermuteNHWCToHWNC<T>()); - } else if (src_format == DataFormat::HWNC && - dst_format == DataFormat::NHWC) { - y.device(d) = x.customOp(VecPermuteHWNCToNHWC<T>()); - } + typename TTypes<T>::Flat y, + const Eigen::DSizes<Eigen::DenseIndex, 8>& dst) { + y.device(d) = x.customOp(VecPermute<T>(dst)); } }; |