aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yangzihao Wang <yangzihao@google.com>2017-12-12 16:21:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-12 16:24:12 -0800
commitc373a16f61bff835181163dc07417e3cba6f47bc (patch)
treee7b3fa4d9f4bdaa0631acf0df9221604b7818619
parent47b674c938a38c6d88f27244a12ce3944c2f0464 (diff)
Return unimplemented error when trying to use dilated rate > 1 combined with NHWC format on the CPU.
Add test for unimplemented errors in Conv2D op. PiperOrigin-RevId: 178832407
-rw-r--r--tensorflow/core/kernels/conv_ops.cc13
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py46
2 files changed, 53 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index ba40c428e4..985586d626 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -112,9 +112,9 @@ struct LaunchGeneric {
template <typename T>
struct LaunchConv2DOp<CPUDevice, T> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& input, const Tensor& filter,
- int /*row_dilation*/, int /*col_dilation*/, int row_stride,
- int col_stride, const Padding& padding, Tensor* output,
+ const Tensor& input, const Tensor& filter, int row_dilation,
+ int col_dilation, int row_stride, int col_stride,
+ const Padding& padding, Tensor* output,
TensorFormat data_format) {
if (data_format != FORMAT_NHWC) {
ctx->SetStatus(
@@ -122,6 +122,13 @@ struct LaunchConv2DOp<CPUDevice, T> {
"NHWC tensor format for now."));
return;
}
+ // TODO(yangzihao): Add the CPU implementation of dilated conv 2D.
+ if (row_dilation > 1 || col_dilation > 1) {
+ ctx->SetStatus(
+ errors::Unimplemented("Generic conv implementation only supports "
+ "dilated rate of 1 for now."));
+ return;
+ }
LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
padding, output, data_format);
}
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index a85134c288..a7cbc76b87 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -164,8 +164,8 @@ class Conv2DTest(test.TestCase):
# as we will be using its gradients as reference for fp16 gradients.
return [dtypes.float32, dtypes.float16]
- def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, strides,
- padding, data_format, dtype, use_gpu):
+ def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, dilations,
+ strides, padding, data_format, dtype, use_gpu):
"""Verifies the output values of the convolution function.
Args:
@@ -173,6 +173,7 @@ class Conv2DTest(test.TestCase):
[batch, input_rows, input_cols, input_depth].
filter_in_sizes: Filter tensor dimensions in
[kernel_rows, kernel_cols, input_depth, output_depth].
+ dilations: Dilated rate: [col_dilation, row_dilation]
strides: Stride: [col_stride, row_stride]
padding: Padding type.
data_format: Format of the data tensors.
@@ -196,11 +197,18 @@ class Conv2DTest(test.TestCase):
t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype)
t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype)
strides = [1] + strides + [1]
+ dilations = [1] + dilations + [1]
if data_format == "NCHW":
t1 = test_util.NHWCToNCHW(t1)
strides = test_util.NHWCToNCHW(strides)
+ dilations = test_util.NHWCToNCHW(dilations)
conv = nn_ops.conv2d(
- t1, t2, strides=strides, padding=padding, data_format=data_format)
+ t1,
+ t2,
+ dilations=dilations,
+ strides=strides,
+ padding=padding,
+ data_format=data_format)
if data_format == "NCHW":
conv = test_util.NCHWToNHWC(conv)
@@ -316,11 +324,13 @@ class Conv2DTest(test.TestCase):
def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, strides, padding,
expected):
tensors = []
+ dilations = [1, 1]
for (data_format, use_gpu) in GetTestConfigs():
for dtype in self._DtypesToTest(use_gpu):
result = self._SetupValuesForDevice(
tensor_in_sizes,
filter_in_sizes,
+ dilations,
strides,
padding,
data_format,
@@ -1498,6 +1508,36 @@ class Conv2DTest(test.TestCase):
strides=[1, 1, 1, 1],
padding="VALID"))
+ def testCPUConv2DNCHWUnimplemented(self):
+ with self.test_session(use_gpu=False):
+ with self.assertRaisesRegexp(errors_impl.UnimplementedError,
+ "NHWC tensor format for now"):
+ conv = self._SetupValuesForDevice(
+ tensor_in_sizes=[1, 4, 4, 1],
+ filter_in_sizes=[2, 2, 1, 1],
+ dilations=[1, 1],
+ strides=[1, 1],
+ padding="VALID",
+ data_format="NCHW",
+ dtype=dtypes.float32,
+ use_gpu=False)
+ self.evaluate(conv)
+
+ def testCPUConv2DDilatedUnimplemented(self):
+ with self.test_session(use_gpu=False):
+ with self.assertRaisesRegexp(errors_impl.UnimplementedError,
+ "dilated rate of 1 for now"):
+ conv = self._SetupValuesForDevice(
+ tensor_in_sizes=[1, 4, 4, 1],
+ filter_in_sizes=[2, 2, 1, 1],
+ dilations=[2, 1],
+ strides=[1, 1],
+ padding="VALID",
+ data_format="NHWC",
+ dtype=dtypes.float32,
+ use_gpu=False)
+ self.evaluate(conv)
+
class DepthwiseConv2DTest(test.TestCase):