aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sergio Guadarrama <sguada@google.com>2016-07-19 16:14:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-19 17:18:36 -0700
commit15afaf7b2ed6aa77b62dde0007afffaffc6a051d (patch)
treedaded573d0cee9df2fdb20a64c41ba98416debbb
parent496f3a48d900b84006e8dba697be6252328cbf98 (diff)
Add common nets to tf.slim
Change: 127893189
-rw-r--r--tensorflow/contrib/slim/BUILD7
-rw-r--r--tensorflow/contrib/slim/__init__.py3
-rw-r--r--tensorflow/contrib/slim/python/slim/model_analyzer.py104
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/BUILD70
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/__init__.py23
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/alexnet.py119
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/alexnet_test.py144
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/overfeat.py112
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/overfeat_test.py145
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/vgg.py159
-rw-r--r--tensorflow/contrib/slim/python/slim/nets/vgg_test.py300
11 files changed, 1185 insertions, 1 deletions
diff --git a/tensorflow/contrib/slim/BUILD b/tensorflow/contrib/slim/BUILD
index 96dbe98158..7c4a1d036d 100644
--- a/tensorflow/contrib/slim/BUILD
+++ b/tensorflow/contrib/slim/BUILD
@@ -55,6 +55,12 @@ py_test(
)
py_library(
+ name = "model_analyzer",
+ srcs = ["python/slim/model_analyzer.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_library(
name = "queues",
srcs = ["python/slim/queues.py"],
srcs_version = "PY2AND3",
@@ -74,6 +80,7 @@ py_library(
deps = [
":evaluation",
":learning",
+ ":model_analyzer",
":queues",
"//tensorflow/contrib/slim/python/slim/data",
],
diff --git a/tensorflow/contrib/slim/__init__.py b/tensorflow/contrib/slim/__init__.py
index 7f7bc8302b..30e69456e4 100644
--- a/tensorflow/contrib/slim/__init__.py
+++ b/tensorflow/contrib/slim/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2015 Google Inc. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,6 +31,7 @@ from tensorflow.contrib.layers.python.layers.initializers import *
from tensorflow.contrib.layers.python.layers.regularizers import *
from tensorflow.contrib.slim.python.slim import evaluation
from tensorflow.contrib.slim.python.slim import learning
+from tensorflow.contrib.slim.python.slim import model_analyzer
from tensorflow.contrib.slim.python.slim import queues
from tensorflow.contrib.slim.python.slim.data import data_decoder
from tensorflow.contrib.slim.python.slim.data import data_provider
diff --git a/tensorflow/contrib/slim/python/slim/model_analyzer.py b/tensorflow/contrib/slim/python/slim/model_analyzer.py
new file mode 100644
index 0000000000..73659be4a7
--- /dev/null
+++ b/tensorflow/contrib/slim/python/slim/model_analyzer.py
@@ -0,0 +1,104 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tools for analyzing the operations and variables in a TensorFlow graph.
+
+To analyze the operations in a graph:
+
+ images, labels = LoadData(...)
+ predictions = MyModel(images)
+
+ slim.model_analyzer.analyze_ops(tf.get_default_graph(), print_info=True)
+
+To analyze the model variables in a graph:
+
+ variables = tf.model_variables()
+ slim.model_analyzer.analyze_vars(variables, print_info=False)
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+def tensor_description(var):
+ """Returns a compact and informative string about a tensor.
+
+ Args:
+ var: A tensor variable.
+
+ Returns:
+ a string with type and size, e.g.: (float32 1x8x8x1024).
+ """
+ description = '(' + str(var.dtype.name) + ' '
+ sizes = var.get_shape()
+ for i, size in enumerate(sizes):
+ description += str(size)
+ if i < len(sizes) - 1:
+ description += 'x'
+ description += ')'
+ return description
+
+
+def analyze_ops(graph, print_info=False):
+ """Compute the estimated size of the ops.outputs in the graph.
+
+ Args:
+ graph: the graph containing the operations.
+ print_info: Optional, if true print ops and their outputs.
+
+ Returns:
+ total size of the ops.outputs
+ """
+ if print_info:
+ print('---------')
+ print('Operations: name -> (type shapes) [size]')
+ print('---------')
+ total_size = 0
+ for op in graph.get_operations():
+ op_size = 0
+ shapes = []
+ for output in op.outputs:
+ # if output.num_elements() is None or [] assume size 0.
+ output_size = output.get_shape().num_elements() or 0
+ if output.get_shape():
+ shapes.append(tensor_description(output))
+ op_size += output_size
+ if print_info:
+ print(op.name, '\t->', ', '.join(shapes), '[' + str(op_size) + ']')
+ total_size += op_size
+ return total_size
+
+
+def analyze_vars(variables, print_info=False):
+ """Prints the names and shapes of the variables.
+
+ Args:
+ variables: list of variables, for example tf.all_variables().
+ print_info: Optional, if true print variables and their shape.
+
+ Returns:
+ total size of the variables.
+ """
+ if print_info:
+ print('---------')
+ print('Variables: name (type shape) [size]')
+ print('---------')
+ total_size = 0
+ for var in variables:
+ # if var.num_elements() is None or [] assume size 0.
+ var_size = var.get_shape().num_elements() or 0
+ total_size += var_size
+ if print_info:
+ print(var.name, tensor_description(var), '[' + str(var_size) + ']')
+ return total_size
diff --git a/tensorflow/contrib/slim/python/slim/nets/BUILD b/tensorflow/contrib/slim/python/slim/nets/BUILD
new file mode 100644
index 0000000000..7a70de6db6
--- /dev/null
+++ b/tensorflow/contrib/slim/python/slim/nets/BUILD
@@ -0,0 +1,70 @@
+# Description:
+# Contains typical networks definitions.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+py_library(
+ name = "nets",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":alexnet",
+ ":overfeat",
+ ":vgg",
+ ],
+)
+
+py_library(
+ name = "alexnet",
+ srcs = ["alexnet.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "alexnet_test",
+ size = "medium",
+ srcs = ["alexnet_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":alexnet",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "overfeat",
+ srcs = ["overfeat.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "overfeat_test",
+ size = "medium",
+ srcs = ["overfeat_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":overfeat",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "vgg",
+ srcs = ["vgg.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "vgg_test",
+ size = "medium",
+ srcs = ["vgg_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":vgg",
+ "//tensorflow:tensorflow_py",
+ ],
+)
diff --git a/tensorflow/contrib/slim/python/slim/nets/__init__.py b/tensorflow/contrib/slim/python/slim/nets/__init__.py
new file mode 100644
index 0000000000..f04f805b00
--- /dev/null
+++ b/tensorflow/contrib/slim/python/slim/nets/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TF-Slim Nets."""
+# pylint: disable=unused-import,
+# pylint: disable=wildcard-import
+
+# Collapse nets into a single namespace.
+from tensorflow.contrib.slim.python.slim.nets import alexnet
+from tensorflow.contrib.slim.python.slim.nets import overfeat
+from tensorflow.contrib.slim.python.slim.nets import vgg
+# pylint: enable=unused-import,wildcard-import
diff --git a/tensorflow/contrib/slim/python/slim/nets/alexnet.py b/tensorflow/contrib/slim/python/slim/nets/alexnet.py
new file mode 100644
index 0000000000..14497cd9a1
--- /dev/null
+++ b/tensorflow/contrib/slim/python/slim/nets/alexnet.py
@@ -0,0 +1,119 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains a model definition for AlexNet.
+
+This work was first described in:
+ ImageNet Classification with Deep Convolutional Neural Networks
+ Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton
+
+and later refined in:
+ One weird trick for parallelizing convolutional neural networks
+ Alex Krizhevsky, 2014
+
+Here we provide the implementation proposed in "One weird trick" and not
+"ImageNet Classification", as per the paper, the LRN layers have been removed.
+
+Usage:
+ with slim.arg_scope(alexnet.alexnet_v2_arg_scope()):
+ outputs, end_points = alexnet.alexnet_v2(inputs)
+
+"""
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
+
+
+def alexnet_v2_arg_scope(weight_decay=0.0005):
+ with slim.arg_scope([slim.conv2d, slim.fully_connected],
+ activation_fn=tf.nn.relu,
+ biases_initializer=tf.constant_initializer(0.1),
+ weights_regularizer=slim.l2_regularizer(weight_decay)):
+ with slim.arg_scope([slim.conv2d], padding='SAME'):
+ with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
+ return arg_sc
+
+
+def alexnet_v2(inputs,
+ num_classes=1000,
+ dropout_keep_prob=0.5,
+ is_training=True,
+ spatial_squeeze=True,
+ scope='alexnet_v2'):
+ """AlexNet version 2.
+
+ Described in: http://arxiv.org/pdf/1404.5997v2.pdf
+ Parameters from:
+ github.com/akrizhevsky/cuda-convnet2/blob/master/layers/
+ layers-imagenet-1gpu.cfg
+
+ Note: All the fully_connected layers have been transformed to conv2d layers.
+ To use in classification mode, resize input to 224x224. To use in fully
+ convolutional mode, set spatial_squeeze to false.
+ The LRN layers have been removed and change the initializers from
+ random_normal_initializer to xavier_initializer.
+
+ Args:
+ inputs: a tensor of size [batch_size, height, width, channels].
+ num_classes: number of predicted classes.
+ dropout_keep_prob: the probability that activations are kept in the dropout
+ layers during training.
+ is_training: whether or not the model is being trained.
+ spatial_squeeze: whether or not should squeeze the spatial dimensions of the
+ outputs. Useful to remove unnecessary dimensions for classification.
+ scope: Optional scope for the variables.
+
+ Returns:
+ the last op containing the log predictions and end_points dict.
+ """
+ with tf.variable_op_scope([inputs], scope, 'alexnet_v2') as sc:
+ end_points_collection = sc.name + '_end_points'
+ # Collect outputs for conv2d, fully_connected and max_pool2d.
+ with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
+ outputs_collections=[end_points_collection]):
+ net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
+ scope='conv1')
+ net = slim.max_pool2d(net, [3, 3], 2, scope='pool1')
+ net = slim.conv2d(net, 192, [5, 5], scope='conv2')
+ net = slim.max_pool2d(net, [3, 3], 2, scope='pool2')
+ net = slim.conv2d(net, 384, [3, 3], scope='conv3')
+ net = slim.conv2d(net, 384, [3, 3], scope='conv4')
+ net = slim.conv2d(net, 256, [3, 3], scope='conv5')
+ net = slim.max_pool2d(net, [3, 3], 2, scope='pool5')
+
+ # Use conv2d instead of fully_connected layers.
+ with slim.arg_scope([slim.conv2d],
+ weights_initializer=trunc_normal(0.005),
+ biases_initializer=tf.constant_initializer(0.1)):
+ net = slim.conv2d(net, 4096, [5, 5], padding='VALID',
+ scope='fc6')
+ net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+ scope='dropout6')
+ net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
+ net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+ scope='dropout7')
+ net = slim.conv2d(net, num_classes, [1, 1],
+ activation_fn=None,
+ normalizer_fn=None,
+ biases_initializer=tf.zeros_initializer,
+ scope='fc8')
+
+ # Convert end_points_collection into a end_point dict.
+ end_points = dict(tf.get_collection(end_points_collection))
+ if spatial_squeeze:
+ net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
+ end_points[sc.name + '/fc8'] = net
+ return net, end_points
diff --git a/tensorflow/contrib/slim/python/slim/nets/alexnet_test.py b/tensorflow/contrib/slim/python/slim/nets/alexnet_test.py
new file mode 100644
index 0000000000..346ffa09f2
--- /dev/null
+++ b/tensorflow/contrib/slim/python/slim/nets/alexnet_test.py
@@ -0,0 +1,144 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slim.nets.alexnet."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.slim.python.slim.nets import alexnet
+slim = tf.contrib.slim
+
+
+class AlexnetV2Test(tf.test.TestCase):
+
+ def testBuild(self):
+ batch_size = 5
+ height, width = 224, 224
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = alexnet.alexnet_v2(inputs, num_classes)
+ self.assertEquals(logits.op.name, 'alexnet_v2/fc8/squeezed')
+ self.assertListEqual(logits.get_shape().as_list(),
+ [batch_size, num_classes])
+
+ def testFullyConvolutional(self):
+ batch_size = 1
+ height, width = 300, 400
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False)
+ self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd')
+ self.assertListEqual(logits.get_shape().as_list(),
+ [batch_size, 4, 7, num_classes])
+
+ def testEndPoints(self):
+ batch_size = 5
+ height, width = 224, 224
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ _, end_points = alexnet.alexnet_v2(inputs, num_classes)
+ expected_names = ['alexnet_v2/conv1',
+ 'alexnet_v2/pool1',
+ 'alexnet_v2/conv2',
+ 'alexnet_v2/pool2',
+ 'alexnet_v2/conv3',
+ 'alexnet_v2/conv4',
+ 'alexnet_v2/conv5',
+ 'alexnet_v2/pool5',
+ 'alexnet_v2/fc6',
+ 'alexnet_v2/fc7',
+ 'alexnet_v2/fc8'
+ ]
+ self.assertSetEqual(set(end_points.keys()), set(expected_names))
+
+ def testModelVariables(self):
+ batch_size = 5
+ height, width = 224, 224
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ alexnet.alexnet_v2(inputs, num_classes)
+ expected_names = ['alexnet_v2/conv1/weights',
+ 'alexnet_v2/conv1/biases',
+ 'alexnet_v2/conv2/weights',
+ 'alexnet_v2/conv2/biases',
+ 'alexnet_v2/conv3/weights',
+ 'alexnet_v2/conv3/biases',
+ 'alexnet_v2/conv4/weights',
+ 'alexnet_v2/conv4/biases',
+ 'alexnet_v2/conv5/weights',
+ 'alexnet_v2/conv5/biases',
+ 'alexnet_v2/fc6/weights',
+ 'alexnet_v2/fc6/biases',
+ 'alexnet_v2/fc7/weights',
+ 'alexnet_v2/fc7/biases',
+ 'alexnet_v2/fc8/weights',
+ 'alexnet_v2/fc8/biases',
+ ]
+ model_variables = [v.op.name for v in slim.get_model_variables()]
+ self.assertSetEqual(set(model_variables), set(expected_names))
+
+ def testEvaluation(self):
+ batch_size = 2
+ height, width = 224, 224
+ num_classes = 1000
+ with self.test_session():
+ eval_inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False)
+ self.assertListEqual(logits.get_shape().as_list(),
+ [batch_size, num_classes])
+ predictions = tf.argmax(logits, 1)
+ self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
+
+ def testTrainEvalWithReuse(self):
+ train_batch_size = 2
+ eval_batch_size = 1
+ train_height, train_width = 224, 224
+ eval_height, eval_width = 300, 400
+ num_classes = 1000
+ with self.test_session():
+ train_inputs = tf.random_uniform(
+ (train_batch_size, train_height, train_width, 3))
+ logits, _ = alexnet.alexnet_v2(train_inputs)
+ self.assertListEqual(logits.get_shape().as_list(),
+ [train_batch_size, num_classes])
+ tf.get_variable_scope().reuse_variables()
+ eval_inputs = tf.random_uniform(
+ (eval_batch_size, eval_height, eval_width, 3))
+ logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False,
+ spatial_squeeze=False)
+ self.assertListEqual(logits.get_shape().as_list(),
+ [eval_batch_size, 4, 7, num_classes])
+ logits = tf.reduce_mean(logits, [1, 2])
+ predictions = tf.argmax(logits, 1)
+ self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
+
+ def testForward(self):
+ batch_size = 1
+ height, width = 224, 224
+ with self.test_session() as sess:
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = alexnet.alexnet_v2(inputs)
+ sess.run(tf.initialize_all_variables())
+ output = sess.run(logits)
+ self.assertTrue(output.any())
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/slim/python/slim/nets/overfeat.py b/tensorflow/contrib/slim/python/slim/nets/overfeat.py
new file mode 100644
index 0000000000..a0c2033934
--- /dev/null
+++ b/tensorflow/contrib/slim/python/slim/nets/overfeat.py
@@ -0,0 +1,112 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains the model definition for the OverFeat network.
+
+The definition for the network was obtained from:
+ OverFeat: Integrated Recognition, Localization and Detection using
+ Convolutional Networks
+ Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and
+ Yann LeCun, 2014
+ http://arxiv.org/abs/1312.6229
+
+Usage:
+ with slim.arg_scope(overfeat.overfeat_arg_scope()):
+ outputs, end_points = overfeat.overfeat(inputs)
+"""
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
+
+
+def overfeat_arg_scope(weight_decay=0.0005):
+ with slim.arg_scope([slim.conv2d, slim.fully_connected],
+ activation_fn=tf.nn.relu,
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ biases_initializer=tf.zeros_initializer):
+ with slim.arg_scope([slim.conv2d], padding='SAME'):
+ with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
+ return arg_sc
+
+
+def overfeat(inputs,
+ num_classes=1000,
+ dropout_keep_prob=0.5,
+ is_training=True,
+ spatial_squeeze=True,
+ scope='overfeat'):
+ """Contains the model definition for the OverFeat network.
+
+ The definition for the network was obtained from:
+ OverFeat: Integrated Recognition, Localization and Detection using
+ Convolutional Networks
+ Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and
+ Yann LeCun, 2014
+ http://arxiv.org/abs/1312.6229
+
+ Note: All the fully_connected layers have been transformed to conv2d layers.
+ To use in classification mode, resize input to 231x231. To use in fully
+ convolutional mode, set spatial_squeeze to false.
+
+ Args:
+ inputs: a tensor of size [batch_size, height, width, channels].
+ num_classes: number of predicted classes.
+ dropout_keep_prob: the probability that activations are kept in the dropout
+ layers during training.
+ is_training: whether or not the model is being trained.
+ spatial_squeeze: whether or not should squeeze the spatial dimensions of the
+ outputs. Useful to remove unnecessary dimensions for classification.
+ scope: Optional scope for the variables.
+
+ Returns:
+ the last op containing the log predictions and end_points dict.
+
+ """
+ with tf.variable_op_scope([inputs], scope, 'overfeat') as sc:
+ end_points_collection = sc.name + '_end_points'
+ # Collect outputs for conv2d, fully_connected and max_pool2d
+ with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
+ outputs_collections=end_points_collection):
+ net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
+ scope='conv1')
+ net = slim.max_pool2d(net, [2, 2], scope='pool1')
+ net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2')
+ net = slim.max_pool2d(net, [2, 2], scope='pool2')
+ net = slim.conv2d(net, 512, [3, 3], scope='conv3')
+ net = slim.conv2d(net, 1024, [3, 3], scope='conv4')
+ net = slim.conv2d(net, 1024, [3, 3], scope='conv5')
+ net = slim.max_pool2d(net, [2, 2], scope='pool5')
+ with slim.arg_scope([slim.conv2d],
+ weights_initializer=trunc_normal(0.005),
+ biases_initializer=tf.constant_initializer(0.1)):
+ # Use conv2d instead of fully_connected layers.
+ net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6')
+ net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+ scope='dropout6')
+ net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
+ net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+ scope='dropout7')
+ net = slim.conv2d(net, num_classes, [1, 1],
+ activation_fn=None,
+ normalizer_fn=None,
+ biases_initializer=tf.zeros_initializer,
+ scope='fc8')
+ # Convert end_points_collection into a end_point dict.
+ end_points = dict(tf.get_collection(end_points_collection))
+ if spatial_squeeze:
+ net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
+ end_points[sc.name + '/fc8'] = net
+ return net, end_points
diff --git a/tensorflow/contrib/slim/python/slim/nets/overfeat_test.py b/tensorflow/contrib/slim/python/slim/nets/overfeat_test.py
new file mode 100644
index 0000000000..f26e03c00e
--- /dev/null
+++ b/tensorflow/contrib/slim/python/slim/nets/overfeat_test.py
@@ -0,0 +1,145 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slim.nets.overfeat."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.slim.python.slim.nets import overfeat
+
+slim = tf.contrib.slim
+
+
+class OverFeatTest(tf.test.TestCase):
+
+ def testBuild(self):
+ batch_size = 5
+ height, width = 231, 231
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = overfeat.overfeat(inputs, num_classes)
+ self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed')
+ self.assertListEqual(logits.get_shape().as_list(),
+ [batch_size, num_classes])
+
+ def testFullyConvolutional(self):
+ batch_size = 1
+ height, width = 281, 281
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False)
+ self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd')
+ self.assertListEqual(logits.get_shape().as_list(),
+ [batch_size, 2, 2, num_classes])
+
+ def testEndPoints(self):
+ batch_size = 5
+ height, width = 231, 231
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ _, end_points = overfeat.overfeat(inputs, num_classes)
+ expected_names = ['overfeat/conv1',
+ 'overfeat/pool1',
+ 'overfeat/conv2',
+ 'overfeat/pool2',
+ 'overfeat/conv3',
+ 'overfeat/conv4',
+ 'overfeat/conv5',
+ 'overfeat/pool5',
+ 'overfeat/fc6',
+ 'overfeat/fc7',
+ 'overfeat/fc8'
+ ]
+ self.assertSetEqual(set(end_points.keys()), set(expected_names))
+
+ def testModelVariables(self):
+ batch_size = 5
+ height, width = 231, 231
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ overfeat.overfeat(inputs, num_classes)
+ expected_names = ['overfeat/conv1/weights',
+ 'overfeat/conv1/biases',
+ 'overfeat/conv2/weights',
+ 'overfeat/conv2/biases',
+ 'overfeat/conv3/weights',
+ 'overfeat/conv3/biases',
+ 'overfeat/conv4/weights',
+ 'overfeat/conv4/biases',
+ 'overfeat/conv5/weights',
+ 'overfeat/conv5/biases',
+ 'overfeat/fc6/weights',
+ 'overfeat/fc6/biases',
+ 'overfeat/fc7/weights',
+ 'overfeat/fc7/biases',
+ 'overfeat/fc8/weights',
+ 'overfeat/fc8/biases',
+ ]
+ model_variables = [v.op.name for v in slim.get_model_variables()]
+ self.assertSetEqual(set(model_variables), set(expected_names))
+
+ def testEvaluation(self):
+ batch_size = 2
+ height, width = 231, 231
+ num_classes = 1000
+ with self.test_session():
+ eval_inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = overfeat.overfeat(eval_inputs, is_training=False)
+ self.assertListEqual(logits.get_shape().as_list(),
+ [batch_size, num_classes])
+ predictions = tf.argmax(logits, 1)
+ self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
+
+ def testTrainEvalWithReuse(self):
+ train_batch_size = 2
+ eval_batch_size = 1
+ train_height, train_width = 231, 231
+ eval_height, eval_width = 281, 281
+ num_classes = 1000
+ with self.test_session():
+ train_inputs = tf.random_uniform(
+ (train_batch_size, train_height, train_width, 3))
+ logits, _ = overfeat.overfeat(train_inputs)
+ self.assertListEqual(logits.get_shape().as_list(),
+ [train_batch_size, num_classes])
+ tf.get_variable_scope().reuse_variables()
+ eval_inputs = tf.random_uniform(
+ (eval_batch_size, eval_height, eval_width, 3))
+ logits, _ = overfeat.overfeat(eval_inputs, is_training=False,
+ spatial_squeeze=False)
+ self.assertListEqual(logits.get_shape().as_list(),
+ [eval_batch_size, 2, 2, num_classes])
+ logits = tf.reduce_mean(logits, [1, 2])
+ predictions = tf.argmax(logits, 1)
+ self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
+
+ def testForward(self):
+ batch_size = 1
+ height, width = 231, 231
+ with self.test_session() as sess:
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = overfeat.overfeat(inputs)
+ sess.run(tf.initialize_all_variables())
+ output = sess.run(logits)
+ self.assertTrue(output.any())
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/slim/python/slim/nets/vgg.py b/tensorflow/contrib/slim/python/slim/nets/vgg.py
new file mode 100644
index 0000000000..892064b051
--- /dev/null
+++ b/tensorflow/contrib/slim/python/slim/nets/vgg.py
@@ -0,0 +1,159 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains model definitions for versions of the VGG network.
+
+These model definitions were introduced in the technical report:
+ Very Deep Convolutional Networks For Large-Scale Image Recognition
+ Karen Simonyan and Andrew Zisserman, ICLR 2015
+
+More information can be obtained from the VGG website:
+www.robots.ox.ac.uk/~vgg/research/very_deep/
+
+Usage:
+ with slim.arg_scope(vgg.vgg_arg_scope()):
+ outputs, end_points = vgg.vgg_a(inputs)
+
+ with slim.arg_scope(vgg.vgg_arg_scope()):
+ outputs, end_points = vgg.vgg_16(inputs)
+"""
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+
+
+def vgg_arg_scope(weight_decay=0.0005):
+ with slim.arg_scope([slim.conv2d, slim.fully_connected],
+ activation_fn=tf.nn.relu,
+ weights_regularizer=slim.l2_regularizer(weight_decay),
+ biases_initializer=tf.zeros_initializer):
+ with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc:
+ return arg_sc
+
+
+def vgg_a(inputs,
+ num_classes=1000,
+ dropout_keep_prob=0.5,
+ is_training=True,
+ spatial_squeeze=True,
+ scope='vgg_a'):
+ """Oxford Net VGG 11-Layers version A Example.
+
+ Note: All the fully_connected layers have been transformed to conv2d layers.
+ To use in classification mode, resize input to 224x224.
+
+ Args:
+ inputs: a tensor of size [batch_size, height, width, channels].
+ num_classes: number of predicted classes.
+ dropout_keep_prob: the probability that activations are kept in the dropout
+ layers during training.
+ is_training: whether or not the model is being trained.
+ spatial_squeeze: whether or not should squeeze the spatial dimensions of the
+ outputs. Useful to remove unnecessary dimensions for classification.
+ scope: Optional scope for the variables.
+
+ Returns:
+ the last op containing the log predictions and end_points dict.
+ """
+ with tf.variable_op_scope([inputs], scope, 'vgg_a') as sc:
+ end_points_collection = sc.name + '_end_points'
+ # Collect outputs for conv2d, fully_connected and max_pool2d.
+ with slim.arg_scope([slim.conv2d, slim.max_pool2d],
+ outputs_collections=end_points_collection):
+ net = slim.repeat(inputs, 1, slim.conv2d, 64, [3, 3], scope='conv1')
+ net = slim.max_pool2d(net, [2, 2], scope='pool1')
+ net = slim.repeat(net, 1, slim.conv2d, 128, [3, 3], scope='conv2')
+ net = slim.max_pool2d(net, [2, 2], scope='pool2')
+ net = slim.repeat(net, 2, slim.conv2d, 256, [3, 3], scope='conv3')
+ net = slim.max_pool2d(net, [2, 2], scope='pool3')
+ net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv4')
+ net = slim.max_pool2d(net, [2, 2], scope='pool4')
+ net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv5')
+ net = slim.max_pool2d(net, [2, 2], scope='pool5')
+ # Use conv2d instead of fully_connected layers.
+ net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
+ net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+ scope='dropout6')
+ net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
+ net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+ scope='dropout7')
+ net = slim.conv2d(net, num_classes, [1, 1],
+ activation_fn=None,
+ normalizer_fn=None,
+ scope='fc8')
+ # Convert end_points_collection into a end_point dict.
+ end_points = dict(tf.get_collection(end_points_collection))
+ if spatial_squeeze:
+ net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
+ end_points[sc.name + '/fc8'] = net
+ return net, end_points
+
+
+def vgg_16(inputs,
+ num_classes=1000,
+ dropout_keep_prob=0.5,
+ is_training=True,
+ spatial_squeeze=True,
+ scope='vgg_16'):
+ """Oxford Net VGG 16-Layers version D Example.
+
+ Note: All the fully_connected layers have been transformed to conv2d layers.
+ To use in classification mode, resize input to 224x224.
+
+ Args:
+ inputs: a tensor of size [batch_size, height, width, channels].
+ num_classes: number of predicted classes.
+ dropout_keep_prob: the probability that activations are kept in the dropout
+ layers during training.
+ is_training: whether or not the model is being trained.
+ spatial_squeeze: whether or not should squeeze the spatial dimensions of the
+ outputs. Useful to remove unnecessary dimensions for classification.
+ scope: Optional scope for the variables.
+
+ Returns:
+ the last op containing the log predictions and end_points dict.
+ """
+ with tf.variable_op_scope([inputs], scope, 'vgg_16') as sc:
+ end_points_collection = sc.name + '_end_points'
+ # Collect outputs for conv2d, fully_connected and max_pool2d.
+ with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
+ outputs_collections=end_points_collection):
+ net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
+ net = slim.max_pool2d(net, [2, 2], scope='pool1')
+ net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
+ net = slim.max_pool2d(net, [2, 2], scope='pool2')
+ net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
+ net = slim.max_pool2d(net, [2, 2], scope='pool3')
+ net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
+ net = slim.max_pool2d(net, [2, 2], scope='pool4')
+ net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')
+ net = slim.max_pool2d(net, [2, 2], scope='pool5')
+ # Use conv2d instead of fully_connected layers.
+ net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
+ net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+ scope='dropout6')
+ net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
+ net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+ scope='dropout7')
+ net = slim.conv2d(net, num_classes, [1, 1],
+ activation_fn=None,
+ normalizer_fn=None,
+ scope='fc8')
+ # Convert end_points_collection into a end_point dict.
+ end_points = dict(tf.get_collection(end_points_collection))
+ if spatial_squeeze:
+ net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
+ end_points[sc.name + '/fc8'] = net
+ return net, end_points
diff --git a/tensorflow/contrib/slim/python/slim/nets/vgg_test.py b/tensorflow/contrib/slim/python/slim/nets/vgg_test.py
new file mode 100644
index 0000000000..81e4d47d4d
--- /dev/null
+++ b/tensorflow/contrib/slim/python/slim/nets/vgg_test.py
@@ -0,0 +1,300 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slim.nets.vgg."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.slim.python.slim.nets import vgg
+slim = tf.contrib.slim
+
+
+class VGGATest(tf.test.TestCase):
+
+ def testBuild(self):
+ batch_size = 5
+ height, width = 224, 224
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = vgg.vgg_a(inputs, num_classes)
+ self.assertEquals(logits.op.name, 'vgg_a/fc8/squeezed')
+ self.assertListEqual(logits.get_shape().as_list(),
+ [batch_size, num_classes])
+
+ def testFullyConvolutional(self):
+ batch_size = 1
+ height, width = 256, 256
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = vgg.vgg_a(inputs, num_classes, spatial_squeeze=False)
+ self.assertEquals(logits.op.name, 'vgg_a/fc8/BiasAdd')
+ self.assertListEqual(logits.get_shape().as_list(),
+ [batch_size, 2, 2, num_classes])
+
+ def testEndPoints(self):
+ batch_size = 5
+ height, width = 224, 224
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ _, end_points = vgg.vgg_a(inputs, num_classes)
+ expected_names = ['vgg_a/conv1/conv1_1',
+ 'vgg_a/pool1',
+ 'vgg_a/conv2/conv2_1',
+ 'vgg_a/pool2',
+ 'vgg_a/conv3/conv3_1',
+ 'vgg_a/conv3/conv3_2',
+ 'vgg_a/pool3',
+ 'vgg_a/conv4/conv4_1',
+ 'vgg_a/conv4/conv4_2',
+ 'vgg_a/pool4',
+ 'vgg_a/conv5/conv5_1',
+ 'vgg_a/conv5/conv5_2',
+ 'vgg_a/pool5',
+ 'vgg_a/fc6',
+ 'vgg_a/fc7',
+ 'vgg_a/fc8'
+ ]
+ self.assertSetEqual(set(end_points.keys()), set(expected_names))
+
+ def testModelVariables(self):
+ batch_size = 5
+ height, width = 224, 224
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ vgg.vgg_a(inputs, num_classes)
+ expected_names = ['vgg_a/conv1/conv1_1/weights',
+ 'vgg_a/conv1/conv1_1/biases',
+ 'vgg_a/conv2/conv2_1/weights',
+ 'vgg_a/conv2/conv2_1/biases',
+ 'vgg_a/conv3/conv3_1/weights',
+ 'vgg_a/conv3/conv3_1/biases',
+ 'vgg_a/conv3/conv3_2/weights',
+ 'vgg_a/conv3/conv3_2/biases',
+ 'vgg_a/conv4/conv4_1/weights',
+ 'vgg_a/conv4/conv4_1/biases',
+ 'vgg_a/conv4/conv4_2/weights',
+ 'vgg_a/conv4/conv4_2/biases',
+ 'vgg_a/conv5/conv5_1/weights',
+ 'vgg_a/conv5/conv5_1/biases',
+ 'vgg_a/conv5/conv5_2/weights',
+ 'vgg_a/conv5/conv5_2/biases',
+ 'vgg_a/fc6/weights',
+ 'vgg_a/fc6/biases',
+ 'vgg_a/fc7/weights',
+ 'vgg_a/fc7/biases',
+ 'vgg_a/fc8/weights',
+ 'vgg_a/fc8/biases',
+ ]
+ model_variables = [v.op.name for v in slim.get_model_variables()]
+ self.assertSetEqual(set(model_variables), set(expected_names))
+
+ def testEvaluation(self):
+ batch_size = 2
+ height, width = 224, 224
+ num_classes = 1000
+ with self.test_session():
+ eval_inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = vgg.vgg_a(eval_inputs, is_training=False)
+ self.assertListEqual(logits.get_shape().as_list(),
+ [batch_size, num_classes])
+ predictions = tf.argmax(logits, 1)
+ self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
+
+ def testTrainEvalWithReuse(self):
+ train_batch_size = 2
+ eval_batch_size = 1
+ train_height, train_width = 224, 224
+ eval_height, eval_width = 256, 256
+ num_classes = 1000
+ with self.test_session():
+ train_inputs = tf.random_uniform(
+ (train_batch_size, train_height, train_width, 3))
+ logits, _ = vgg.vgg_a(train_inputs)
+ self.assertListEqual(logits.get_shape().as_list(),
+ [train_batch_size, num_classes])
+ tf.get_variable_scope().reuse_variables()
+ eval_inputs = tf.random_uniform(
+ (eval_batch_size, eval_height, eval_width, 3))
+ logits, _ = vgg.vgg_a(eval_inputs, is_training=False,
+ spatial_squeeze=False)
+ self.assertListEqual(logits.get_shape().as_list(),
+ [eval_batch_size, 2, 2, num_classes])
+ logits = tf.reduce_mean(logits, [1, 2])
+ predictions = tf.argmax(logits, 1)
+ self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
+
+ def testForward(self):
+ batch_size = 1
+ height, width = 224, 224
+ with self.test_session() as sess:
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = vgg.vgg_a(inputs)
+ sess.run(tf.initialize_all_variables())
+ output = sess.run(logits)
+ self.assertTrue(output.any())
+
+
+class VGG16Test(tf.test.TestCase):
+
+ def testBuild(self):
+ batch_size = 5
+ height, width = 224, 224
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = vgg.vgg_16(inputs, num_classes)
+ self.assertEquals(logits.op.name, 'vgg_16/fc8/squeezed')
+ self.assertListEqual(logits.get_shape().as_list(),
+ [batch_size, num_classes])
+
+ def testFullyConvolutional(self):
+ batch_size = 1
+ height, width = 256, 256
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = vgg.vgg_16(inputs, num_classes, spatial_squeeze=False)
+ self.assertEquals(logits.op.name, 'vgg_16/fc8/BiasAdd')
+ self.assertListEqual(logits.get_shape().as_list(),
+ [batch_size, 2, 2, num_classes])
+
+ def testEndPoints(self):
+ batch_size = 5
+ height, width = 224, 224
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ _, end_points = vgg.vgg_16(inputs, num_classes)
+ expected_names = ['vgg_16/conv1/conv1_1',
+ 'vgg_16/conv1/conv1_2',
+ 'vgg_16/pool1',
+ 'vgg_16/conv2/conv2_1',
+ 'vgg_16/conv2/conv2_2',
+ 'vgg_16/pool2',
+ 'vgg_16/conv3/conv3_1',
+ 'vgg_16/conv3/conv3_2',
+ 'vgg_16/conv3/conv3_3',
+ 'vgg_16/pool3',
+ 'vgg_16/conv4/conv4_1',
+ 'vgg_16/conv4/conv4_2',
+ 'vgg_16/conv4/conv4_3',
+ 'vgg_16/pool4',
+ 'vgg_16/conv5/conv5_1',
+ 'vgg_16/conv5/conv5_2',
+ 'vgg_16/conv5/conv5_3',
+ 'vgg_16/pool5',
+ 'vgg_16/fc6',
+ 'vgg_16/fc7',
+ 'vgg_16/fc8'
+ ]
+ print(end_points.keys())
+ self.assertSetEqual(set(end_points.keys()), set(expected_names))
+
+ def testModelVariables(self):
+ batch_size = 5
+ height, width = 224, 224
+ num_classes = 1000
+ with self.test_session():
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ vgg.vgg_16(inputs, num_classes)
+ expected_names = ['vgg_16/conv1/conv1_1/weights',
+ 'vgg_16/conv1/conv1_1/biases',
+ 'vgg_16/conv1/conv1_2/weights',
+ 'vgg_16/conv1/conv1_2/biases',
+ 'vgg_16/conv2/conv2_1/weights',
+ 'vgg_16/conv2/conv2_1/biases',
+ 'vgg_16/conv2/conv2_2/weights',
+ 'vgg_16/conv2/conv2_2/biases',
+ 'vgg_16/conv3/conv3_1/weights',
+ 'vgg_16/conv3/conv3_1/biases',
+ 'vgg_16/conv3/conv3_2/weights',
+ 'vgg_16/conv3/conv3_2/biases',
+ 'vgg_16/conv3/conv3_3/weights',
+ 'vgg_16/conv3/conv3_3/biases',
+ 'vgg_16/conv4/conv4_1/weights',
+ 'vgg_16/conv4/conv4_1/biases',
+ 'vgg_16/conv4/conv4_2/weights',
+ 'vgg_16/conv4/conv4_2/biases',
+ 'vgg_16/conv4/conv4_3/weights',
+ 'vgg_16/conv4/conv4_3/biases',
+ 'vgg_16/conv5/conv5_1/weights',
+ 'vgg_16/conv5/conv5_1/biases',
+ 'vgg_16/conv5/conv5_2/weights',
+ 'vgg_16/conv5/conv5_2/biases',
+ 'vgg_16/conv5/conv5_3/weights',
+ 'vgg_16/conv5/conv5_3/biases',
+ 'vgg_16/fc6/weights',
+ 'vgg_16/fc6/biases',
+ 'vgg_16/fc7/weights',
+ 'vgg_16/fc7/biases',
+ 'vgg_16/fc8/weights',
+ 'vgg_16/fc8/biases',
+ ]
+ model_variables = [v.op.name for v in slim.get_model_variables()]
+ self.assertSetEqual(set(model_variables), set(expected_names))
+
+ def testEvaluation(self):
+ batch_size = 2
+ height, width = 224, 224
+ num_classes = 1000
+ with self.test_session():
+ eval_inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = vgg.vgg_16(eval_inputs, is_training=False)
+ self.assertListEqual(logits.get_shape().as_list(),
+ [batch_size, num_classes])
+ predictions = tf.argmax(logits, 1)
+ self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
+
+ def testTrainEvalWithReuse(self):
+ train_batch_size = 2
+ eval_batch_size = 1
+ train_height, train_width = 224, 224
+ eval_height, eval_width = 256, 256
+ num_classes = 1000
+ with self.test_session():
+ train_inputs = tf.random_uniform(
+ (train_batch_size, train_height, train_width, 3))
+ logits, _ = vgg.vgg_16(train_inputs)
+ self.assertListEqual(logits.get_shape().as_list(),
+ [train_batch_size, num_classes])
+ tf.get_variable_scope().reuse_variables()
+ eval_inputs = tf.random_uniform(
+ (eval_batch_size, eval_height, eval_width, 3))
+ logits, _ = vgg.vgg_16(eval_inputs, is_training=False,
+ spatial_squeeze=False)
+ self.assertListEqual(logits.get_shape().as_list(),
+ [eval_batch_size, 2, 2, num_classes])
+ logits = tf.reduce_mean(logits, [1, 2])
+ predictions = tf.argmax(logits, 1)
+ self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
+
+ def testForward(self):
+ batch_size = 1
+ height, width = 224, 224
+ with self.test_session() as sess:
+ inputs = tf.random_uniform((batch_size, height, width, 3))
+ logits, _ = vgg.vgg_16(inputs)
+ sess.run(tf.initialize_all_variables())
+ output = sess.run(logits)
+ self.assertTrue(output.any())
+
+if __name__ == '__main__':
+ tf.test.main()