aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data_format_ops.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-19 18:20:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-19 18:24:30 -0700
commit331bbe2886712fffc96ed9a7fb33fc9f09600240 (patch)
tree9745c3830f9d48e3b399b8f84d9f2661e13d63f2 /tensorflow/core/kernels/data_format_ops.h
parentd548cb4e811fc8a04dd10370c576441fc56b03f2 (diff)
Support general permutation.
PiperOrigin-RevId: 189675019
Diffstat (limited to 'tensorflow/core/kernels/data_format_ops.h')
-rw-r--r--tensorflow/core/kernels/data_format_ops.h131
1 files changed, 9 insertions, 122 deletions
diff --git a/tensorflow/core/kernels/data_format_ops.h b/tensorflow/core/kernels/data_format_ops.h
index d27415ed91..2ccc919586 100644
--- a/tensorflow/core/kernels/data_format_ops.h
+++ b/tensorflow/core/kernels/data_format_ops.h
@@ -23,13 +23,6 @@ limitations under the License.
namespace tensorflow {
namespace functor {
-enum class DataFormat {
- UNKNOWN = 0,
- NHWC,
- NCHW,
- HWNC,
-};
-
// Functor used by DataFormatDimMapOP to do the computations.
template <typename Device, typename T>
struct DataFormatDimMap {
@@ -47,65 +40,8 @@ struct DataFormatDimMap {
};
template <typename T>
-struct VecPermuteNHWCToNCHW {
- Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
- typename TTypes<T>::ConstFlat input) const {
- Eigen::DSizes<Eigen::DenseIndex, 1> result;
- result[0] = input.dimension(0);
- return result;
- }
- template <typename Output, typename Device>
- void eval(typename TTypes<T>::ConstFlat input, Output& output,
- const Device& d) const {
- if (input.size() == 8) {
- output.template chip<0>(0).device(d) = input.template chip<0>(0);
- output.template chip<0>(1).device(d) = input.template chip<0>(1);
- output.template chip<0>(2).device(d) = input.template chip<0>(6);
- output.template chip<0>(3).device(d) = input.template chip<0>(7);
- output.template chip<0>(4).device(d) = input.template chip<0>(2);
- output.template chip<0>(5).device(d) = input.template chip<0>(3);
- output.template chip<0>(6).device(d) = input.template chip<0>(4);
- output.template chip<0>(7).device(d) = input.template chip<0>(5);
- } else {
- output.template chip<0>(0).device(d) = input.template chip<0>(0);
- output.template chip<0>(1).device(d) = input.template chip<0>(3);
- output.template chip<0>(2).device(d) = input.template chip<0>(1);
- output.template chip<0>(3).device(d) = input.template chip<0>(2);
- }
- }
-};
-
-template <typename T>
-struct VecPermuteNCHWToNHWC {
- Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
- typename TTypes<T>::ConstFlat input) const {
- Eigen::DSizes<Eigen::DenseIndex, 1> result;
- result[0] = input.dimension(0);
- return result;
- }
- template <typename Output, typename Device>
- void eval(typename TTypes<T>::ConstFlat input, Output& output,
- const Device& d) const {
- if (input.size() == 8) {
- output.template chip<0>(0).device(d) = input.template chip<0>(0);
- output.template chip<0>(1).device(d) = input.template chip<0>(1);
- output.template chip<0>(2).device(d) = input.template chip<0>(4);
- output.template chip<0>(3).device(d) = input.template chip<0>(5);
- output.template chip<0>(4).device(d) = input.template chip<0>(6);
- output.template chip<0>(5).device(d) = input.template chip<0>(7);
- output.template chip<0>(6).device(d) = input.template chip<0>(2);
- output.template chip<0>(7).device(d) = input.template chip<0>(3);
- } else {
- output.template chip<0>(0).device(d) = input.template chip<0>(0);
- output.template chip<0>(1).device(d) = input.template chip<0>(2);
- output.template chip<0>(2).device(d) = input.template chip<0>(3);
- output.template chip<0>(3).device(d) = input.template chip<0>(1);
- }
- }
-};
-
-template <typename T>
-struct VecPermuteNHWCToHWNC {
+struct VecPermute {
+ VecPermute(const Eigen::DSizes<Eigen::DenseIndex, 8>& dst) : dst_(dst) {}
Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
typename TTypes<T>::ConstFlat input) const {
Eigen::DSizes<Eigen::DenseIndex, 1> result;
@@ -115,71 +51,22 @@ struct VecPermuteNHWCToHWNC {
template <typename Output, typename Device>
void eval(typename TTypes<T>::ConstFlat input, Output& output,
const Device& d) const {
- if (input.size() == 8) {
- output.template chip<0>(0).device(d) = input.template chip<0>(2);
- output.template chip<0>(1).device(d) = input.template chip<0>(3);
- output.template chip<0>(2).device(d) = input.template chip<0>(4);
- output.template chip<0>(3).device(d) = input.template chip<0>(5);
- output.template chip<0>(4).device(d) = input.template chip<0>(0);
- output.template chip<0>(5).device(d) = input.template chip<0>(1);
- output.template chip<0>(6).device(d) = input.template chip<0>(6);
- output.template chip<0>(7).device(d) = input.template chip<0>(7);
- } else {
- output.template chip<0>(0).device(d) = input.template chip<0>(1);
- output.template chip<0>(1).device(d) = input.template chip<0>(2);
- output.template chip<0>(2).device(d) = input.template chip<0>(0);
- output.template chip<0>(3).device(d) = input.template chip<0>(3);
+ for (int i = 0; i < input.size(); ++i) {
+ output.template chip<0>(dst_[i]).device(d) = input.template chip<0>(i);
}
}
-};
-template <typename T>
-struct VecPermuteHWNCToNHWC {
- Eigen::DSizes<Eigen::DenseIndex, 1> dimensions(
- typename TTypes<T>::ConstFlat input) const {
- Eigen::DSizes<Eigen::DenseIndex, 1> result;
- result[0] = input.dimension(0);
- return result;
- }
- template <typename Output, typename Device>
- void eval(typename TTypes<T>::ConstFlat input, Output& output,
- const Device& d) const {
- if (input.size() == 8) {
- output.template chip<0>(0).device(d) = input.template chip<0>(4);
- output.template chip<0>(1).device(d) = input.template chip<0>(5);
- output.template chip<0>(2).device(d) = input.template chip<0>(0);
- output.template chip<0>(3).device(d) = input.template chip<0>(1);
- output.template chip<0>(4).device(d) = input.template chip<0>(2);
- output.template chip<0>(5).device(d) = input.template chip<0>(3);
- output.template chip<0>(6).device(d) = input.template chip<0>(6);
- output.template chip<0>(7).device(d) = input.template chip<0>(7);
- } else {
- output.template chip<0>(0).device(d) = input.template chip<0>(2);
- output.template chip<0>(1).device(d) = input.template chip<0>(0);
- output.template chip<0>(2).device(d) = input.template chip<0>(1);
- output.template chip<0>(3).device(d) = input.template chip<0>(3);
- }
- }
+ private:
+ Eigen::DSizes<Eigen::DenseIndex, 8> dst_;
};
// Functor used by DataFormatVecPermuteOp to do the computations.
template <typename Device, typename T>
struct DataFormatVecPermute {
void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
- typename TTypes<T>::Flat y, const DataFormat src_format,
- const DataFormat dst_format) {
- if (src_format == DataFormat::NHWC && dst_format == DataFormat::NCHW) {
- y.device(d) = x.customOp(VecPermuteNHWCToNCHW<T>());
- } else if (src_format == DataFormat::NCHW &&
- dst_format == DataFormat::NHWC) {
- y.device(d) = x.customOp(VecPermuteNCHWToNHWC<T>());
- } else if (src_format == DataFormat::NHWC &&
- dst_format == DataFormat::HWNC) {
- y.device(d) = x.customOp(VecPermuteNHWCToHWNC<T>());
- } else if (src_format == DataFormat::HWNC &&
- dst_format == DataFormat::NHWC) {
- y.device(d) = x.customOp(VecPermuteHWNCToNHWC<T>());
- }
+ typename TTypes<T>::Flat y,
+ const Eigen::DSizes<Eigen::DenseIndex, 8>& dst) {
+ y.device(d) = x.customOp(VecPermute<T>(dst));
}
};