aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-11 18:15:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-11 18:20:06 -0700
commit9b26ed77dc2740314f47bcc4c991dd7f729b8d23 (patch)
tree1dad368dce604fa97eed4545f7835eee476f5b65
parentc69b9597995ab6510f9a21b615fe765b417d9cbb (diff)
Implement NCHW_VECT_C support for tf.depth_to_space on GPU.
PiperOrigin-RevId: 171904046
-rw-r--r--tensorflow/core/kernels/depthtospace_op.cc68
-rw-r--r--tensorflow/core/kernels/depthtospace_op.h4
-rw-r--r--tensorflow/core/kernels/depthtospace_op_gpu.cu.cc11
-rw-r--r--tensorflow/core/ops/array_ops.cc7
-rw-r--r--tensorflow/python/kernel_tests/BUILD4
-rw-r--r--tensorflow/python/kernel_tests/depthtospace_op_test.py84
6 files changed, 114 insertions, 64 deletions
diff --git a/tensorflow/core/kernels/depthtospace_op.cc b/tensorflow/core/kernels/depthtospace_op.cc
index 4cf7de0df4..39aa3e9eb0 100644
--- a/tensorflow/core/kernels/depthtospace_op.cc
+++ b/tensorflow/core/kernels/depthtospace_op.cc
@@ -49,34 +49,33 @@ class DepthToSpaceOp : public OpKernel {
OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context, context->GetAttr("block_size", &block_size_));
+ OP_REQUIRES(context, block_size_ > 1,
+ errors::InvalidArgument("Block size should be > 1, but was: ",
+ block_size_));
+
if (std::is_same<Device, CPUDevice>::value) {
OP_REQUIRES(
context, data_format_ == FORMAT_NHWC,
errors::InvalidArgument(
"Only NHWC data_format supported on CPU. Got ", data_format_str));
}
-
- // TODO(pauldonnelly): Implement NCHW_VECT_C kernel for the GPU.
- OP_REQUIRES(
- context, data_format_ != FORMAT_NCHW_VECT_C,
- errors::InvalidArgument("NHWC_VECT_C kernel not yet implemented."));
-
- OP_REQUIRES_OK(context, context->GetAttr("block_size", &block_size_));
-
- OP_REQUIRES(
- context, block_size_ > 1,
- errors::InvalidArgument("Block size should be > 1: ", block_size_));
}
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
-
- // Check on the input dimensions first.
- // The input is presumed to be [batch, height, width, depth]
const int dims = input.dims();
- constexpr int kRequiredDims = 4;
- OP_REQUIRES(context, kRequiredDims == dims,
- errors::InvalidArgument("Input rank should be: ", kRequiredDims,
+
+ // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here.
+ constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
+ OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)),
+ errors::InvalidArgument(
+ "qint8 should be used with data_format NCHW_VECT_C."));
+
+ constexpr int kVect = is_int8x4 ? 4 : 1;
+ constexpr int kDims = is_int8x4 ? 5 : 4;
+ OP_REQUIRES(context, kDims == dims,
+ errors::InvalidArgument("Input rank should be: ", kDims,
" instead of: ", dims));
constexpr int kNumSpatialDims = 2;
@@ -87,7 +86,8 @@ class DepthToSpaceOp : public OpKernel {
const int input_width =
input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'W'));
const int input_depth =
- input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'C'));
+ input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'C')) *
+ kVect;
const int block_size_sq = block_size_ * block_size_;
@@ -109,13 +109,30 @@ class DepthToSpaceOp : public OpKernel {
ShapeFromFormat(data_format_, batch_size, output_height,
output_width, output_depth),
&outputs_tensor));
- auto Tinput = input.tensor<T, kRequiredDims>();
- auto Toutput = outputs_tensor->tensor<T, kRequiredDims>();
+ auto Tinput = input.tensor<T, kDims>();
+ auto Toutput = outputs_tensor->tensor<T, kDims>();
+
+ if (std::is_same<Device, GPUDevice>::value) {
+ if (is_int8x4) {
+ // NCHW_VECT_C with 4 x qint8 can be treated as NCHW int32.
+ auto Tinput_v = input.template reinterpret_last_dimension<int32, 4>();
+ auto Toutput_v = outputs_tensor->reinterpret_last_dimension<int32, 4>();
+ functor::DepthToSpaceOpFunctor<GPUDevice, int32, FORMAT_NCHW> functor;
+ functor(context->eigen_device<GPUDevice>(), Tinput_v, block_size_,
+ Toutput_v);
+ return;
+ } else if (data_format_ == FORMAT_NCHW) {
+ functor::DepthToSpaceOpFunctor<GPUDevice, T, FORMAT_NCHW> functor;
+ functor(context->eigen_device<GPUDevice>(), Tinput, block_size_,
+ Toutput);
+ return;
+ }
+ }
- if (std::is_same<Device, GPUDevice>::value && data_format_ == FORMAT_NCHW) {
- functor::DepthToSpaceOpFunctor<Device, T, FORMAT_NCHW> functor;
- functor(context->eigen_device<Device>(), Tinput, block_size_, Toutput);
- } else {
+ // NOTE: Assumes data_format_ == FORMAT_NHWC here, since we have rejected
+ // (CPU && data_format_ != FORMAT_NHWC) in the constructor.
+
+ if (!is_int8x4) {
functor::DepthToSpaceOpFunctor<Device, T, FORMAT_NHWC> functor;
functor(context->eigen_device<Device>(), Tinput, block_size_, Toutput);
}
@@ -170,6 +187,9 @@ TF_CALL_ALL_TYPES(REGISTER);
REGISTER_KERNEL_BUILDER(
Name("DepthToSpace").Device(DEVICE_GPU).TypeConstraint<float>("T"),
DepthToSpaceOp<GPUDevice, float>);
+REGISTER_KERNEL_BUILDER(
+ Name("DepthToSpace").Device(DEVICE_GPU).TypeConstraint<qint8>("T"),
+ DepthToSpaceOp<GPUDevice, qint8>);
#endif // GOOGLE_CUDA
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/depthtospace_op.h b/tensorflow/core/kernels/depthtospace_op.h
index fca375f58b..272468b740 100644
--- a/tensorflow/core/kernels/depthtospace_op.h
+++ b/tensorflow/core/kernels/depthtospace_op.h
@@ -44,6 +44,10 @@ template <typename Device, typename T, TensorFormat data_format>
struct DepthToSpaceOpFunctor {
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
int block_size, typename TTypes<T, 4>::Tensor output);
+
+ // This 5-D version is to support NCHW_VECT_C.
+ void operator()(const Device& d, typename TTypes<T, 5>::ConstTensor input,
+ int block_size, typename TTypes<T, 5>::Tensor output);
};
} // namespace functor
diff --git a/tensorflow/core/kernels/depthtospace_op_gpu.cu.cc b/tensorflow/core/kernels/depthtospace_op_gpu.cu.cc
index 8f07c809e6..357c1f1be4 100644
--- a/tensorflow/core/kernels/depthtospace_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthtospace_op_gpu.cu.cc
@@ -124,6 +124,10 @@ struct DepthToSpaceOpFunctor<GPUDevice, T, FORMAT_NHWC> {
input_height, input_width, input_depth, output_height, output_width,
output_depth, output.data());
}
+ void operator()(const GPUDevice& d, typename TTypes<T, 5>::ConstTensor input,
+ int block_size, typename TTypes<T, 5>::Tensor output) {
+ LOG(FATAL) << "5-D tensors should not be used with NHWC format";
+ }
};
template <typename T>
@@ -143,6 +147,10 @@ struct DepthToSpaceOpFunctor<GPUDevice, T, FORMAT_NCHW> {
config.virtual_thread_count, input.data(), block_size, input_width,
output_depth * input_height, output.data());
}
+ void operator()(const GPUDevice& d, typename TTypes<T, 5>::ConstTensor input,
+ int block_size, typename TTypes<T, 5>::Tensor output) {
+ LOG(FATAL) << "5-D tensors should not be used with NCHW format";
+ }
};
} // end namespace functor
@@ -150,6 +158,9 @@ struct DepthToSpaceOpFunctor<GPUDevice, T, FORMAT_NCHW> {
template struct functor::DepthToSpaceOpFunctor<GPUDevice, float, FORMAT_NCHW>;
template struct functor::DepthToSpaceOpFunctor<GPUDevice, float, FORMAT_NHWC>;
+// NCHW_VECT_C with 4 x qint8 can be treated as NCHW int32.
+template struct functor::DepthToSpaceOpFunctor<GPUDevice, int32, FORMAT_NCHW>;
+
} // end namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 25a7c9eb39..14b87f0edf 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -4244,13 +4244,16 @@ REGISTER_OP("DepthToSpace")
TensorFormat data_format;
FormatFromString(data_format_str, &data_format);
+ constexpr int num_spatial_dims = 2;
+ const int dims =
+ GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
+
ShapeHandle input;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), dims, &input));
int32 block_size;
TF_RETURN_IF_ERROR(c->GetAttr("block_size", &block_size));
- constexpr int num_spatial_dims = 2;
DimensionHandle batch_size =
c->Dim(input, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
DimensionHandle input_height =
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index b8a7444f45..6beebbf48f 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1330,7 +1330,7 @@ cuda_py_test(
cuda_py_test(
name = "depthtospace_op_test",
- size = "small",
+ size = "medium",
srcs = ["depthtospace_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -1898,7 +1898,7 @@ cuda_py_test(
cuda_py_test(
name = "spacetodepth_op_test",
- size = "small",
+ size = "medium",
srcs = ["spacetodepth_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
diff --git a/tensorflow/python/kernel_tests/depthtospace_op_test.py b/tensorflow/python/kernel_tests/depthtospace_op_test.py
index 6d5dc3846b..792806642a 100644
--- a/tensorflow/python/kernel_tests/depthtospace_op_test.py
+++ b/tensorflow/python/kernel_tests/depthtospace_op_test.py
@@ -26,9 +26,11 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
class DepthToSpaceTest(test.TestCase):
@@ -201,7 +203,8 @@ class DepthToSpaceTest(test.TestCase):
_ = array_ops.space_to_depth(x_np, block_size)
def testUnknownShape(self):
- t = array_ops.depth_to_space(array_ops.placeholder(dtypes.float32), block_size=4)
+ t = array_ops.depth_to_space(
+ array_ops.placeholder(dtypes.float32), block_size=4)
self.assertEqual(4, t.get_shape().ndims)
def depthToSpaceUsingTranspose(self, tensor, block_size, data_format):
@@ -224,49 +227,58 @@ class DepthToSpaceTest(test.TestCase):
tensor = array_ops.reshape(tensor, [b, oc, oh, ow])
return tensor
- def compareToTranspose(self, data_format, batch_size, in_height, in_width,
- out_channels, block_size, use_gpu):
- if use_gpu and not test.is_gpu_available():
- print("gpu not available")
- return
-
- dtype = dtypes.float32
+ def compareToTranspose(self, batch_size, in_height, in_width, out_channels,
+ block_size, data_format, use_gpu):
in_channels = out_channels * block_size * block_size
-
- if data_format == "NHWC":
- input_shape = [batch_size, in_height, in_width, in_channels]
- elif data_format == "NCHW":
- input_shape = [batch_size, in_channels, in_height, in_width]
+ nhwc_input_shape = [batch_size, in_height, in_width, in_channels]
+ nchw_input_shape = [batch_size, in_channels, in_height, in_width]
+ total_size = np.prod(nhwc_input_shape)
+
+ if data_format == "NCHW_VECT_C":
+ # Initialize the input tensor with qint8 values that circle -127..127.
+ x = [((f + 128) % 255) - 127 for f in range(total_size)]
+ t = constant_op.constant(x, shape=nhwc_input_shape, dtype=dtypes.float32)
+ expected = self.depthToSpaceUsingTranspose(t, block_size, "NHWC")
+ t = test_util.NHWCToNCHW_VECT_C(t)
+ t, _, _ = gen_array_ops.quantize_v2(t, -128.0, 127.0, dtypes.qint8)
+ t = array_ops.depth_to_space(t, block_size, data_format="NCHW_VECT_C")
+ t = gen_array_ops.dequantize(t, -128, 127)
+ actual = test_util.NCHW_VECT_CToNHWC(t)
else:
- assert False, "unsupported format"
-
- # Initialize the input tensor with ascending whole numbers.
- total_size = 1
- for dim_size in input_shape:
- total_size *= dim_size
- x = [f for f in range(total_size)]
- inputs = constant_op.constant(x, shape=input_shape, dtype=dtype)
-
- expected = self.depthToSpaceUsingTranspose(inputs, block_size, data_format)
- actual = array_ops.depth_to_space(
- inputs, block_size, data_format=data_format)
+ # Initialize the input tensor with ascending whole numbers as floats.
+ x = [f * 1.0 for f in range(total_size)]
+ shape = nchw_input_shape if data_format == "NCHW" else nhwc_input_shape
+ t = constant_op.constant(x, shape=shape, dtype=dtypes.float32)
+ expected = self.depthToSpaceUsingTranspose(t, block_size, data_format)
+ actual = array_ops.depth_to_space(t, block_size, data_format=data_format)
with self.test_session(use_gpu=use_gpu) as sess:
actual_vals, expected_vals = sess.run([actual, expected])
self.assertTrue(np.array_equal(actual_vals, expected_vals))
def testAgainstTranspose(self):
- self.compareToTranspose("NHWC", 3, 2, 3, 1, 2, False)
- self.compareToTranspose("NHWC", 3, 2, 3, 2, 2, False)
- self.compareToTranspose("NHWC", 3, 2, 3, 1, 2, True)
- self.compareToTranspose("NHWC", 3, 2, 3, 2, 2, True)
-
- self.compareToTranspose("NCHW", 3, 2, 3, 1, 2, True)
- self.compareToTranspose("NCHW", 3, 2, 3, 2, 2, True)
- self.compareToTranspose("NCHW", 3, 2, 3, 1, 3, True)
- self.compareToTranspose("NCHW", 3, 2, 3, 2, 3, True)
- self.compareToTranspose("NCHW", 5, 7, 11, 3, 2, True)
- self.compareToTranspose("NCHW", 3, 200, 300, 32, 2, True)
+ self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", False)
+ self.compareToTranspose(3, 2, 3, 2, 2, "NHWC", False)
+ self.compareToTranspose(1, 2, 3, 2, 3, "NHWC", False)
+
+ if not test.is_gpu_available():
+ tf_logging.info("skipping gpu tests since gpu not available")
+ return
+
+ self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", True)
+ self.compareToTranspose(3, 2, 3, 2, 2, "NHWC", True)
+ self.compareToTranspose(3, 2, 3, 1, 2, "NCHW", True)
+ self.compareToTranspose(3, 2, 3, 2, 2, "NCHW", True)
+ self.compareToTranspose(3, 2, 3, 1, 3, "NCHW", True)
+ self.compareToTranspose(3, 2, 3, 2, 3, "NCHW", True)
+ self.compareToTranspose(5, 7, 11, 3, 2, "NCHW", True)
+ self.compareToTranspose(3, 200, 300, 32, 2, "NCHW", True)
+
+ self.compareToTranspose(3, 2, 3, 8, 2, "NCHW_VECT_C", True)
+ self.compareToTranspose(3, 2, 3, 4, 3, "NCHW_VECT_C", True)
+ self.compareToTranspose(3, 2, 3, 8, 3, "NCHW_VECT_C", True)
+ self.compareToTranspose(5, 7, 11, 12, 2, "NCHW_VECT_C", True)
+ self.compareToTranspose(3, 200, 300, 32, 2, "NCHW_VECT_C", True)
class DepthToSpaceGradientTest(test.TestCase):