From 1822073137e1ac431250ea6f89b2719aac8d4782 Mon Sep 17 00:00:00 2001 From: Vijay Vasudevan Date: Fri, 7 Apr 2017 14:17:02 -0700 Subject: Allow uses of over-parameterized separable_conv. Fixes #4330. RELNOTES: Allow uses of over-parameterized separable convolution. PiperOrigin-RevId: 157035904 --- tensorflow/python/kernel_tests/conv_ops_test.py | 13 ------------- tensorflow/python/ops/nn_impl.py | 21 --------------------- 2 files changed, 34 deletions(-) diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index 6184610bc0..db0adfc794 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -1317,19 +1317,6 @@ class SeparableConv2DTest(test.TestCase): return self._testSeparableConv2DEqualInputOutputDepth("NCHW") - def testSeparableConv2DIllegalCases(self): - # Output depth less then input depth. - with self.assertRaisesRegexp( - ValueError, - "Refusing to perform an overparameterized separable convolution"): - self._VerifyValues( - tensor_in_sizes=[1, 4, 4, 2], - depthwise_filter_in_sizes=[2, 2, 2, 3], - pointwise_filter_in_sizes=[1, 1, 6, 5], - stride=1, - padding="SAME", - expected=None) - class DeepConv2DTest(test.TestCase): diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 0a00e3d765..254a8432d3 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -448,10 +448,6 @@ def separable_conv2d(input, A 4-D `Tensor` with shape according to 'data_format'. For example, with data_format="NHWC", shape is [batch, out_height, out_width, out_channels]. - - Raises: - ValueError: If channel_multiplier * in_channels > out_channels, - which means that the separable convolution is overparameterized. """ with ops.name_scope(name, "separable_conv2d", [input, depthwise_filter, pointwise_filter]) as name: @@ -465,26 +461,9 @@ def separable_conv2d(input, pointwise_filter_shape[0].assert_is_compatible_with(1) pointwise_filter_shape[1].assert_is_compatible_with(1) - channel_multiplier = depthwise_filter.get_shape().with_rank(4)[3] - if data_format and data_format == "NCHW": - in_channels = input.get_shape().with_rank(4)[1] - else: - in_channels = input.get_shape().with_rank(4)[3] - - out_channels = pointwise_filter_shape[3] - if rate is None: rate = [1, 1] - # If any of channel numbers is unknown, then the comparison below returns - # None. See TensorShape.__gt__(). - if channel_multiplier * in_channels > out_channels: - raise ValueError("Refusing to perform an overparameterized separable " - "convolution: channel_multiplier * in_channels = " - "%d * %d = %d > %d = out_channels" % - (channel_multiplier, in_channels, - channel_multiplier * in_channels, out_channels)) - # The layout of the ops in the graph are expected to be as follows: # depthwise_conv2d // Conv2D op corresponding to native deptwise conv. # separable_conv2d // Conv2D op corresponding to the pointwise conv. -- cgit v1.2.3