diff options
author | Jonathan Shen <jonathanasdf@google.com> | 2017-04-25 14:03:45 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-25 15:24:32 -0700 |
commit | fa7f79541d1ba1fa6f54ea3474cbf3ce1411bfb4 (patch) | |
tree | 76f02110b46ae420aed18f344ebfe548132cdb81 /tensorflow | |
parent | 1d8ac37322983c71f02d0300766740172ad00485 (diff) |
Deprecate the use of is_training in resnet_arg_scope. is_training should be
directly passed to the network instead.
Change: 154226768
Diffstat (limited to 'tensorflow')
7 files changed, 366 insertions, 135 deletions
diff --git a/tensorflow/contrib/slim/python/slim/nets/BUILD b/tensorflow/contrib/slim/python/slim/nets/BUILD index 8077818216..737bbbe57b 100644 --- a/tensorflow/contrib/slim/python/slim/nets/BUILD +++ b/tensorflow/contrib/slim/python/slim/nets/BUILD @@ -286,6 +286,26 @@ py_test( ], ) +py_test( + name = "resnet_is_training_test", + size = "medium", + srcs = ["resnet_is_training_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":resnet_utils", + ":resnet_v1", + ":resnet_v2", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//third_party/py/numpy", + ], +) + py_library( name = "vgg", srcs = ["vgg.py"], diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_is_training_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_is_training_test.py new file mode 100644 index 0000000000..9a165577b6 --- /dev/null +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_is_training_test.py @@ -0,0 +1,154 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Specifying is_training in resnet_arg_scope is being deprecated. + +Test that everything behaves as expected in the meantime. + +Note: This test modifies the layers.batch_norm function. +Other tests that use layers.batch_norm may not work if added to this file. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib import layers +from tensorflow.contrib.framework.python.ops import add_arg_scope +from tensorflow.contrib.framework.python.ops import arg_scope +from tensorflow.contrib.slim.python.slim.nets import resnet_utils +from tensorflow.contrib.slim.python.slim.nets import resnet_v1 +from tensorflow.contrib.slim.python.slim.nets import resnet_v2 +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +def create_test_input(batch, height, width, channels): + """Create test input tensor.""" + if None in [batch, height, width, channels]: + return array_ops.placeholder(dtypes.float32, (batch, height, width, + channels)) + else: + return math_ops.to_float( + np.tile( + np.reshape( + np.reshape(np.arange(height), [height, 1]) + + np.reshape(np.arange(width), [1, width]), + [1, height, width, 1]), + [batch, 1, 1, channels])) + + +class ResnetIsTrainingTest(test.TestCase): + + def _testDeprecatingIsTraining(self, network_fn): + batch_norm_fn = layers.batch_norm + + @add_arg_scope + def batch_norm_expect_is_training(*args, **kwargs): + assert kwargs['is_training'] + return batch_norm_fn(*args, **kwargs) + + @add_arg_scope + def batch_norm_expect_is_not_training(*args, **kwargs): + assert not kwargs['is_training'] + return batch_norm_fn(*args, **kwargs) + + global_pool = True + num_classes = 10 + inputs = create_test_input(2, 224, 224, 3) + + # Default argument for resnet_arg_scope + layers.batch_norm = batch_norm_expect_is_training + with arg_scope(resnet_utils.resnet_arg_scope()): + network_fn(inputs, num_classes, global_pool=global_pool, scope='resnet1') + + layers.batch_norm = batch_norm_expect_is_training + with arg_scope(resnet_utils.resnet_arg_scope()): + network_fn( + inputs, + num_classes, + is_training=True, + global_pool=global_pool, + scope='resnet2') + + layers.batch_norm = batch_norm_expect_is_not_training + with arg_scope(resnet_utils.resnet_arg_scope()): + network_fn( + inputs, + num_classes, + is_training=False, + global_pool=global_pool, + scope='resnet3') + + # resnet_arg_scope with is_training set to True (deprecated) + layers.batch_norm = batch_norm_expect_is_training + with arg_scope(resnet_utils.resnet_arg_scope(is_training=True)): + network_fn(inputs, num_classes, global_pool=global_pool, scope='resnet4') + + layers.batch_norm = batch_norm_expect_is_training + with arg_scope(resnet_utils.resnet_arg_scope(is_training=True)): + network_fn( + inputs, + num_classes, + is_training=True, + global_pool=global_pool, + scope='resnet5') + + layers.batch_norm = batch_norm_expect_is_not_training + with arg_scope(resnet_utils.resnet_arg_scope(is_training=True)): + network_fn( + inputs, + num_classes, + is_training=False, + global_pool=global_pool, + scope='resnet6') + + # resnet_arg_scope with is_training set to False (deprecated) + layers.batch_norm = batch_norm_expect_is_not_training + with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)): + network_fn(inputs, num_classes, global_pool=global_pool, scope='resnet7') + + layers.batch_norm = batch_norm_expect_is_training + with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)): + network_fn( + inputs, + num_classes, + is_training=True, + global_pool=global_pool, + scope='resnet8') + + layers.batch_norm = batch_norm_expect_is_not_training + with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)): + network_fn( + inputs, + num_classes, + is_training=False, + global_pool=global_pool, + scope='resnet9') + + layers.batch_norm = batch_norm_fn + + def testDeprecatingIsTrainingResnetV1(self): + self._testDeprecatingIsTraining(resnet_v1.resnet_v1_50) + + def testDeprecatingIsTrainingResnetV2(self): + self._testDeprecatingIsTraining(resnet_v2.resnet_v2_50) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py b/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py index 89d27438e5..58614a998a 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py @@ -41,6 +41,7 @@ from __future__ import print_function import collections from tensorflow.contrib import layers as layers_lib +from tensorflow.contrib.framework import deprecated_args from tensorflow.contrib.framework.python.ops import add_arg_scope from tensorflow.contrib.framework.python.ops import arg_scope from tensorflow.contrib.layers.python.layers import initializers @@ -222,6 +223,10 @@ def stack_blocks_dense(net, return net +@deprecated_args( + '2017-08-01', + 'Pass is_training directly to the network instead of the arg_scope.', + 'is_training') def resnet_arg_scope(is_training=True, weight_decay=0.0001, batch_norm_decay=0.997, @@ -236,7 +241,7 @@ def resnet_arg_scope(is_training=True, Args: is_training: Whether or not we are training the parameters in the batch - normalization layers of the model. + normalization layers of the model. (deprecated) weight_decay: The weight decay to use for regularizing the model. batch_norm_decay: The moving average decay when estimating layer activation statistics in batch normalization. @@ -261,8 +266,7 @@ def resnet_arg_scope(is_training=True, weights_regularizer=regularizers.l2_regularizer(weight_decay), weights_initializer=initializers.variance_scaling_initializer(), activation_fn=nn_ops.relu, - normalizer_fn=layers.batch_norm, - normalizer_params=batch_norm_params): + normalizer_fn=layers.batch_norm): with arg_scope([layers.batch_norm], **batch_norm_params): # The following implies padding='SAME' for pool1, which makes feature # alignment easier for dense prediction tasks. This is also used in diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py index fe13ce1b0e..90f93d46e3 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py @@ -40,15 +40,16 @@ Typical use: ResNet-101 for image classification into 1000 classes: # inputs has shape [batch, 224, 224, 3] - with slim.arg_scope(resnet_v1.resnet_arg_scope(is_training)): - net, end_points = resnet_v1.resnet_v1_101(inputs, 1000) + with slim.arg_scope(resnet_v1.resnet_arg_scope()): + net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False) ResNet-101 for semantic segmentation into 21 classes: # inputs has shape [batch, 513, 513, 3] - with slim.arg_scope(resnet_v1.resnet_arg_scope(is_training)): + with slim.arg_scope(resnet_v1.resnet_arg_scope()): net, end_points = resnet_v1.resnet_v1_101(inputs, 21, + is_training=False, global_pool=False, output_stride=16) """ @@ -127,6 +128,7 @@ def bottleneck(inputs, def resnet_v1(inputs, blocks, num_classes=None, + is_training=None, global_pool=True, output_stride=None, include_root_block=True, @@ -161,6 +163,8 @@ def resnet_v1(inputs, is a resnet_utils.Block object describing the units in the block. num_classes: Number of predicted classes for classification tasks. If None we return the features before the logit layer. + is_training: whether is training or not. If None, the value inherited from + the resnet_arg_scope is used. Specifying value None is deprecated. global_pool: If True, we perform global average pooling before computing the logits. Set to True for image classification, False for dense prediction. output_stride: If None, then the output will be computed at the nominal @@ -192,30 +196,36 @@ def resnet_v1(inputs, with arg_scope( [layers.conv2d, bottleneck, resnet_utils.stack_blocks_dense], outputs_collections=end_points_collection): - net = inputs - if include_root_block: - if output_stride is not None: - if output_stride % 4 != 0: - raise ValueError('The output_stride needs to be a multiple of 4.') - output_stride /= 4 - net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') - net = layers_lib.max_pool2d(net, [3, 3], stride=2, scope='pool1') - net = resnet_utils.stack_blocks_dense(net, blocks, output_stride) - if global_pool: - # Global average pooling. - net = math_ops.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) - if num_classes is not None: - net = layers.conv2d( - net, - num_classes, [1, 1], - activation_fn=None, - normalizer_fn=None, - scope='logits') - # Convert end_points_collection into a dictionary of end_points. - end_points = utils.convert_collection_to_dict(end_points_collection) - if num_classes is not None: - end_points['predictions'] = layers_lib.softmax(net, scope='predictions') - return net, end_points + if is_training is not None: + bn_scope = arg_scope([layers.batch_norm], is_training=is_training) + else: + bn_scope = arg_scope([]) + with bn_scope: + net = inputs + if include_root_block: + if output_stride is not None: + if output_stride % 4 != 0: + raise ValueError('The output_stride needs to be a multiple of 4.') + output_stride /= 4 + net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') + net = layers_lib.max_pool2d(net, [3, 3], stride=2, scope='pool1') + net = resnet_utils.stack_blocks_dense(net, blocks, output_stride) + if global_pool: + # Global average pooling. + net = math_ops.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) + if num_classes is not None: + net = layers.conv2d( + net, + num_classes, [1, 1], + activation_fn=None, + normalizer_fn=None, + scope='logits') + # Convert end_points_collection into a dictionary of end_points. + end_points = utils.convert_collection_to_dict(end_points_collection) + if num_classes is not None: + end_points['predictions'] = layers_lib.softmax( + net, scope='predictions') + return net, end_points resnet_v1.default_image_size = 224 @@ -245,6 +255,7 @@ def resnet_v1_block(scope, base_depth, num_units, stride): def resnet_v1_50(inputs, num_classes=None, + is_training=None, global_pool=True, output_stride=None, reuse=None, @@ -260,6 +271,7 @@ def resnet_v1_50(inputs, inputs, blocks, num_classes, + is_training, global_pool, output_stride, include_root_block=True, @@ -269,6 +281,7 @@ def resnet_v1_50(inputs, def resnet_v1_101(inputs, num_classes=None, + is_training=None, global_pool=True, output_stride=None, reuse=None, @@ -284,6 +297,7 @@ def resnet_v1_101(inputs, inputs, blocks, num_classes, + is_training, global_pool, output_stride, include_root_block=True, @@ -293,6 +307,7 @@ def resnet_v1_101(inputs, def resnet_v1_152(inputs, num_classes=None, + is_training=None, global_pool=True, output_stride=None, reuse=None, @@ -308,6 +323,7 @@ def resnet_v1_152(inputs, inputs, blocks, num_classes, + is_training, global_pool, output_stride, include_root_block=True, @@ -317,6 +333,7 @@ def resnet_v1_152(inputs, def resnet_v1_200(inputs, num_classes=None, + is_training=None, global_pool=True, output_stride=None, reuse=None, @@ -332,6 +349,7 @@ def resnet_v1_200(inputs, inputs, blocks, num_classes, + is_training, global_pool, output_stride, include_root_block=True, diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py index dffd29f920..d510337fef 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py @@ -219,28 +219,29 @@ class ResnetUtilsTest(test.TestCase): # Test both odd and even input dimensions. height = 30 width = 31 - with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)): - for output_stride in [1, 2, 4, 8, None]: - with ops.Graph().as_default(): - with self.test_session() as sess: - random_seed.set_random_seed(0) - inputs = create_test_input(1, height, width, 3) - # Dense feature extraction followed by subsampling. - output = resnet_utils.stack_blocks_dense(inputs, blocks, - output_stride) - if output_stride is None: - factor = 1 - else: - factor = nominal_stride // output_stride - - output = resnet_utils.subsample(output, factor) - # Make the two networks use the same weights. - variable_scope.get_variable_scope().reuse_variables() - # Feature extraction at the nominal network rate. - expected = self._stack_blocks_nondense(inputs, blocks) - sess.run(variables.global_variables_initializer()) - output, expected = sess.run([output, expected]) - self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4) + with arg_scope(resnet_utils.resnet_arg_scope()): + with arg_scope([layers.batch_norm], is_training=False): + for output_stride in [1, 2, 4, 8, None]: + with ops.Graph().as_default(): + with self.test_session() as sess: + random_seed.set_random_seed(0) + inputs = create_test_input(1, height, width, 3) + # Dense feature extraction followed by subsampling. + output = resnet_utils.stack_blocks_dense(inputs, blocks, + output_stride) + if output_stride is None: + factor = 1 + else: + factor = nominal_stride // output_stride + + output = resnet_utils.subsample(output, factor) + # Make the two networks use the same weights. + variable_scope.get_variable_scope().reuse_variables() + # Feature extraction at the nominal network rate. + expected = self._stack_blocks_nondense(inputs, blocks) + sess.run(variables.global_variables_initializer()) + output, expected = sess.run([output, expected]) + self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4) class ResnetCompleteNetworkTest(test.TestCase): @@ -249,6 +250,7 @@ class ResnetCompleteNetworkTest(test.TestCase): def _resnet_small(self, inputs, num_classes=None, + is_training=None, global_pool=True, output_stride=None, include_root_block=True, @@ -262,8 +264,9 @@ class ResnetCompleteNetworkTest(test.TestCase): block('block3', base_depth=4, num_units=3, stride=2), block('block4', base_depth=8, num_units=2, stride=1), ] - return resnet_v1.resnet_v1(inputs, blocks, num_classes, global_pool, - output_stride, include_root_block, reuse, scope) + return resnet_v1.resnet_v1(inputs, blocks, num_classes, is_training, + global_pool, output_stride, include_root_block, + reuse, scope) def testClassificationEndPoints(self): global_pool = True @@ -271,7 +274,7 @@ class ResnetCompleteNetworkTest(test.TestCase): inputs = create_test_input(2, 224, 224, 3) with arg_scope(resnet_utils.resnet_arg_scope()): logits, end_points = self._resnet_small( - inputs, num_classes, global_pool, scope='resnet') + inputs, num_classes, global_pool=global_pool, scope='resnet') self.assertTrue(logits.op.name.startswith('resnet/logits')) self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes]) self.assertTrue('predictions' in end_points) @@ -284,7 +287,7 @@ class ResnetCompleteNetworkTest(test.TestCase): inputs = create_test_input(2, 224, 224, 3) with arg_scope(resnet_utils.resnet_arg_scope()): _, end_points = self._resnet_small( - inputs, num_classes, global_pool, scope='resnet') + inputs, num_classes, global_pool=global_pool, scope='resnet') endpoint_to_shape = { 'resnet/block1': [2, 28, 28, 4], 'resnet/block2': [2, 14, 14, 8], @@ -301,7 +304,7 @@ class ResnetCompleteNetworkTest(test.TestCase): inputs = create_test_input(2, 321, 321, 3) with arg_scope(resnet_utils.resnet_arg_scope()): _, end_points = self._resnet_small( - inputs, num_classes, global_pool, scope='resnet') + inputs, num_classes, global_pool=global_pool, scope='resnet') endpoint_to_shape = { 'resnet/block1': [2, 41, 41, 4], 'resnet/block2': [2, 21, 21, 8], @@ -320,7 +323,7 @@ class ResnetCompleteNetworkTest(test.TestCase): _, end_points = self._resnet_small( inputs, num_classes, - global_pool, + global_pool=global_pool, include_root_block=False, scope='resnet') endpoint_to_shape = { @@ -342,7 +345,7 @@ class ResnetCompleteNetworkTest(test.TestCase): _, end_points = self._resnet_small( inputs, num_classes, - global_pool, + global_pool=global_pool, output_stride=output_stride, scope='resnet') endpoint_to_shape = { @@ -359,14 +362,18 @@ class ResnetCompleteNetworkTest(test.TestCase): """Verify dense feature extraction with atrous convolution.""" nominal_stride = 32 for output_stride in [4, 8, 16, 32, None]: - with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)): + with arg_scope(resnet_utils.resnet_arg_scope()): with ops.Graph().as_default(): with self.test_session() as sess: random_seed.set_random_seed(0) inputs = create_test_input(2, 81, 81, 3) # Dense feature extraction followed by subsampling. output, _ = self._resnet_small( - inputs, None, global_pool=False, output_stride=output_stride) + inputs, + None, + is_training=False, + global_pool=False, + output_stride=output_stride) if output_stride is None: factor = 1 else: @@ -375,7 +382,8 @@ class ResnetCompleteNetworkTest(test.TestCase): # Make the two networks use the same weights. variable_scope.get_variable_scope().reuse_variables() # Feature extraction at the nominal network rate. - expected, _ = self._resnet_small(inputs, None, global_pool=False) + expected, _ = self._resnet_small( + inputs, None, is_training=False, global_pool=False) sess.run(variables.global_variables_initializer()) self.assertAllClose( output.eval(), expected.eval(), atol=1e-4, rtol=1e-4) @@ -388,7 +396,7 @@ class ResnetCompleteNetworkTest(test.TestCase): inputs = create_test_input(None, height, width, 3) with arg_scope(resnet_utils.resnet_arg_scope()): logits, _ = self._resnet_small( - inputs, num_classes, global_pool, scope='resnet') + inputs, num_classes, global_pool=global_pool, scope='resnet') self.assertTrue(logits.op.name.startswith('resnet/logits')) self.assertListEqual(logits.get_shape().as_list(), [None, 1, 1, num_classes]) @@ -404,7 +412,7 @@ class ResnetCompleteNetworkTest(test.TestCase): global_pool = False inputs = create_test_input(batch, None, None, 3) with arg_scope(resnet_utils.resnet_arg_scope()): - output, _ = self._resnet_small(inputs, None, global_pool) + output, _ = self._resnet_small(inputs, None, global_pool=global_pool) self.assertListEqual(output.get_shape().as_list(), [batch, None, None, 32]) images = create_test_input(batch, height, width, 3) with self.test_session() as sess: @@ -420,7 +428,7 @@ class ResnetCompleteNetworkTest(test.TestCase): inputs = create_test_input(batch, None, None, 3) with arg_scope(resnet_utils.resnet_arg_scope()): output, _ = self._resnet_small( - inputs, None, global_pool, output_stride=output_stride) + inputs, None, global_pool=global_pool, output_stride=output_stride) self.assertListEqual(output.get_shape().as_list(), [batch, None, None, 32]) images = create_test_input(batch, height, width, 3) with self.test_session() as sess: diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v2.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v2.py index 7e6fe5dfc2..f260ede348 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_v2.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v2.py @@ -36,15 +36,16 @@ Typical use: ResNet-101 for image classification into 1000 classes: # inputs has shape [batch, 224, 224, 3] - with slim.arg_scope(resnet_v2.resnet_arg_scope(is_training)): - net, end_points = resnet_v2.resnet_v2_101(inputs, 1000) + with slim.arg_scope(resnet_v2.resnet_arg_scope()): + net, end_points = resnet_v2.resnet_v2_101(inputs, 1000, is_training=False) ResNet-101 for semantic segmentation into 21 classes: # inputs has shape [batch, 513, 513, 3] - with slim.arg_scope(resnet_v2.resnet_arg_scope(is_training)): + with slim.arg_scope(resnet_v2.resnet_arg_scope()): net, end_points = resnet_v2.resnet_v2_101(inputs, 21, + is_training=False, global_pool=False, output_stride=16) """ @@ -129,6 +130,7 @@ def bottleneck(inputs, def resnet_v2(inputs, blocks, num_classes=None, + is_training=None, global_pool=True, output_stride=None, include_root_block=True, @@ -163,6 +165,8 @@ def resnet_v2(inputs, is a resnet_utils.Block object describing the units in the block. num_classes: Number of predicted classes for classification tasks. If None we return the features before the logit layer. + is_training: whether is training or not. If None, the value inherited from + the resnet_arg_scope is used. Specifying value None is deprecated. global_pool: If True, we perform global average pooling before computing the logits. Set to True for image classification, False for dense prediction. output_stride: If None, then the output will be computed at the nominal @@ -196,38 +200,45 @@ def resnet_v2(inputs, with arg_scope( [layers_lib.conv2d, bottleneck, resnet_utils.stack_blocks_dense], outputs_collections=end_points_collection): - net = inputs - if include_root_block: - if output_stride is not None: - if output_stride % 4 != 0: - raise ValueError('The output_stride needs to be a multiple of 4.') - output_stride /= 4 - # We do not include batch normalization or activation functions in conv1 - # because the first ResNet unit will perform these. Cf. Appendix of [2]. - with arg_scope( - [layers_lib.conv2d], activation_fn=None, normalizer_fn=None): - net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') - net = layers.max_pool2d(net, [3, 3], stride=2, scope='pool1') - net = resnet_utils.stack_blocks_dense(net, blocks, output_stride) - # This is needed because the pre-activation variant does not have batch - # normalization or activation functions in the residual unit output. See - # Appendix of [2]. - net = layers.batch_norm(net, activation_fn=nn_ops.relu, scope='postnorm') - if global_pool: - # Global average pooling. - net = math_ops.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) - if num_classes is not None: - net = layers_lib.conv2d( - net, - num_classes, [1, 1], - activation_fn=None, - normalizer_fn=None, - scope='logits') - # Convert end_points_collection into a dictionary of end_points. - end_points = utils.convert_collection_to_dict(end_points_collection) - if num_classes is not None: - end_points['predictions'] = layers.softmax(net, scope='predictions') - return net, end_points + if is_training is not None: + bn_scope = arg_scope([layers.batch_norm], is_training=is_training) + else: + bn_scope = arg_scope([]) + with bn_scope: + net = inputs + if include_root_block: + if output_stride is not None: + if output_stride % 4 != 0: + raise ValueError('The output_stride needs to be a multiple of 4.') + output_stride /= 4 + # We do not include batch normalization or activation functions in + # conv1 because the first ResNet unit will perform these. Cf. + # Appendix of [2]. + with arg_scope( + [layers_lib.conv2d], activation_fn=None, normalizer_fn=None): + net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') + net = layers.max_pool2d(net, [3, 3], stride=2, scope='pool1') + net = resnet_utils.stack_blocks_dense(net, blocks, output_stride) + # This is needed because the pre-activation variant does not have batch + # normalization or activation functions in the residual unit output. See + # Appendix of [2]. + net = layers.batch_norm( + net, activation_fn=nn_ops.relu, scope='postnorm') + if global_pool: + # Global average pooling. + net = math_ops.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) + if num_classes is not None: + net = layers_lib.conv2d( + net, + num_classes, [1, 1], + activation_fn=None, + normalizer_fn=None, + scope='logits') + # Convert end_points_collection into a dictionary of end_points. + end_points = utils.convert_collection_to_dict(end_points_collection) + if num_classes is not None: + end_points['predictions'] = layers.softmax(net, scope='predictions') + return net, end_points resnet_v2.default_image_size = 224 @@ -257,6 +268,7 @@ def resnet_v2_block(scope, base_depth, num_units, stride): def resnet_v2_50(inputs, num_classes=None, + is_training=None, global_pool=True, output_stride=None, reuse=None, @@ -272,6 +284,7 @@ def resnet_v2_50(inputs, inputs, blocks, num_classes, + is_training, global_pool, output_stride, include_root_block=True, @@ -282,6 +295,7 @@ def resnet_v2_50(inputs, def resnet_v2_101(inputs, num_classes=None, global_pool=True, + is_training=None, output_stride=None, reuse=None, scope='resnet_v2_101'): @@ -296,6 +310,7 @@ def resnet_v2_101(inputs, inputs, blocks, num_classes, + is_training, global_pool, output_stride, include_root_block=True, @@ -305,6 +320,7 @@ def resnet_v2_101(inputs, def resnet_v2_152(inputs, num_classes=None, + is_training=None, global_pool=True, output_stride=None, reuse=None, @@ -320,6 +336,7 @@ def resnet_v2_152(inputs, inputs, blocks, num_classes, + is_training, global_pool, output_stride, include_root_block=True, @@ -329,6 +346,7 @@ def resnet_v2_152(inputs, def resnet_v2_200(inputs, num_classes=None, + is_training=None, global_pool=True, output_stride=None, reuse=None, @@ -344,6 +362,7 @@ def resnet_v2_200(inputs, inputs, blocks, num_classes, + is_training, global_pool, output_stride, include_root_block=True, diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py index 1c09bcbb5a..c4f3b071fd 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py @@ -223,28 +223,29 @@ class ResnetUtilsTest(test.TestCase): # Test both odd and even input dimensions. height = 30 width = 31 - with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)): - for output_stride in [1, 2, 4, 8, None]: - with ops.Graph().as_default(): - with self.test_session() as sess: - random_seed.set_random_seed(0) - inputs = create_test_input(1, height, width, 3) - # Dense feature extraction followed by subsampling. - output = resnet_utils.stack_blocks_dense(inputs, blocks, - output_stride) - if output_stride is None: - factor = 1 - else: - factor = nominal_stride // output_stride - - output = resnet_utils.subsample(output, factor) - # Make the two networks use the same weights. - variable_scope.get_variable_scope().reuse_variables() - # Feature extraction at the nominal network rate. - expected = self._stack_blocks_nondense(inputs, blocks) - sess.run(variables.global_variables_initializer()) - output, expected = sess.run([output, expected]) - self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4) + with arg_scope(resnet_utils.resnet_arg_scope()): + with arg_scope([layers.batch_norm], is_training=False): + for output_stride in [1, 2, 4, 8, None]: + with ops.Graph().as_default(): + with self.test_session() as sess: + random_seed.set_random_seed(0) + inputs = create_test_input(1, height, width, 3) + # Dense feature extraction followed by subsampling. + output = resnet_utils.stack_blocks_dense(inputs, blocks, + output_stride) + if output_stride is None: + factor = 1 + else: + factor = nominal_stride // output_stride + + output = resnet_utils.subsample(output, factor) + # Make the two networks use the same weights. + variable_scope.get_variable_scope().reuse_variables() + # Feature extraction at the nominal network rate. + expected = self._stack_blocks_nondense(inputs, blocks) + sess.run(variables.global_variables_initializer()) + output, expected = sess.run([output, expected]) + self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4) class ResnetCompleteNetworkTest(test.TestCase): @@ -253,6 +254,7 @@ class ResnetCompleteNetworkTest(test.TestCase): def _resnet_small(self, inputs, num_classes=None, + is_training=None, global_pool=True, output_stride=None, include_root_block=True, @@ -266,8 +268,9 @@ class ResnetCompleteNetworkTest(test.TestCase): block('block3', base_depth=4, num_units=3, stride=2), block('block4', base_depth=8, num_units=2, stride=1), ] - return resnet_v2.resnet_v2(inputs, blocks, num_classes, global_pool, - output_stride, include_root_block, reuse, scope) + return resnet_v2.resnet_v2(inputs, blocks, num_classes, is_training, + global_pool, output_stride, include_root_block, + reuse, scope) def testClassificationEndPoints(self): global_pool = True @@ -275,7 +278,7 @@ class ResnetCompleteNetworkTest(test.TestCase): inputs = create_test_input(2, 224, 224, 3) with arg_scope(resnet_utils.resnet_arg_scope()): logits, end_points = self._resnet_small( - inputs, num_classes, global_pool, scope='resnet') + inputs, num_classes, global_pool=global_pool, scope='resnet') self.assertTrue(logits.op.name.startswith('resnet/logits')) self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes]) self.assertTrue('predictions' in end_points) @@ -288,7 +291,7 @@ class ResnetCompleteNetworkTest(test.TestCase): inputs = create_test_input(2, 224, 224, 3) with arg_scope(resnet_utils.resnet_arg_scope()): _, end_points = self._resnet_small( - inputs, num_classes, global_pool, scope='resnet') + inputs, num_classes, global_pool=global_pool, scope='resnet') endpoint_to_shape = { 'resnet/block1': [2, 28, 28, 4], 'resnet/block2': [2, 14, 14, 8], @@ -305,7 +308,7 @@ class ResnetCompleteNetworkTest(test.TestCase): inputs = create_test_input(2, 321, 321, 3) with arg_scope(resnet_utils.resnet_arg_scope()): _, end_points = self._resnet_small( - inputs, num_classes, global_pool, scope='resnet') + inputs, num_classes, global_pool=global_pool, scope='resnet') endpoint_to_shape = { 'resnet/block1': [2, 41, 41, 4], 'resnet/block2': [2, 21, 21, 8], @@ -324,7 +327,7 @@ class ResnetCompleteNetworkTest(test.TestCase): _, end_points = self._resnet_small( inputs, num_classes, - global_pool, + global_pool=global_pool, include_root_block=False, scope='resnet') endpoint_to_shape = { @@ -346,7 +349,7 @@ class ResnetCompleteNetworkTest(test.TestCase): _, end_points = self._resnet_small( inputs, num_classes, - global_pool, + global_pool=global_pool, output_stride=output_stride, scope='resnet') endpoint_to_shape = { @@ -363,14 +366,18 @@ class ResnetCompleteNetworkTest(test.TestCase): """Verify dense feature extraction with atrous convolution.""" nominal_stride = 32 for output_stride in [4, 8, 16, 32, None]: - with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)): + with arg_scope(resnet_utils.resnet_arg_scope()): with ops.Graph().as_default(): with self.test_session() as sess: random_seed.set_random_seed(0) inputs = create_test_input(2, 81, 81, 3) # Dense feature extraction followed by subsampling. output, _ = self._resnet_small( - inputs, None, global_pool=False, output_stride=output_stride) + inputs, + None, + is_training=False, + global_pool=False, + output_stride=output_stride) if output_stride is None: factor = 1 else: @@ -379,7 +386,8 @@ class ResnetCompleteNetworkTest(test.TestCase): # Make the two networks use the same weights. variable_scope.get_variable_scope().reuse_variables() # Feature extraction at the nominal network rate. - expected, _ = self._resnet_small(inputs, None, global_pool=False) + expected, _ = self._resnet_small( + inputs, None, is_training=False, global_pool=False) sess.run(variables.global_variables_initializer()) self.assertAllClose( output.eval(), expected.eval(), atol=1e-4, rtol=1e-4) @@ -392,7 +400,7 @@ class ResnetCompleteNetworkTest(test.TestCase): inputs = create_test_input(None, height, width, 3) with arg_scope(resnet_utils.resnet_arg_scope()): logits, _ = self._resnet_small( - inputs, num_classes, global_pool, scope='resnet') + inputs, num_classes, global_pool=global_pool, scope='resnet') self.assertTrue(logits.op.name.startswith('resnet/logits')) self.assertListEqual(logits.get_shape().as_list(), [None, 1, 1, num_classes]) @@ -408,7 +416,7 @@ class ResnetCompleteNetworkTest(test.TestCase): global_pool = False inputs = create_test_input(batch, None, None, 3) with arg_scope(resnet_utils.resnet_arg_scope()): - output, _ = self._resnet_small(inputs, None, global_pool) + output, _ = self._resnet_small(inputs, None, global_pool=global_pool) self.assertListEqual(output.get_shape().as_list(), [batch, None, None, 32]) images = create_test_input(batch, height, width, 3) with self.test_session() as sess: @@ -424,7 +432,7 @@ class ResnetCompleteNetworkTest(test.TestCase): inputs = create_test_input(batch, None, None, 3) with arg_scope(resnet_utils.resnet_arg_scope()): output, _ = self._resnet_small( - inputs, None, global_pool, output_stride=output_stride) + inputs, None, global_pool=global_pool, output_stride=output_stride) self.assertListEqual(output.get_shape().as_list(), [batch, None, None, 32]) images = create_test_input(batch, height, width, 3) with self.test_session() as sess: |