aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Jonathan Shen <jonathanasdf@google.com>2017-04-25 14:03:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-25 15:24:32 -0700
commitfa7f79541d1ba1fa6f54ea3474cbf3ce1411bfb4 (patch)
tree76f02110b46ae420aed18f344ebfe548132cdb81 /tensorflow
parent1d8ac37322983c71f02d0300766740172ad00485 (diff)
Deprecate the use of is_training in resnet_arg_scope. is_training should be
directly passed to the network instead. Change: 154226768
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/BUILD20
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/resnet_is_training_test.py154
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/resnet_utils.py10
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/resnet_v1.py72
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py78
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/resnet_v2.py89
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py78
7 files changed, 366 insertions, 135 deletions
diff --git a/tensorflow/contrib/slim/python/slim/nets/BUILD b/tensorflow/contrib/slim/python/slim/nets/BUILD
index 8077818216..737bbbe57b 100644
--- a/tensorflow/contrib/slim/python/slim/nets/BUILD
+++ b/tensorflow/contrib/slim/python/slim/nets/BUILD
@@ -286,6 +286,26 @@ py_test(
],
)
+py_test(
+ name = "resnet_is_training_test",
+ size = "medium",
+ srcs = ["resnet_is_training_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":resnet_utils",
+ ":resnet_v1",
+ ":resnet_v2",
+ "//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
py_library(
name = "vgg",
srcs = ["vgg.py"],
diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_is_training_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_is_training_test.py
new file mode 100644
index 0000000000..9a165577b6
--- /dev/null
+++ b/tensorflow/contrib/slim/python/slim/nets/resnet_is_training_test.py
@@ -0,0 +1,154 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Specifying is_training in resnet_arg_scope is being deprecated.
+
+Test that everything behaves as expected in the meantime.
+
+Note: This test modifies the layers.batch_norm function.
+Other tests that use layers.batch_norm may not work if added to this file.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib import layers
+from tensorflow.contrib.framework.python.ops import add_arg_scope
+from tensorflow.contrib.framework.python.ops import arg_scope
+from tensorflow.contrib.slim.python.slim.nets import resnet_utils
+from tensorflow.contrib.slim.python.slim.nets import resnet_v1
+from tensorflow.contrib.slim.python.slim.nets import resnet_v2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+def create_test_input(batch, height, width, channels):
+ """Create test input tensor."""
+ if None in [batch, height, width, channels]:
+ return array_ops.placeholder(dtypes.float32, (batch, height, width,
+ channels))
+ else:
+ return math_ops.to_float(
+ np.tile(
+ np.reshape(
+ np.reshape(np.arange(height), [height, 1]) +
+ np.reshape(np.arange(width), [1, width]),
+ [1, height, width, 1]),
+ [batch, 1, 1, channels]))
+
+
+class ResnetIsTrainingTest(test.TestCase):
+
+ def _testDeprecatingIsTraining(self, network_fn):
+ batch_norm_fn = layers.batch_norm
+
+ @add_arg_scope
+ def batch_norm_expect_is_training(*args, **kwargs):
+ assert kwargs['is_training']
+ return batch_norm_fn(*args, **kwargs)
+
+ @add_arg_scope
+ def batch_norm_expect_is_not_training(*args, **kwargs):
+ assert not kwargs['is_training']
+ return batch_norm_fn(*args, **kwargs)
+
+ global_pool = True
+ num_classes = 10
+ inputs = create_test_input(2, 224, 224, 3)
+
+ # Default argument for resnet_arg_scope
+ layers.batch_norm = batch_norm_expect_is_training
+ with arg_scope(resnet_utils.resnet_arg_scope()):
+ network_fn(inputs, num_classes, global_pool=global_pool, scope='resnet1')
+
+ layers.batch_norm = batch_norm_expect_is_training
+ with arg_scope(resnet_utils.resnet_arg_scope()):
+ network_fn(
+ inputs,
+ num_classes,
+ is_training=True,
+ global_pool=global_pool,
+ scope='resnet2')
+
+ layers.batch_norm = batch_norm_expect_is_not_training
+ with arg_scope(resnet_utils.resnet_arg_scope()):
+ network_fn(
+ inputs,
+ num_classes,
+ is_training=False,
+ global_pool=global_pool,
+ scope='resnet3')
+
+ # resnet_arg_scope with is_training set to True (deprecated)
+ layers.batch_norm = batch_norm_expect_is_training
+ with arg_scope(resnet_utils.resnet_arg_scope(is_training=True)):
+ network_fn(inputs, num_classes, global_pool=global_pool, scope='resnet4')
+
+ layers.batch_norm = batch_norm_expect_is_training
+ with arg_scope(resnet_utils.resnet_arg_scope(is_training=True)):
+ network_fn(
+ inputs,
+ num_classes,
+ is_training=True,
+ global_pool=global_pool,
+ scope='resnet5')
+
+ layers.batch_norm = batch_norm_expect_is_not_training
+ with arg_scope(resnet_utils.resnet_arg_scope(is_training=True)):
+ network_fn(
+ inputs,
+ num_classes,
+ is_training=False,
+ global_pool=global_pool,
+ scope='resnet6')
+
+ # resnet_arg_scope with is_training set to False (deprecated)
+ layers.batch_norm = batch_norm_expect_is_not_training
+ with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)):
+ network_fn(inputs, num_classes, global_pool=global_pool, scope='resnet7')
+
+ layers.batch_norm = batch_norm_expect_is_training
+ with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)):
+ network_fn(
+ inputs,
+ num_classes,
+ is_training=True,
+ global_pool=global_pool,
+ scope='resnet8')
+
+ layers.batch_norm = batch_norm_expect_is_not_training
+ with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)):
+ network_fn(
+ inputs,
+ num_classes,
+ is_training=False,
+ global_pool=global_pool,
+ scope='resnet9')
+
+ layers.batch_norm = batch_norm_fn
+
+ def testDeprecatingIsTrainingResnetV1(self):
+ self._testDeprecatingIsTraining(resnet_v1.resnet_v1_50)
+
+ def testDeprecatingIsTrainingResnetV2(self):
+ self._testDeprecatingIsTraining(resnet_v2.resnet_v2_50)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py b/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py
index 89d27438e5..58614a998a 100644
--- a/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py
+++ b/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py
@@ -41,6 +41,7 @@ from __future__ import print_function
import collections
from tensorflow.contrib import layers as layers_lib
+from tensorflow.contrib.framework import deprecated_args
from tensorflow.contrib.framework.python.ops import add_arg_scope
from tensorflow.contrib.framework.python.ops import arg_scope
from tensorflow.contrib.layers.python.layers import initializers
@@ -222,6 +223,10 @@ def stack_blocks_dense(net,
return net
+@deprecated_args(
+ '2017-08-01',
+ 'Pass is_training directly to the network instead of the arg_scope.',
+ 'is_training')
def resnet_arg_scope(is_training=True,
weight_decay=0.0001,
batch_norm_decay=0.997,
@@ -236,7 +241,7 @@ def resnet_arg_scope(is_training=True,
Args:
is_training: Whether or not we are training the parameters in the batch
- normalization layers of the model.
+ normalization layers of the model. (deprecated)
weight_decay: The weight decay to use for regularizing the model.
batch_norm_decay: The moving average decay when estimating layer activation
statistics in batch normalization.
@@ -261,8 +266,7 @@ def resnet_arg_scope(is_training=True,
weights_regularizer=regularizers.l2_regularizer(weight_decay),
weights_initializer=initializers.variance_scaling_initializer(),
activation_fn=nn_ops.relu,
- normalizer_fn=layers.batch_norm,
- normalizer_params=batch_norm_params):
+ normalizer_fn=layers.batch_norm):
with arg_scope([layers.batch_norm], **batch_norm_params):
# The following implies padding='SAME' for pool1, which makes feature
# alignment easier for dense prediction tasks. This is also used in
diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py
index fe13ce1b0e..90f93d46e3 100644
--- a/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py
+++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py
@@ -40,15 +40,16 @@ Typical use:
ResNet-101 for image classification into 1000 classes:
# inputs has shape [batch, 224, 224, 3]
- with slim.arg_scope(resnet_v1.resnet_arg_scope(is_training)):
- net, end_points = resnet_v1.resnet_v1_101(inputs, 1000)
+ with slim.arg_scope(resnet_v1.resnet_arg_scope()):
+ net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False)
ResNet-101 for semantic segmentation into 21 classes:
# inputs has shape [batch, 513, 513, 3]
- with slim.arg_scope(resnet_v1.resnet_arg_scope(is_training)):
+ with slim.arg_scope(resnet_v1.resnet_arg_scope()):
net, end_points = resnet_v1.resnet_v1_101(inputs,
21,
+ is_training=False,
global_pool=False,
output_stride=16)
"""
@@ -127,6 +128,7 @@ def bottleneck(inputs,
def resnet_v1(inputs,
blocks,
num_classes=None,
+ is_training=None,
global_pool=True,
output_stride=None,
include_root_block=True,
@@ -161,6 +163,8 @@ def resnet_v1(inputs,
is a resnet_utils.Block object describing the units in the block.
num_classes: Number of predicted classes for classification tasks. If None
we return the features before the logit layer.
+ is_training: whether is training or not. If None, the value inherited from
+ the resnet_arg_scope is used. Specifying value None is deprecated.
global_pool: If True, we perform global average pooling before computing the
logits. Set to True for image classification, False for dense prediction.
output_stride: If None, then the output will be computed at the nominal
@@ -192,30 +196,36 @@ def resnet_v1(inputs,
with arg_scope(
[layers.conv2d, bottleneck, resnet_utils.stack_blocks_dense],
outputs_collections=end_points_collection):
- net = inputs
- if include_root_block:
- if output_stride is not None:
- if output_stride % 4 != 0:
- raise ValueError('The output_stride needs to be a multiple of 4.')
- output_stride /= 4
- net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
- net = layers_lib.max_pool2d(net, [3, 3], stride=2, scope='pool1')
- net = resnet_utils.stack_blocks_dense(net, blocks, output_stride)
- if global_pool:
- # Global average pooling.
- net = math_ops.reduce_mean(net, [1, 2], name='pool5', keep_dims=True)
- if num_classes is not None:
- net = layers.conv2d(
- net,
- num_classes, [1, 1],
- activation_fn=None,
- normalizer_fn=None,
- scope='logits')
- # Convert end_points_collection into a dictionary of end_points.
- end_points = utils.convert_collection_to_dict(end_points_collection)
- if num_classes is not None:
- end_points['predictions'] = layers_lib.softmax(net, scope='predictions')
- return net, end_points
+ if is_training is not None:
+ bn_scope = arg_scope([layers.batch_norm], is_training=is_training)
+ else:
+ bn_scope = arg_scope([])
+ with bn_scope:
+ net = inputs
+ if include_root_block:
+ if output_stride is not None:
+ if output_stride % 4 != 0:
+ raise ValueError('The output_stride needs to be a multiple of 4.')
+ output_stride /= 4
+ net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
+ net = layers_lib.max_pool2d(net, [3, 3], stride=2, scope='pool1')
+ net = resnet_utils.stack_blocks_dense(net, blocks, output_stride)
+ if global_pool:
+ # Global average pooling.
+ net = math_ops.reduce_mean(net, [1, 2], name='pool5', keep_dims=True)
+ if num_classes is not None:
+ net = layers.conv2d(
+ net,
+ num_classes, [1, 1],
+ activation_fn=None,
+ normalizer_fn=None,
+ scope='logits')
+ # Convert end_points_collection into a dictionary of end_points.
+ end_points = utils.convert_collection_to_dict(end_points_collection)
+ if num_classes is not None:
+ end_points['predictions'] = layers_lib.softmax(
+ net, scope='predictions')
+ return net, end_points
resnet_v1.default_image_size = 224
@@ -245,6 +255,7 @@ def resnet_v1_block(scope, base_depth, num_units, stride):
def resnet_v1_50(inputs,
num_classes=None,
+ is_training=None,
global_pool=True,
output_stride=None,
reuse=None,
@@ -260,6 +271,7 @@ def resnet_v1_50(inputs,
inputs,
blocks,
num_classes,
+ is_training,
global_pool,
output_stride,
include_root_block=True,
@@ -269,6 +281,7 @@ def resnet_v1_50(inputs,
def resnet_v1_101(inputs,
num_classes=None,
+ is_training=None,
global_pool=True,
output_stride=None,
reuse=None,
@@ -284,6 +297,7 @@ def resnet_v1_101(inputs,
inputs,
blocks,
num_classes,
+ is_training,
global_pool,
output_stride,
include_root_block=True,
@@ -293,6 +307,7 @@ def resnet_v1_101(inputs,
def resnet_v1_152(inputs,
num_classes=None,
+ is_training=None,
global_pool=True,
output_stride=None,
reuse=None,
@@ -308,6 +323,7 @@ def resnet_v1_152(inputs,
inputs,
blocks,
num_classes,
+ is_training,
global_pool,
output_stride,
include_root_block=True,
@@ -317,6 +333,7 @@ def resnet_v1_152(inputs,
def resnet_v1_200(inputs,
num_classes=None,
+ is_training=None,
global_pool=True,
output_stride=None,
reuse=None,
@@ -332,6 +349,7 @@ def resnet_v1_200(inputs,
inputs,
blocks,
num_classes,
+ is_training,
global_pool,
output_stride,
include_root_block=True,
diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py
index dffd29f920..d510337fef 100644
--- a/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py
+++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py
@@ -219,28 +219,29 @@ class ResnetUtilsTest(test.TestCase):
# Test both odd and even input dimensions.
height = 30
width = 31
- with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)):
- for output_stride in [1, 2, 4, 8, None]:
- with ops.Graph().as_default():
- with self.test_session() as sess:
- random_seed.set_random_seed(0)
- inputs = create_test_input(1, height, width, 3)
- # Dense feature extraction followed by subsampling.
- output = resnet_utils.stack_blocks_dense(inputs, blocks,
- output_stride)
- if output_stride is None:
- factor = 1
- else:
- factor = nominal_stride // output_stride
-
- output = resnet_utils.subsample(output, factor)
- # Make the two networks use the same weights.
- variable_scope.get_variable_scope().reuse_variables()
- # Feature extraction at the nominal network rate.
- expected = self._stack_blocks_nondense(inputs, blocks)
- sess.run(variables.global_variables_initializer())
- output, expected = sess.run([output, expected])
- self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4)
+ with arg_scope(resnet_utils.resnet_arg_scope()):
+ with arg_scope([layers.batch_norm], is_training=False):
+ for output_stride in [1, 2, 4, 8, None]:
+ with ops.Graph().as_default():
+ with self.test_session() as sess:
+ random_seed.set_random_seed(0)
+ inputs = create_test_input(1, height, width, 3)
+ # Dense feature extraction followed by subsampling.
+ output = resnet_utils.stack_blocks_dense(inputs, blocks,
+ output_stride)
+ if output_stride is None:
+ factor = 1
+ else:
+ factor = nominal_stride // output_stride
+
+ output = resnet_utils.subsample(output, factor)
+ # Make the two networks use the same weights.
+ variable_scope.get_variable_scope().reuse_variables()
+ # Feature extraction at the nominal network rate.
+ expected = self._stack_blocks_nondense(inputs, blocks)
+ sess.run(variables.global_variables_initializer())
+ output, expected = sess.run([output, expected])
+ self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4)
class ResnetCompleteNetworkTest(test.TestCase):
@@ -249,6 +250,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
def _resnet_small(self,
inputs,
num_classes=None,
+ is_training=None,
global_pool=True,
output_stride=None,
include_root_block=True,
@@ -262,8 +264,9 @@ class ResnetCompleteNetworkTest(test.TestCase):
block('block3', base_depth=4, num_units=3, stride=2),
block('block4', base_depth=8, num_units=2, stride=1),
]
- return resnet_v1.resnet_v1(inputs, blocks, num_classes, global_pool,
- output_stride, include_root_block, reuse, scope)
+ return resnet_v1.resnet_v1(inputs, blocks, num_classes, is_training,
+ global_pool, output_stride, include_root_block,
+ reuse, scope)
def testClassificationEndPoints(self):
global_pool = True
@@ -271,7 +274,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
inputs = create_test_input(2, 224, 224, 3)
with arg_scope(resnet_utils.resnet_arg_scope()):
logits, end_points = self._resnet_small(
- inputs, num_classes, global_pool, scope='resnet')
+ inputs, num_classes, global_pool=global_pool, scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
self.assertTrue('predictions' in end_points)
@@ -284,7 +287,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
inputs = create_test_input(2, 224, 224, 3)
with arg_scope(resnet_utils.resnet_arg_scope()):
_, end_points = self._resnet_small(
- inputs, num_classes, global_pool, scope='resnet')
+ inputs, num_classes, global_pool=global_pool, scope='resnet')
endpoint_to_shape = {
'resnet/block1': [2, 28, 28, 4],
'resnet/block2': [2, 14, 14, 8],
@@ -301,7 +304,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
inputs = create_test_input(2, 321, 321, 3)
with arg_scope(resnet_utils.resnet_arg_scope()):
_, end_points = self._resnet_small(
- inputs, num_classes, global_pool, scope='resnet')
+ inputs, num_classes, global_pool=global_pool, scope='resnet')
endpoint_to_shape = {
'resnet/block1': [2, 41, 41, 4],
'resnet/block2': [2, 21, 21, 8],
@@ -320,7 +323,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
_, end_points = self._resnet_small(
inputs,
num_classes,
- global_pool,
+ global_pool=global_pool,
include_root_block=False,
scope='resnet')
endpoint_to_shape = {
@@ -342,7 +345,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
_, end_points = self._resnet_small(
inputs,
num_classes,
- global_pool,
+ global_pool=global_pool,
output_stride=output_stride,
scope='resnet')
endpoint_to_shape = {
@@ -359,14 +362,18 @@ class ResnetCompleteNetworkTest(test.TestCase):
"""Verify dense feature extraction with atrous convolution."""
nominal_stride = 32
for output_stride in [4, 8, 16, 32, None]:
- with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)):
+ with arg_scope(resnet_utils.resnet_arg_scope()):
with ops.Graph().as_default():
with self.test_session() as sess:
random_seed.set_random_seed(0)
inputs = create_test_input(2, 81, 81, 3)
# Dense feature extraction followed by subsampling.
output, _ = self._resnet_small(
- inputs, None, global_pool=False, output_stride=output_stride)
+ inputs,
+ None,
+ is_training=False,
+ global_pool=False,
+ output_stride=output_stride)
if output_stride is None:
factor = 1
else:
@@ -375,7 +382,8 @@ class ResnetCompleteNetworkTest(test.TestCase):
# Make the two networks use the same weights.
variable_scope.get_variable_scope().reuse_variables()
# Feature extraction at the nominal network rate.
- expected, _ = self._resnet_small(inputs, None, global_pool=False)
+ expected, _ = self._resnet_small(
+ inputs, None, is_training=False, global_pool=False)
sess.run(variables.global_variables_initializer())
self.assertAllClose(
output.eval(), expected.eval(), atol=1e-4, rtol=1e-4)
@@ -388,7 +396,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
inputs = create_test_input(None, height, width, 3)
with arg_scope(resnet_utils.resnet_arg_scope()):
logits, _ = self._resnet_small(
- inputs, num_classes, global_pool, scope='resnet')
+ inputs, num_classes, global_pool=global_pool, scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(),
[None, 1, 1, num_classes])
@@ -404,7 +412,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
global_pool = False
inputs = create_test_input(batch, None, None, 3)
with arg_scope(resnet_utils.resnet_arg_scope()):
- output, _ = self._resnet_small(inputs, None, global_pool)
+ output, _ = self._resnet_small(inputs, None, global_pool=global_pool)
self.assertListEqual(output.get_shape().as_list(), [batch, None, None, 32])
images = create_test_input(batch, height, width, 3)
with self.test_session() as sess:
@@ -420,7 +428,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
inputs = create_test_input(batch, None, None, 3)
with arg_scope(resnet_utils.resnet_arg_scope()):
output, _ = self._resnet_small(
- inputs, None, global_pool, output_stride=output_stride)
+ inputs, None, global_pool=global_pool, output_stride=output_stride)
self.assertListEqual(output.get_shape().as_list(), [batch, None, None, 32])
images = create_test_input(batch, height, width, 3)
with self.test_session() as sess:
diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v2.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v2.py
index 7e6fe5dfc2..f260ede348 100644
--- a/tensorflow/contrib/slim/python/slim/nets/resnet_v2.py
+++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v2.py
@@ -36,15 +36,16 @@ Typical use:
ResNet-101 for image classification into 1000 classes:
# inputs has shape [batch, 224, 224, 3]
- with slim.arg_scope(resnet_v2.resnet_arg_scope(is_training)):
- net, end_points = resnet_v2.resnet_v2_101(inputs, 1000)
+ with slim.arg_scope(resnet_v2.resnet_arg_scope()):
+ net, end_points = resnet_v2.resnet_v2_101(inputs, 1000, is_training=False)
ResNet-101 for semantic segmentation into 21 classes:
# inputs has shape [batch, 513, 513, 3]
- with slim.arg_scope(resnet_v2.resnet_arg_scope(is_training)):
+ with slim.arg_scope(resnet_v2.resnet_arg_scope()):
net, end_points = resnet_v2.resnet_v2_101(inputs,
21,
+ is_training=False,
global_pool=False,
output_stride=16)
"""
@@ -129,6 +130,7 @@ def bottleneck(inputs,
def resnet_v2(inputs,
blocks,
num_classes=None,
+ is_training=None,
global_pool=True,
output_stride=None,
include_root_block=True,
@@ -163,6 +165,8 @@ def resnet_v2(inputs,
is a resnet_utils.Block object describing the units in the block.
num_classes: Number of predicted classes for classification tasks. If None
we return the features before the logit layer.
+ is_training: whether is training or not. If None, the value inherited from
+ the resnet_arg_scope is used. Specifying value None is deprecated.
global_pool: If True, we perform global average pooling before computing the
logits. Set to True for image classification, False for dense prediction.
output_stride: If None, then the output will be computed at the nominal
@@ -196,38 +200,45 @@ def resnet_v2(inputs,
with arg_scope(
[layers_lib.conv2d, bottleneck, resnet_utils.stack_blocks_dense],
outputs_collections=end_points_collection):
- net = inputs
- if include_root_block:
- if output_stride is not None:
- if output_stride % 4 != 0:
- raise ValueError('The output_stride needs to be a multiple of 4.')
- output_stride /= 4
- # We do not include batch normalization or activation functions in conv1
- # because the first ResNet unit will perform these. Cf. Appendix of [2].
- with arg_scope(
- [layers_lib.conv2d], activation_fn=None, normalizer_fn=None):
- net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
- net = layers.max_pool2d(net, [3, 3], stride=2, scope='pool1')
- net = resnet_utils.stack_blocks_dense(net, blocks, output_stride)
- # This is needed because the pre-activation variant does not have batch
- # normalization or activation functions in the residual unit output. See
- # Appendix of [2].
- net = layers.batch_norm(net, activation_fn=nn_ops.relu, scope='postnorm')
- if global_pool:
- # Global average pooling.
- net = math_ops.reduce_mean(net, [1, 2], name='pool5', keep_dims=True)
- if num_classes is not None:
- net = layers_lib.conv2d(
- net,
- num_classes, [1, 1],
- activation_fn=None,
- normalizer_fn=None,
- scope='logits')
- # Convert end_points_collection into a dictionary of end_points.
- end_points = utils.convert_collection_to_dict(end_points_collection)
- if num_classes is not None:
- end_points['predictions'] = layers.softmax(net, scope='predictions')
- return net, end_points
+ if is_training is not None:
+ bn_scope = arg_scope([layers.batch_norm], is_training=is_training)
+ else:
+ bn_scope = arg_scope([])
+ with bn_scope:
+ net = inputs
+ if include_root_block:
+ if output_stride is not None:
+ if output_stride % 4 != 0:
+ raise ValueError('The output_stride needs to be a multiple of 4.')
+ output_stride /= 4
+ # We do not include batch normalization or activation functions in
+ # conv1 because the first ResNet unit will perform these. Cf.
+ # Appendix of [2].
+ with arg_scope(
+ [layers_lib.conv2d], activation_fn=None, normalizer_fn=None):
+ net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
+ net = layers.max_pool2d(net, [3, 3], stride=2, scope='pool1')
+ net = resnet_utils.stack_blocks_dense(net, blocks, output_stride)
+ # This is needed because the pre-activation variant does not have batch
+ # normalization or activation functions in the residual unit output. See
+ # Appendix of [2].
+ net = layers.batch_norm(
+ net, activation_fn=nn_ops.relu, scope='postnorm')
+ if global_pool:
+ # Global average pooling.
+ net = math_ops.reduce_mean(net, [1, 2], name='pool5', keep_dims=True)
+ if num_classes is not None:
+ net = layers_lib.conv2d(
+ net,
+ num_classes, [1, 1],
+ activation_fn=None,
+ normalizer_fn=None,
+ scope='logits')
+ # Convert end_points_collection into a dictionary of end_points.
+ end_points = utils.convert_collection_to_dict(end_points_collection)
+ if num_classes is not None:
+ end_points['predictions'] = layers.softmax(net, scope='predictions')
+ return net, end_points
resnet_v2.default_image_size = 224
@@ -257,6 +268,7 @@ def resnet_v2_block(scope, base_depth, num_units, stride):
def resnet_v2_50(inputs,
num_classes=None,
+ is_training=None,
global_pool=True,
output_stride=None,
reuse=None,
@@ -272,6 +284,7 @@ def resnet_v2_50(inputs,
inputs,
blocks,
num_classes,
+ is_training,
global_pool,
output_stride,
include_root_block=True,
@@ -282,6 +295,7 @@ def resnet_v2_50(inputs,
def resnet_v2_101(inputs,
num_classes=None,
global_pool=True,
+ is_training=None,
output_stride=None,
reuse=None,
scope='resnet_v2_101'):
@@ -296,6 +310,7 @@ def resnet_v2_101(inputs,
inputs,
blocks,
num_classes,
+ is_training,
global_pool,
output_stride,
include_root_block=True,
@@ -305,6 +320,7 @@ def resnet_v2_101(inputs,
def resnet_v2_152(inputs,
num_classes=None,
+ is_training=None,
global_pool=True,
output_stride=None,
reuse=None,
@@ -320,6 +336,7 @@ def resnet_v2_152(inputs,
inputs,
blocks,
num_classes,
+ is_training,
global_pool,
output_stride,
include_root_block=True,
@@ -329,6 +346,7 @@ def resnet_v2_152(inputs,
def resnet_v2_200(inputs,
num_classes=None,
+ is_training=None,
global_pool=True,
output_stride=None,
reuse=None,
@@ -344,6 +362,7 @@ def resnet_v2_200(inputs,
inputs,
blocks,
num_classes,
+ is_training,
global_pool,
output_stride,
include_root_block=True,
diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py
index 1c09bcbb5a..c4f3b071fd 100644
--- a/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py
+++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py
@@ -223,28 +223,29 @@ class ResnetUtilsTest(test.TestCase):
# Test both odd and even input dimensions.
height = 30
width = 31
- with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)):
- for output_stride in [1, 2, 4, 8, None]:
- with ops.Graph().as_default():
- with self.test_session() as sess:
- random_seed.set_random_seed(0)
- inputs = create_test_input(1, height, width, 3)
- # Dense feature extraction followed by subsampling.
- output = resnet_utils.stack_blocks_dense(inputs, blocks,
- output_stride)
- if output_stride is None:
- factor = 1
- else:
- factor = nominal_stride // output_stride
-
- output = resnet_utils.subsample(output, factor)
- # Make the two networks use the same weights.
- variable_scope.get_variable_scope().reuse_variables()
- # Feature extraction at the nominal network rate.
- expected = self._stack_blocks_nondense(inputs, blocks)
- sess.run(variables.global_variables_initializer())
- output, expected = sess.run([output, expected])
- self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4)
+ with arg_scope(resnet_utils.resnet_arg_scope()):
+ with arg_scope([layers.batch_norm], is_training=False):
+ for output_stride in [1, 2, 4, 8, None]:
+ with ops.Graph().as_default():
+ with self.test_session() as sess:
+ random_seed.set_random_seed(0)
+ inputs = create_test_input(1, height, width, 3)
+ # Dense feature extraction followed by subsampling.
+ output = resnet_utils.stack_blocks_dense(inputs, blocks,
+ output_stride)
+ if output_stride is None:
+ factor = 1
+ else:
+ factor = nominal_stride // output_stride
+
+ output = resnet_utils.subsample(output, factor)
+ # Make the two networks use the same weights.
+ variable_scope.get_variable_scope().reuse_variables()
+ # Feature extraction at the nominal network rate.
+ expected = self._stack_blocks_nondense(inputs, blocks)
+ sess.run(variables.global_variables_initializer())
+ output, expected = sess.run([output, expected])
+ self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4)
class ResnetCompleteNetworkTest(test.TestCase):
@@ -253,6 +254,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
def _resnet_small(self,
inputs,
num_classes=None,
+ is_training=None,
global_pool=True,
output_stride=None,
include_root_block=True,
@@ -266,8 +268,9 @@ class ResnetCompleteNetworkTest(test.TestCase):
block('block3', base_depth=4, num_units=3, stride=2),
block('block4', base_depth=8, num_units=2, stride=1),
]
- return resnet_v2.resnet_v2(inputs, blocks, num_classes, global_pool,
- output_stride, include_root_block, reuse, scope)
+ return resnet_v2.resnet_v2(inputs, blocks, num_classes, is_training,
+ global_pool, output_stride, include_root_block,
+ reuse, scope)
def testClassificationEndPoints(self):
global_pool = True
@@ -275,7 +278,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
inputs = create_test_input(2, 224, 224, 3)
with arg_scope(resnet_utils.resnet_arg_scope()):
logits, end_points = self._resnet_small(
- inputs, num_classes, global_pool, scope='resnet')
+ inputs, num_classes, global_pool=global_pool, scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
self.assertTrue('predictions' in end_points)
@@ -288,7 +291,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
inputs = create_test_input(2, 224, 224, 3)
with arg_scope(resnet_utils.resnet_arg_scope()):
_, end_points = self._resnet_small(
- inputs, num_classes, global_pool, scope='resnet')
+ inputs, num_classes, global_pool=global_pool, scope='resnet')
endpoint_to_shape = {
'resnet/block1': [2, 28, 28, 4],
'resnet/block2': [2, 14, 14, 8],
@@ -305,7 +308,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
inputs = create_test_input(2, 321, 321, 3)
with arg_scope(resnet_utils.resnet_arg_scope()):
_, end_points = self._resnet_small(
- inputs, num_classes, global_pool, scope='resnet')
+ inputs, num_classes, global_pool=global_pool, scope='resnet')
endpoint_to_shape = {
'resnet/block1': [2, 41, 41, 4],
'resnet/block2': [2, 21, 21, 8],
@@ -324,7 +327,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
_, end_points = self._resnet_small(
inputs,
num_classes,
- global_pool,
+ global_pool=global_pool,
include_root_block=False,
scope='resnet')
endpoint_to_shape = {
@@ -346,7 +349,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
_, end_points = self._resnet_small(
inputs,
num_classes,
- global_pool,
+ global_pool=global_pool,
output_stride=output_stride,
scope='resnet')
endpoint_to_shape = {
@@ -363,14 +366,18 @@ class ResnetCompleteNetworkTest(test.TestCase):
"""Verify dense feature extraction with atrous convolution."""
nominal_stride = 32
for output_stride in [4, 8, 16, 32, None]:
- with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)):
+ with arg_scope(resnet_utils.resnet_arg_scope()):
with ops.Graph().as_default():
with self.test_session() as sess:
random_seed.set_random_seed(0)
inputs = create_test_input(2, 81, 81, 3)
# Dense feature extraction followed by subsampling.
output, _ = self._resnet_small(
- inputs, None, global_pool=False, output_stride=output_stride)
+ inputs,
+ None,
+ is_training=False,
+ global_pool=False,
+ output_stride=output_stride)
if output_stride is None:
factor = 1
else:
@@ -379,7 +386,8 @@ class ResnetCompleteNetworkTest(test.TestCase):
# Make the two networks use the same weights.
variable_scope.get_variable_scope().reuse_variables()
# Feature extraction at the nominal network rate.
- expected, _ = self._resnet_small(inputs, None, global_pool=False)
+ expected, _ = self._resnet_small(
+ inputs, None, is_training=False, global_pool=False)
sess.run(variables.global_variables_initializer())
self.assertAllClose(
output.eval(), expected.eval(), atol=1e-4, rtol=1e-4)
@@ -392,7 +400,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
inputs = create_test_input(None, height, width, 3)
with arg_scope(resnet_utils.resnet_arg_scope()):
logits, _ = self._resnet_small(
- inputs, num_classes, global_pool, scope='resnet')
+ inputs, num_classes, global_pool=global_pool, scope='resnet')
self.assertTrue(logits.op.name.startswith('resnet/logits'))
self.assertListEqual(logits.get_shape().as_list(),
[None, 1, 1, num_classes])
@@ -408,7 +416,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
global_pool = False
inputs = create_test_input(batch, None, None, 3)
with arg_scope(resnet_utils.resnet_arg_scope()):
- output, _ = self._resnet_small(inputs, None, global_pool)
+ output, _ = self._resnet_small(inputs, None, global_pool=global_pool)
self.assertListEqual(output.get_shape().as_list(), [batch, None, None, 32])
images = create_test_input(batch, height, width, 3)
with self.test_session() as sess:
@@ -424,7 +432,7 @@ class ResnetCompleteNetworkTest(test.TestCase):
inputs = create_test_input(batch, None, None, 3)
with arg_scope(resnet_utils.resnet_arg_scope()):
output, _ = self._resnet_small(
- inputs, None, global_pool, output_stride=output_stride)
+ inputs, None, global_pool=global_pool, output_stride=output_stride)
self.assertListEqual(output.get_shape().as_list(), [batch, None, None, 32])
images = create_test_input(batch, height, width, 3)
with self.test_session() as sess: