diff options
author | Yao Zhang <yaozhang@google.com> | 2017-12-20 13:30:22 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-20 13:40:18 -0800 |
commit | 47249f349d13f5a11a8dc8c4026c54b49c88cfe0 (patch) | |
tree | 208470b3f2757232f357b3774e1884c366bec38b | |
parent | 76db97fe3961651617371902a1a623df61f9ed81 (diff) |
Extend DataFormatDimMap to handle tensors.
PiperOrigin-RevId: 179726269
-rw-r--r-- | tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt | 5 | ||||
-rw-r--r-- | tensorflow/core/kernels/data_format_ops.cc | 17 | ||||
-rw-r--r-- | tensorflow/core/kernels/data_format_ops.h | 4 | ||||
-rw-r--r-- | tensorflow/core/ops/nn_ops.cc | 5 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_test.py | 6 |
5 files changed, 19 insertions, 18 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt b/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt index 62098acd38..994d3b8ddb 100644 --- a/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt @@ -3,13 +3,14 @@ op { in_arg { name: "x" description: <<END -Scalar. Dimension index in source data format. Must be in the range [-4, 4). +A Tensor with each element as a dimension index in source data format. +Must be in the range [-4, 4). END } out_arg { name: "y" description: <<END -Scalar. Dimension index in destination data format. +A Tensor with each element as a dimension index in destination data format. END } attr { diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc index e32d6545b8..fa67545a0d 100644 --- a/tensorflow/core/kernels/data_format_ops.cc +++ b/tensorflow/core/kernels/data_format_ops.cc @@ -50,16 +50,11 @@ class DataFormatDimMapOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); - OP_REQUIRES( - context, input.dims() == 0, - errors::InvalidArgument("input must be a scalar, but got shape ", - input.shape().DebugString())); Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &output)); functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(), - input.scalar<T>(), - output->scalar<T>()); + input.flat<T>(), output->flat<T>()); } }; @@ -137,11 +132,11 @@ TF_CALL_int64(REGISTER_KERNEL); #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void DataFormatDimMap<GPUDevice, T>::operator()( \ - const GPUDevice& d, typename TTypes<T>::ConstScalar x, \ - typename TTypes<T>::Scalar y); \ +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void DataFormatDimMap<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T>::ConstFlat x, \ + typename TTypes<T>::Flat y); \ extern template struct DataFormatDimMap<GPUDevice, T>; #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); TF_CALL_int32(DECLARE_GPU_SPECS); diff --git a/tensorflow/core/kernels/data_format_ops.h b/tensorflow/core/kernels/data_format_ops.h index 01b7bff1eb..bf704cc35c 100644 --- a/tensorflow/core/kernels/data_format_ops.h +++ b/tensorflow/core/kernels/data_format_ops.h @@ -26,8 +26,8 @@ namespace functor { // Functor used by DataFormatDimMapOP to do the computations. template <typename Device, typename T> struct DataFormatDimMap { - void operator()(const Device& d, typename TTypes<T>::ConstScalar x, - typename TTypes<T>::Scalar y) { + void operator()(const Device& d, typename TTypes<T>::ConstFlat x, + typename TTypes<T>::Flat y) { auto zero = x.constant(0); auto one = x.constant(1); auto three = x.constant(3); diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index df2d4a7123..d2dfe23888 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -762,8 +762,9 @@ REGISTER_OP("DataFormatDimMap") Returns the dimension index in the destination data format given the one in the source data format. -x: Scalar. Dimension index in source data format. Must be in the range [-4, 4). -y: Scalar. Dimension index in destination data format. +x: A Tensor with each element as a dimension index in source data format. + Must be in the range [-4, 4). +y: A Tensor with each element as a dimension index in destination data format. src_format: source data format. dst_format: destination data format. )doc"); diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index d391e345fe..e8c24643a7 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -960,7 +960,7 @@ class DataFormatDimMapTest(test_lib.TestCase): y = nn_ops.data_format_dim_map(x) with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: y_val = sess.run(y) - self.assertEqual(y_val, y_val_expected) + self.assertAllEqual(y_val, y_val_expected) def test(self): self._test(0, 0) @@ -971,6 +971,10 @@ class DataFormatDimMapTest(test_lib.TestCase): self._test(-2, 3) self._test(-3, 2) self._test(-4, 0) + self._test([1, 3], [2, 1]) + self._test([1, 3, -2], [2, 1, 3]) + self._test([1, -3, -2], [2, 2, 3]) + self._test([[1, -3], [1, -1]], [[2, 2], [2, 1]]) class DataFormatVectorPermuteTest(test_lib.TestCase): |