aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2017-04-07 14:17:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-24 14:22:11 -0700
commit1822073137e1ac431250ea6f89b2719aac8d4782 (patch)
tree7f38332629b82078b89234e875a537c02468086f
parent7b09fd76fb2a14df771744226b9eb22624e65c0d (diff)
Allow uses of over-parameterized separable_conv.
Fixes #4330. RELNOTES: Allow uses of over-parameterized separable convolution. PiperOrigin-RevId: 157035904
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py13
-rw-r--r--tensorflow/python/ops/nn_impl.py21
2 files changed, 0 insertions, 34 deletions
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index 6184610bc0..db0adfc794 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -1317,19 +1317,6 @@ class SeparableConv2DTest(test.TestCase):
return
self._testSeparableConv2DEqualInputOutputDepth("NCHW")
- def testSeparableConv2DIllegalCases(self):
- # Output depth less then input depth.
- with self.assertRaisesRegexp(
- ValueError,
- "Refusing to perform an overparameterized separable convolution"):
- self._VerifyValues(
- tensor_in_sizes=[1, 4, 4, 2],
- depthwise_filter_in_sizes=[2, 2, 2, 3],
- pointwise_filter_in_sizes=[1, 1, 6, 5],
- stride=1,
- padding="SAME",
- expected=None)
-
class DeepConv2DTest(test.TestCase):
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 0a00e3d765..254a8432d3 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -448,10 +448,6 @@ def separable_conv2d(input,
A 4-D `Tensor` with shape according to 'data_format'. For
example, with data_format="NHWC", shape is [batch, out_height,
out_width, out_channels].
-
- Raises:
- ValueError: If channel_multiplier * in_channels > out_channels,
- which means that the separable convolution is overparameterized.
"""
with ops.name_scope(name, "separable_conv2d",
[input, depthwise_filter, pointwise_filter]) as name:
@@ -465,26 +461,9 @@ def separable_conv2d(input,
pointwise_filter_shape[0].assert_is_compatible_with(1)
pointwise_filter_shape[1].assert_is_compatible_with(1)
- channel_multiplier = depthwise_filter.get_shape().with_rank(4)[3]
- if data_format and data_format == "NCHW":
- in_channels = input.get_shape().with_rank(4)[1]
- else:
- in_channels = input.get_shape().with_rank(4)[3]
-
- out_channels = pointwise_filter_shape[3]
-
if rate is None:
rate = [1, 1]
- # If any of channel numbers is unknown, then the comparison below returns
- # None. See TensorShape.__gt__().
- if channel_multiplier * in_channels > out_channels:
- raise ValueError("Refusing to perform an overparameterized separable "
- "convolution: channel_multiplier * in_channels = "
- "%d * %d = %d > %d = out_channels" %
- (channel_multiplier, in_channels,
- channel_multiplier * in_channels, out_channels))
-
# The layout of the ops in the graph are expected to be as follows:
# depthwise_conv2d // Conv2D op corresponding to native deptwise conv.
# separable_conv2d // Conv2D op corresponding to the pointwise conv.