aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-12-20 13:30:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-20 13:40:18 -0800
commit47249f349d13f5a11a8dc8c4026c54b49c88cfe0 (patch)
tree208470b3f2757232f357b3774e1884c366bec38b
parent76db97fe3961651617371902a1a623df61f9ed81 (diff)
Extend DataFormatDimMap to handle tensors.
PiperOrigin-RevId: 179726269
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt5
-rw-r--r--tensorflow/core/kernels/data_format_ops.cc17
-rw-r--r--tensorflow/core/kernels/data_format_ops.h4
-rw-r--r--tensorflow/core/ops/nn_ops.cc5
-rw-r--r--tensorflow/python/ops/nn_test.py6
5 files changed, 19 insertions, 18 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt b/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt
index 62098acd38..994d3b8ddb 100644
--- a/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DataFormatDimMap.pbtxt
@@ -3,13 +3,14 @@ op {
in_arg {
name: "x"
description: <<END
-Scalar. Dimension index in source data format. Must be in the range [-4, 4).
+A Tensor with each element as a dimension index in source data format.
+Must be in the range [-4, 4).
END
}
out_arg {
name: "y"
description: <<END
-Scalar. Dimension index in destination data format.
+A Tensor with each element as a dimension index in destination data format.
END
}
attr {
diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc
index e32d6545b8..fa67545a0d 100644
--- a/tensorflow/core/kernels/data_format_ops.cc
+++ b/tensorflow/core/kernels/data_format_ops.cc
@@ -50,16 +50,11 @@ class DataFormatDimMapOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
- OP_REQUIRES(
- context, input.dims() == 0,
- errors::InvalidArgument("input must be a scalar, but got shape ",
- input.shape().DebugString()));
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(),
- input.scalar<T>(),
- output->scalar<T>());
+ input.flat<T>(), output->flat<T>());
}
};
@@ -137,11 +132,11 @@ TF_CALL_int64(REGISTER_KERNEL);
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
-#define DECLARE_GPU_SPEC(T) \
- template <> \
- void DataFormatDimMap<GPUDevice, T>::operator()( \
- const GPUDevice& d, typename TTypes<T>::ConstScalar x, \
- typename TTypes<T>::Scalar y); \
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void DataFormatDimMap<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
+ typename TTypes<T>::Flat y); \
extern template struct DataFormatDimMap<GPUDevice, T>;
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
TF_CALL_int32(DECLARE_GPU_SPECS);
diff --git a/tensorflow/core/kernels/data_format_ops.h b/tensorflow/core/kernels/data_format_ops.h
index 01b7bff1eb..bf704cc35c 100644
--- a/tensorflow/core/kernels/data_format_ops.h
+++ b/tensorflow/core/kernels/data_format_ops.h
@@ -26,8 +26,8 @@ namespace functor {
// Functor used by DataFormatDimMapOP to do the computations.
template <typename Device, typename T>
struct DataFormatDimMap {
- void operator()(const Device& d, typename TTypes<T>::ConstScalar x,
- typename TTypes<T>::Scalar y) {
+ void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
+ typename TTypes<T>::Flat y) {
auto zero = x.constant(0);
auto one = x.constant(1);
auto three = x.constant(3);
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index df2d4a7123..d2dfe23888 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -762,8 +762,9 @@ REGISTER_OP("DataFormatDimMap")
Returns the dimension index in the destination data format given the one in
the source data format.
-x: Scalar. Dimension index in source data format. Must be in the range [-4, 4).
-y: Scalar. Dimension index in destination data format.
+x: A Tensor with each element as a dimension index in source data format.
+ Must be in the range [-4, 4).
+y: A Tensor with each element as a dimension index in destination data format.
src_format: source data format.
dst_format: destination data format.
)doc");
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index d391e345fe..e8c24643a7 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -960,7 +960,7 @@ class DataFormatDimMapTest(test_lib.TestCase):
y = nn_ops.data_format_dim_map(x)
with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
y_val = sess.run(y)
- self.assertEqual(y_val, y_val_expected)
+ self.assertAllEqual(y_val, y_val_expected)
def test(self):
self._test(0, 0)
@@ -971,6 +971,10 @@ class DataFormatDimMapTest(test_lib.TestCase):
self._test(-2, 3)
self._test(-3, 2)
self._test(-4, 0)
+ self._test([1, 3], [2, 1])
+ self._test([1, 3, -2], [2, 1, 3])
+ self._test([1, -3, -2], [2, 2, 3])
+ self._test([[1, -3], [1, -1]], [[2, 2], [2, 1]])
class DataFormatVectorPermuteTest(test_lib.TestCase):