From 47249f349d13f5a11a8dc8c4026c54b49c88cfe0 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Wed, 20 Dec 2017 13:30:22 -0800 Subject: Extend DataFormatDimMap to handle tensors. PiperOrigin-RevId: 179726269 --- .../api_def/base_api/api_def_DataFormatDimMap.pbtxt | 5 +++-- tensorflow/core/kernels/data_format_ops.cc | 17 ++++++----------- tensorflow/core/kernels/data_format_ops.h | 4 ++-- tensorflow/core/ops/nn_ops.cc | 5 +++-- tensorflow/python/ops/nn_test.py | 6 +++++- 5 files changed, 19 insertions(+), 18 deletions(-) diff --git a/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt b/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt index 62098acd38..994d3b8ddb 100644 --- a/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt @@ -3,13 +3,14 @@ op { in_arg { name: "x" description: <input(0); - OP_REQUIRES( - context, input.dims() == 0, - errors::InvalidArgument("input must be a scalar, but got shape ", - input.shape().DebugString())); Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); functor::DataFormatDimMap()(context->eigen_device(), - input.scalar(), - output->scalar()); + input.flat(), output->flat()); } }; @@ -137,11 +132,11 @@ TF_CALL_int64(REGISTER_KERNEL); #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void DataFormatDimMap::operator()( \ - const GPUDevice& d, typename TTypes::ConstScalar x, \ - typename TTypes::Scalar y); \ +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void DataFormatDimMap::operator()( \ + const GPUDevice& d, typename TTypes::ConstFlat x, \ + typename TTypes::Flat y); \ extern template struct DataFormatDimMap; #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); TF_CALL_int32(DECLARE_GPU_SPECS); diff --git a/tensorflow/core/kernels/data_format_ops.h b/tensorflow/core/kernels/data_format_ops.h index 01b7bff1eb..bf704cc35c 100644 --- a/tensorflow/core/kernels/data_format_ops.h +++ b/tensorflow/core/kernels/data_format_ops.h @@ -26,8 +26,8 @@ namespace functor { // Functor used by DataFormatDimMapOP to do the computations. template struct DataFormatDimMap { - void operator()(const Device& d, typename TTypes::ConstScalar x, - typename TTypes::Scalar y) { + void operator()(const Device& d, typename TTypes::ConstFlat x, + typename TTypes::Flat y) { auto zero = x.constant(0); auto one = x.constant(1); auto three = x.constant(3); diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index df2d4a7123..d2dfe23888 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -762,8 +762,9 @@ REGISTER_OP("DataFormatDimMap") Returns the dimension index in the destination data format given the one in the source data format. -x: Scalar. Dimension index in source data format. Must be in the range [-4, 4). -y: Scalar. Dimension index in destination data format. +x: A Tensor with each element as a dimension index in source data format. + Must be in the range [-4, 4). +y: A Tensor with each element as a dimension index in destination data format. src_format: source data format. dst_format: destination data format. )doc"); diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index d391e345fe..e8c24643a7 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -960,7 +960,7 @@ class DataFormatDimMapTest(test_lib.TestCase): y = nn_ops.data_format_dim_map(x) with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: y_val = sess.run(y) - self.assertEqual(y_val, y_val_expected) + self.assertAllEqual(y_val, y_val_expected) def test(self): self._test(0, 0) @@ -971,6 +971,10 @@ class DataFormatDimMapTest(test_lib.TestCase): self._test(-2, 3) self._test(-3, 2) self._test(-4, 0) + self._test([1, 3], [2, 1]) + self._test([1, 3, -2], [2, 1, 3]) + self._test([1, -3, -2], [2, 2, 3]) + self._test([[1, -3], [1, -1]], [[2, 2], [2, 1]]) class DataFormatVectorPermuteTest(test_lib.TestCase): -- cgit v1.2.3