diff options
author | Yao Zhang <yaozhang@google.com> | 2017-12-05 20:43:47 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-05 20:51:49 -0800 |
commit | 63dd8ffea012bf7c743d6848faf5f406ede94c05 (patch) | |
tree | dfb38566ace474bc0f3f2b70e7d1c29d08aa2711 | |
parent | 041dc3349d100dc8a26a337d6656efe1543a6b9d (diff) |
Add DataFormatDimMap op.
PiperOrigin-RevId: 178051608
-rw-r--r-- | tensorflow/contrib/makefile/tf_op_files.txt | 1 | ||||
-rw-r--r-- | tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt | 31 | ||||
-rw-r--r-- | tensorflow/core/kernels/BUILD | 9 | ||||
-rw-r--r-- | tensorflow/core/kernels/data_format_ops.cc | 103 | ||||
-rw-r--r-- | tensorflow/core/kernels/data_format_ops.h | 45 | ||||
-rw-r--r-- | tensorflow/core/kernels/data_format_ops_gpu.cu.cc | 31 | ||||
-rw-r--r-- | tensorflow/core/ops/nn_ops.cc | 17 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_test.py | 20 |
8 files changed, 257 insertions, 0 deletions
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 9fc9aeb785..5f27566398 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -148,6 +148,7 @@ tensorflow/core/kernels/dynamic_stitch_op.cc tensorflow/core/kernels/dynamic_partition_op.cc tensorflow/core/kernels/decode_bmp_op.cc tensorflow/core/kernels/depthtospace_op.cc +tensorflow/core/kernels/data_format_ops.cc tensorflow/core/kernels/spacetodepth_op.cc tensorflow/core/kernels/dense_update_ops.cc tensorflow/core/kernels/deep_conv2d.cc diff --git a/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt b/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt new file mode 100644 index 0000000000..62098acd38 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt @@ -0,0 +1,31 @@ +op { + graph_op_name: "DataFormatDimMap" + in_arg { + name: "x" + description: <<END +Scalar. Dimension index in source data format. Must be in the range [-4, 4). +END + } + out_arg { + name: "y" + description: <<END +Scalar. Dimension index in destination data format. +END + } + attr { + name: "src_format" + description: <<END +source data format. +END + } + attr { + name: "dst_format" + description: <<END +destination data format. +END + } + summary: "Returns the dimension index in the destination data format given the one in" + description: <<END +the source data format. +END +} diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index c8359b4480..77ca8f5fcb 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3115,6 +3115,7 @@ cc_library( ":batch_norm_op", ":bias_op", ":conv_ops", + ":data_format_ops", ":depthwise_conv_grad_op", ":depthwise_conv_op", ":dilation_ops", @@ -3153,6 +3154,12 @@ tf_kernel_library( ) tf_kernel_library( + name = "data_format_ops", + prefix = "data_format_ops", + deps = NN_DEPS, +) + +tf_kernel_library( name = "bias_op", prefix = "bias_op", deps = NN_DEPS, @@ -4603,6 +4610,7 @@ filegroup( "control_flow_ops.h", "conv_2d.h", "conv_ops.h", + "data_format_ops.h", "depthtospace_op.h", "depthwise_conv_op.h", "fake_quant_ops_functor.h", @@ -4716,6 +4724,7 @@ filegroup( "cwise_op_squared_difference.cc", "cwise_op_sub.cc", "cwise_op_tanh.cc", + "data_format_ops.cc", "decode_wav_op.cc", "deep_conv2d.cc", "deep_conv2d.h", diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc new file mode 100644 index 0000000000..047188f754 --- /dev/null +++ b/tensorflow/core/kernels/data_format_ops.cc @@ -0,0 +1,103 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/data_format_ops.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template <typename Device, typename T> +class DataFormatDimMapOp : public OpKernel { + public: + explicit DataFormatDimMapOp(OpKernelConstruction* context) + : OpKernel(context) { + string src_format; + OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format)); + string dst_format; + OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format)); + OP_REQUIRES( + context, src_format == "NHWC", + errors::InvalidArgument(strings::StrCat( + "Current implementation doesn't support source data format ", + src_format))); + OP_REQUIRES(context, dst_format == "NCHW", + errors::InvalidArgument(strings::StrCat( + "Current implementation doesn't support dst data format ", + dst_format))); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + OP_REQUIRES(context, input.dims() == 0, + errors::InvalidArgument("input must be a scalar", + input.shape().DebugString())); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(), + input.scalar<T>(), + output->scalar<T>()); + } +}; + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("DataFormatDimMap").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + DataFormatDimMapOp<CPUDevice, T>); + +TF_CALL_int32(REGISTER_KERNEL); +TF_CALL_int64(REGISTER_KERNEL); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void DataFormatDimMap<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T>::ConstScalar x, \ + typename TTypes<T>::Scalar y); \ + extern template struct DataFormatDimMap<GPUDevice, T>; + +#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); + +TF_CALL_int32(DECLARE_GPU_SPECS); +TF_CALL_int64(DECLARE_GPU_SPECS); +#undef DECLARE_GPU_SPEC +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("DataFormatDimMap").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ + DataFormatDimMapOp<GPUDevice, T>); + +TF_CALL_int32(REGISTER_GPU_KERNEL); +TF_CALL_int64(REGISTER_GPU_KERNEL); +#undef REGISTER_GPU_KERNEL + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data_format_ops.h b/tensorflow/core/kernels/data_format_ops.h new file mode 100644 index 0000000000..079e76c0d9 --- /dev/null +++ b/tensorflow/core/kernels/data_format_ops.h @@ -0,0 +1,45 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_KERNELS_DATA_FORMAT_OPS_H_ +#define TENSORFLOW_KERNELS_DATA_FORMAT_OPS_H_ +// Functor definition for data format dim mapping ops, must be compilable +// by nvcc. +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Functor used by DataFormatDimMapOP to do the computations. +template <typename Device, typename T> +struct DataFormatDimMap { + void operator()(const Device& d, typename TTypes<T>::ConstScalar x, + typename TTypes<T>::Scalar y) { + auto zero = x.constant(0); + auto one = x.constant(1); + auto three = x.constant(3); + auto four = x.constant(4); + auto x_mod = (x + four) % 4; + auto is_zero = (x_mod == zero); + auto is_three = (x_mod == three); + y.device(d) = is_zero.select(zero, is_three.select(one, x_mod + one)); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_DATA_FORMAT_OPS_H_ diff --git a/tensorflow/core/kernels/data_format_ops_gpu.cu.cc b/tensorflow/core/kernels/data_format_ops_gpu.cu.cc new file mode 100644 index 0000000000..09340a7d87 --- /dev/null +++ b/tensorflow/core/kernels/data_format_ops_gpu.cu.cc @@ -0,0 +1,31 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/data_format_ops.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; +template struct functor::DataFormatDimMap<GPUDevice, int32>; +template struct functor::DataFormatDimMap<GPUDevice, int64>; + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 102de94787..f58425db0a 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -751,6 +751,23 @@ Status CommonFusedConvCalculations(InferenceContext* c, bool has_resize) { } // namespace +REGISTER_OP("DataFormatDimMap") + .Input("x: T") + .Output("y: T") + .Attr("T: {int32, int64} = DT_INT32") + .Attr("src_format: string = 'NHWC'") + .Attr("dst_format: string = 'NCHW'") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Returns the dimension index in the destination data format given the one in +the source data format. + +x: Scalar. Dimension index in source data format. Must be in the range [-4, 4). +y: Scalar. Dimension index in destination data format. +src_format: source data format. +dst_format: destination data format. +)doc"); + REGISTER_OP("FusedResizeAndPadConv2D") .Input("input: T") .Input("size: int32") diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 3b918e4f74..ac79354fb7 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -953,5 +953,25 @@ class MomentsTest(test_lib.TestCase): self.doOutputTest((10, 10, 10, 30), (1, 2, 3)) +class DataFormatDimMapTest(test_lib.TestCase): + + def _test(self, x_val, y_val_expected): + x = constant_op.constant(x_val) + y = nn_ops.data_format_dim_map(x) + with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: + y_val = sess.run(y) + self.assertEqual(y_val, y_val_expected) + + def test(self): + self._test(0, 0) + self._test(1, 2) + self._test(2, 3) + self._test(3, 1) + self._test(-1, 1) + self._test(-2, 3) + self._test(-3, 2) + self._test(-4, 0) + + if __name__ == "__main__": test_lib.main() |