aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-12-05 20:43:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-05 20:51:49 -0800
commit63dd8ffea012bf7c743d6848faf5f406ede94c05 (patch)
treedfb38566ace474bc0f3f2b70e7d1c29d08aa2711
parent041dc3349d100dc8a26a337d6656efe1543a6b9d (diff)
Add DataFormatDimMap op.
PiperOrigin-RevId: 178051608
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt1
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt31
-rw-r--r--tensorflow/core/kernels/BUILD9
-rw-r--r--tensorflow/core/kernels/data_format_ops.cc103
-rw-r--r--tensorflow/core/kernels/data_format_ops.h45
-rw-r--r--tensorflow/core/kernels/data_format_ops_gpu.cu.cc31
-rw-r--r--tensorflow/core/ops/nn_ops.cc17
-rw-r--r--tensorflow/python/ops/nn_test.py20
8 files changed, 257 insertions, 0 deletions
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 9fc9aeb785..5f27566398 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -148,6 +148,7 @@ tensorflow/core/kernels/dynamic_stitch_op.cc
tensorflow/core/kernels/dynamic_partition_op.cc
tensorflow/core/kernels/decode_bmp_op.cc
tensorflow/core/kernels/depthtospace_op.cc
+tensorflow/core/kernels/data_format_ops.cc
tensorflow/core/kernels/spacetodepth_op.cc
tensorflow/core/kernels/dense_update_ops.cc
tensorflow/core/kernels/deep_conv2d.cc
diff --git a/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt b/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt
new file mode 100644
index 0000000000..62098acd38
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt
@@ -0,0 +1,31 @@
+op {
+ graph_op_name: "DataFormatDimMap"
+ in_arg {
+ name: "x"
+ description: <<END
+Scalar. Dimension index in source data format. Must be in the range [-4, 4).
+END
+ }
+ out_arg {
+ name: "y"
+ description: <<END
+Scalar. Dimension index in destination data format.
+END
+ }
+ attr {
+ name: "src_format"
+ description: <<END
+source data format.
+END
+ }
+ attr {
+ name: "dst_format"
+ description: <<END
+destination data format.
+END
+ }
+ summary: "Returns the dimension index in the destination data format given the one in"
+ description: <<END
+the source data format.
+END
+}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index c8359b4480..77ca8f5fcb 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3115,6 +3115,7 @@ cc_library(
":batch_norm_op",
":bias_op",
":conv_ops",
+ ":data_format_ops",
":depthwise_conv_grad_op",
":depthwise_conv_op",
":dilation_ops",
@@ -3153,6 +3154,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "data_format_ops",
+ prefix = "data_format_ops",
+ deps = NN_DEPS,
+)
+
+tf_kernel_library(
name = "bias_op",
prefix = "bias_op",
deps = NN_DEPS,
@@ -4603,6 +4610,7 @@ filegroup(
"control_flow_ops.h",
"conv_2d.h",
"conv_ops.h",
+ "data_format_ops.h",
"depthtospace_op.h",
"depthwise_conv_op.h",
"fake_quant_ops_functor.h",
@@ -4716,6 +4724,7 @@ filegroup(
"cwise_op_squared_difference.cc",
"cwise_op_sub.cc",
"cwise_op_tanh.cc",
+ "data_format_ops.cc",
"decode_wav_op.cc",
"deep_conv2d.cc",
"deep_conv2d.h",
diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc
new file mode 100644
index 0000000000..047188f754
--- /dev/null
+++ b/tensorflow/core/kernels/data_format_ops.cc
@@ -0,0 +1,103 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/nn_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/data_format_ops.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class DataFormatDimMapOp : public OpKernel {
+ public:
+ explicit DataFormatDimMapOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string src_format;
+ OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
+ string dst_format;
+ OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
+ OP_REQUIRES(
+ context, src_format == "NHWC",
+ errors::InvalidArgument(strings::StrCat(
+ "Current implementation doesn't support source data format ",
+ src_format)));
+ OP_REQUIRES(context, dst_format == "NCHW",
+ errors::InvalidArgument(strings::StrCat(
+ "Current implementation doesn't support dst data format ",
+ dst_format)));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ OP_REQUIRES(context, input.dims() == 0,
+ errors::InvalidArgument("input must be a scalar",
+ input.shape().DebugString()));
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+ functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(),
+ input.scalar<T>(),
+ output->scalar<T>());
+ }
+};
+
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("DataFormatDimMap").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ DataFormatDimMapOp<CPUDevice, T>);
+
+TF_CALL_int32(REGISTER_KERNEL);
+TF_CALL_int64(REGISTER_KERNEL);
+#undef REGISTER_KERNEL
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void DataFormatDimMap<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstScalar x, \
+ typename TTypes<T>::Scalar y); \
+ extern template struct DataFormatDimMap<GPUDevice, T>;
+
+#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
+
+TF_CALL_int32(DECLARE_GPU_SPECS);
+TF_CALL_int64(DECLARE_GPU_SPECS);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER_GPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("DataFormatDimMap").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+ DataFormatDimMapOp<GPUDevice, T>);
+
+TF_CALL_int32(REGISTER_GPU_KERNEL);
+TF_CALL_int64(REGISTER_GPU_KERNEL);
+#undef REGISTER_GPU_KERNEL
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data_format_ops.h b/tensorflow/core/kernels/data_format_ops.h
new file mode 100644
index 0000000000..079e76c0d9
--- /dev/null
+++ b/tensorflow/core/kernels/data_format_ops.h
@@ -0,0 +1,45 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_DATA_FORMAT_OPS_H_
+#define TENSORFLOW_KERNELS_DATA_FORMAT_OPS_H_
+// Functor definition for data format dim mapping ops, must be compilable
+// by nvcc.
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+namespace functor {
+
+// Functor used by DataFormatDimMapOP to do the computations.
+template <typename Device, typename T>
+struct DataFormatDimMap {
+ void operator()(const Device& d, typename TTypes<T>::ConstScalar x,
+ typename TTypes<T>::Scalar y) {
+ auto zero = x.constant(0);
+ auto one = x.constant(1);
+ auto three = x.constant(3);
+ auto four = x.constant(4);
+ auto x_mod = (x + four) % 4;
+ auto is_zero = (x_mod == zero);
+ auto is_three = (x_mod == three);
+ y.device(d) = is_zero.select(zero, is_three.select(one, x_mod + one));
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_DATA_FORMAT_OPS_H_
diff --git a/tensorflow/core/kernels/data_format_ops_gpu.cu.cc b/tensorflow/core/kernels/data_format_ops_gpu.cu.cc
new file mode 100644
index 0000000000..09340a7d87
--- /dev/null
+++ b/tensorflow/core/kernels/data_format_ops_gpu.cu.cc
@@ -0,0 +1,31 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/data_format_ops.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+template struct functor::DataFormatDimMap<GPUDevice, int32>;
+template struct functor::DataFormatDimMap<GPUDevice, int64>;
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 102de94787..f58425db0a 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -751,6 +751,23 @@ Status CommonFusedConvCalculations(InferenceContext* c, bool has_resize) {
} // namespace
+REGISTER_OP("DataFormatDimMap")
+ .Input("x: T")
+ .Output("y: T")
+ .Attr("T: {int32, int64} = DT_INT32")
+ .Attr("src_format: string = 'NHWC'")
+ .Attr("dst_format: string = 'NCHW'")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Returns the dimension index in the destination data format given the one in
+the source data format.
+
+x: Scalar. Dimension index in source data format. Must be in the range [-4, 4).
+y: Scalar. Dimension index in destination data format.
+src_format: source data format.
+dst_format: destination data format.
+)doc");
+
REGISTER_OP("FusedResizeAndPadConv2D")
.Input("input: T")
.Input("size: int32")
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 3b918e4f74..ac79354fb7 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -953,5 +953,25 @@ class MomentsTest(test_lib.TestCase):
self.doOutputTest((10, 10, 10, 30), (1, 2, 3))
+class DataFormatDimMapTest(test_lib.TestCase):
+
+ def _test(self, x_val, y_val_expected):
+ x = constant_op.constant(x_val)
+ y = nn_ops.data_format_dim_map(x)
+ with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
+ y_val = sess.run(y)
+ self.assertEqual(y_val, y_val_expected)
+
+ def test(self):
+ self._test(0, 0)
+ self._test(1, 2)
+ self._test(2, 3)
+ self._test(3, 1)
+ self._test(-1, 1)
+ self._test(-2, 3)
+ self._test(-3, 2)
+ self._test(-4, 0)
+
+
if __name__ == "__main__":
test_lib.main()