diff options
author | Sergio Guadarrama <sguada@google.com> | 2016-07-19 16:14:42 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-07-19 17:18:36 -0700 |
commit | 15afaf7b2ed6aa77b62dde0007afffaffc6a051d (patch) | |
tree | daded573d0cee9df2fdb20a64c41ba98416debbb | |
parent | 496f3a48d900b84006e8dba697be6252328cbf98 (diff) |
Add common nets to tf.slim
Change: 127893189
-rw-r--r-- | tensorflow/contrib/slim/BUILD | 7 | ||||
-rw-r--r-- | tensorflow/contrib/slim/__init__.py | 3 | ||||
-rw-r--r-- | tensorflow/contrib/slim/python/slim/model_analyzer.py | 104 | ||||
-rw-r--r-- | tensorflow/contrib/slim/python/slim/nets/BUILD | 70 | ||||
-rw-r--r-- | tensorflow/contrib/slim/python/slim/nets/__init__.py | 23 | ||||
-rw-r--r-- | tensorflow/contrib/slim/python/slim/nets/alexnet.py | 119 | ||||
-rw-r--r-- | tensorflow/contrib/slim/python/slim/nets/alexnet_test.py | 144 | ||||
-rw-r--r-- | tensorflow/contrib/slim/python/slim/nets/overfeat.py | 112 | ||||
-rw-r--r-- | tensorflow/contrib/slim/python/slim/nets/overfeat_test.py | 145 | ||||
-rw-r--r-- | tensorflow/contrib/slim/python/slim/nets/vgg.py | 159 | ||||
-rw-r--r-- | tensorflow/contrib/slim/python/slim/nets/vgg_test.py | 300 |
11 files changed, 1185 insertions, 1 deletions
diff --git a/tensorflow/contrib/slim/BUILD b/tensorflow/contrib/slim/BUILD index 96dbe98158..7c4a1d036d 100644 --- a/tensorflow/contrib/slim/BUILD +++ b/tensorflow/contrib/slim/BUILD @@ -55,6 +55,12 @@ py_test( ) py_library( + name = "model_analyzer", + srcs = ["python/slim/model_analyzer.py"], + srcs_version = "PY2AND3", +) + +py_library( name = "queues", srcs = ["python/slim/queues.py"], srcs_version = "PY2AND3", @@ -74,6 +80,7 @@ py_library( deps = [ ":evaluation", ":learning", + ":model_analyzer", ":queues", "//tensorflow/contrib/slim/python/slim/data", ], diff --git a/tensorflow/contrib/slim/__init__.py b/tensorflow/contrib/slim/__init__.py index 7f7bc8302b..30e69456e4 100644 --- a/tensorflow/contrib/slim/__init__.py +++ b/tensorflow/contrib/slim/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 Google Inc. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ from tensorflow.contrib.layers.python.layers.initializers import * from tensorflow.contrib.layers.python.layers.regularizers import * from tensorflow.contrib.slim.python.slim import evaluation from tensorflow.contrib.slim.python.slim import learning +from tensorflow.contrib.slim.python.slim import model_analyzer from tensorflow.contrib.slim.python.slim import queues from tensorflow.contrib.slim.python.slim.data import data_decoder from tensorflow.contrib.slim.python.slim.data import data_provider diff --git a/tensorflow/contrib/slim/python/slim/model_analyzer.py b/tensorflow/contrib/slim/python/slim/model_analyzer.py new file mode 100644 index 0000000000..73659be4a7 --- /dev/null +++ b/tensorflow/contrib/slim/python/slim/model_analyzer.py @@ -0,0 +1,104 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tools for analyzing the operations and variables in a TensorFlow graph. + +To analyze the operations in a graph: + + images, labels = LoadData(...) + predictions = MyModel(images) + + slim.model_analyzer.analyze_ops(tf.get_default_graph(), print_info=True) + +To analyze the model variables in a graph: + + variables = tf.model_variables() + slim.model_analyzer.analyze_vars(variables, print_info=False) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def tensor_description(var): + """Returns a compact and informative string about a tensor. + + Args: + var: A tensor variable. + + Returns: + a string with type and size, e.g.: (float32 1x8x8x1024). + """ + description = '(' + str(var.dtype.name) + ' ' + sizes = var.get_shape() + for i, size in enumerate(sizes): + description += str(size) + if i < len(sizes) - 1: + description += 'x' + description += ')' + return description + + +def analyze_ops(graph, print_info=False): + """Compute the estimated size of the ops.outputs in the graph. + + Args: + graph: the graph containing the operations. + print_info: Optional, if true print ops and their outputs. + + Returns: + total size of the ops.outputs + """ + if print_info: + print('---------') + print('Operations: name -> (type shapes) [size]') + print('---------') + total_size = 0 + for op in graph.get_operations(): + op_size = 0 + shapes = [] + for output in op.outputs: + # if output.num_elements() is None or [] assume size 0. + output_size = output.get_shape().num_elements() or 0 + if output.get_shape(): + shapes.append(tensor_description(output)) + op_size += output_size + if print_info: + print(op.name, '\t->', ', '.join(shapes), '[' + str(op_size) + ']') + total_size += op_size + return total_size + + +def analyze_vars(variables, print_info=False): + """Prints the names and shapes of the variables. + + Args: + variables: list of variables, for example tf.all_variables(). + print_info: Optional, if true print variables and their shape. + + Returns: + total size of the variables. + """ + if print_info: + print('---------') + print('Variables: name (type shape) [size]') + print('---------') + total_size = 0 + for var in variables: + # if var.num_elements() is None or [] assume size 0. + var_size = var.get_shape().num_elements() or 0 + total_size += var_size + if print_info: + print(var.name, tensor_description(var), '[' + str(var_size) + ']') + return total_size diff --git a/tensorflow/contrib/slim/python/slim/nets/BUILD b/tensorflow/contrib/slim/python/slim/nets/BUILD new file mode 100644 index 0000000000..7a70de6db6 --- /dev/null +++ b/tensorflow/contrib/slim/python/slim/nets/BUILD @@ -0,0 +1,70 @@ +# Description: +# Contains typical networks definitions. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +py_library( + name = "nets", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":alexnet", + ":overfeat", + ":vgg", + ], +) + +py_library( + name = "alexnet", + srcs = ["alexnet.py"], + srcs_version = "PY2AND3", +) + +py_test( + name = "alexnet_test", + size = "medium", + srcs = ["alexnet_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":alexnet", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "overfeat", + srcs = ["overfeat.py"], + srcs_version = "PY2AND3", +) + +py_test( + name = "overfeat_test", + size = "medium", + srcs = ["overfeat_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":overfeat", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "vgg", + srcs = ["vgg.py"], + srcs_version = "PY2AND3", +) + +py_test( + name = "vgg_test", + size = "medium", + srcs = ["vgg_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":vgg", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/slim/python/slim/nets/__init__.py b/tensorflow/contrib/slim/python/slim/nets/__init__.py new file mode 100644 index 0000000000..f04f805b00 --- /dev/null +++ b/tensorflow/contrib/slim/python/slim/nets/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TF-Slim Nets.""" +# pylint: disable=unused-import, +# pylint: disable=wildcard-import + +# Collapse nets into a single namespace. +from tensorflow.contrib.slim.python.slim.nets import alexnet +from tensorflow.contrib.slim.python.slim.nets import overfeat +from tensorflow.contrib.slim.python.slim.nets import vgg +# pylint: enable=unused-import,wildcard-import diff --git a/tensorflow/contrib/slim/python/slim/nets/alexnet.py b/tensorflow/contrib/slim/python/slim/nets/alexnet.py new file mode 100644 index 0000000000..14497cd9a1 --- /dev/null +++ b/tensorflow/contrib/slim/python/slim/nets/alexnet.py @@ -0,0 +1,119 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains a model definition for AlexNet. + +This work was first described in: + ImageNet Classification with Deep Convolutional Neural Networks + Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton + +and later refined in: + One weird trick for parallelizing convolutional neural networks + Alex Krizhevsky, 2014 + +Here we provide the implementation proposed in "One weird trick" and not +"ImageNet Classification", as per the paper, the LRN layers have been removed. + +Usage: + with slim.arg_scope(alexnet.alexnet_v2_arg_scope()): + outputs, end_points = alexnet.alexnet_v2(inputs) + +""" + +import tensorflow as tf + +slim = tf.contrib.slim +trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) + + +def alexnet_v2_arg_scope(weight_decay=0.0005): + with slim.arg_scope([slim.conv2d, slim.fully_connected], + activation_fn=tf.nn.relu, + biases_initializer=tf.constant_initializer(0.1), + weights_regularizer=slim.l2_regularizer(weight_decay)): + with slim.arg_scope([slim.conv2d], padding='SAME'): + with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: + return arg_sc + + +def alexnet_v2(inputs, + num_classes=1000, + dropout_keep_prob=0.5, + is_training=True, + spatial_squeeze=True, + scope='alexnet_v2'): + """AlexNet version 2. + + Described in: http://arxiv.org/pdf/1404.5997v2.pdf + Parameters from: + github.com/akrizhevsky/cuda-convnet2/blob/master/layers/ + layers-imagenet-1gpu.cfg + + Note: All the fully_connected layers have been transformed to conv2d layers. + To use in classification mode, resize input to 224x224. To use in fully + convolutional mode, set spatial_squeeze to false. + The LRN layers have been removed and change the initializers from + random_normal_initializer to xavier_initializer. + + Args: + inputs: a tensor of size [batch_size, height, width, channels]. + num_classes: number of predicted classes. + dropout_keep_prob: the probability that activations are kept in the dropout + layers during training. + is_training: whether or not the model is being trained. + spatial_squeeze: whether or not should squeeze the spatial dimensions of the + outputs. Useful to remove unnecessary dimensions for classification. + scope: Optional scope for the variables. + + Returns: + the last op containing the log predictions and end_points dict. + """ + with tf.variable_op_scope([inputs], scope, 'alexnet_v2') as sc: + end_points_collection = sc.name + '_end_points' + # Collect outputs for conv2d, fully_connected and max_pool2d. + with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], + outputs_collections=[end_points_collection]): + net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', + scope='conv1') + net = slim.max_pool2d(net, [3, 3], 2, scope='pool1') + net = slim.conv2d(net, 192, [5, 5], scope='conv2') + net = slim.max_pool2d(net, [3, 3], 2, scope='pool2') + net = slim.conv2d(net, 384, [3, 3], scope='conv3') + net = slim.conv2d(net, 384, [3, 3], scope='conv4') + net = slim.conv2d(net, 256, [3, 3], scope='conv5') + net = slim.max_pool2d(net, [3, 3], 2, scope='pool5') + + # Use conv2d instead of fully_connected layers. + with slim.arg_scope([slim.conv2d], + weights_initializer=trunc_normal(0.005), + biases_initializer=tf.constant_initializer(0.1)): + net = slim.conv2d(net, 4096, [5, 5], padding='VALID', + scope='fc6') + net = slim.dropout(net, dropout_keep_prob, is_training=is_training, + scope='dropout6') + net = slim.conv2d(net, 4096, [1, 1], scope='fc7') + net = slim.dropout(net, dropout_keep_prob, is_training=is_training, + scope='dropout7') + net = slim.conv2d(net, num_classes, [1, 1], + activation_fn=None, + normalizer_fn=None, + biases_initializer=tf.zeros_initializer, + scope='fc8') + + # Convert end_points_collection into a end_point dict. + end_points = dict(tf.get_collection(end_points_collection)) + if spatial_squeeze: + net = tf.squeeze(net, [1, 2], name='fc8/squeezed') + end_points[sc.name + '/fc8'] = net + return net, end_points diff --git a/tensorflow/contrib/slim/python/slim/nets/alexnet_test.py b/tensorflow/contrib/slim/python/slim/nets/alexnet_test.py new file mode 100644 index 0000000000..346ffa09f2 --- /dev/null +++ b/tensorflow/contrib/slim/python/slim/nets/alexnet_test.py @@ -0,0 +1,144 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slim.nets.alexnet.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.slim.python.slim.nets import alexnet +slim = tf.contrib.slim + + +class AlexnetV2Test(tf.test.TestCase): + + def testBuild(self): + batch_size = 5 + height, width = 224, 224 + num_classes = 1000 + with self.test_session(): + inputs = tf.random_uniform((batch_size, height, width, 3)) + logits, _ = alexnet.alexnet_v2(inputs, num_classes) + self.assertEquals(logits.op.name, 'alexnet_v2/fc8/squeezed') + self.assertListEqual(logits.get_shape().as_list(), + [batch_size, num_classes]) + + def testFullyConvolutional(self): + batch_size = 1 + height, width = 300, 400 + num_classes = 1000 + with self.test_session(): + inputs = tf.random_uniform((batch_size, height, width, 3)) + logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False) + self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd') + self.assertListEqual(logits.get_shape().as_list(), + [batch_size, 4, 7, num_classes]) + + def testEndPoints(self): + batch_size = 5 + height, width = 224, 224 + num_classes = 1000 + with self.test_session(): + inputs = tf.random_uniform((batch_size, height, width, 3)) + _, end_points = alexnet.alexnet_v2(inputs, num_classes) + expected_names = ['alexnet_v2/conv1', + 'alexnet_v2/pool1', + 'alexnet_v2/conv2', + 'alexnet_v2/pool2', + 'alexnet_v2/conv3', + 'alexnet_v2/conv4', + 'alexnet_v2/conv5', + 'alexnet_v2/pool5', + 'alexnet_v2/fc6', + 'alexnet_v2/fc7', + 'alexnet_v2/fc8' + ] + self.assertSetEqual(set(end_points.keys()), set(expected_names)) + + def testModelVariables(self): + batch_size = 5 + height, width = 224, 224 + num_classes = 1000 + with self.test_session(): + inputs = tf.random_uniform((batch_size, height, width, 3)) + alexnet.alexnet_v2(inputs, num_classes) + expected_names = ['alexnet_v2/conv1/weights', + 'alexnet_v2/conv1/biases', + 'alexnet_v2/conv2/weights', + 'alexnet_v2/conv2/biases', + 'alexnet_v2/conv3/weights', + 'alexnet_v2/conv3/biases', + 'alexnet_v2/conv4/weights', + 'alexnet_v2/conv4/biases', + 'alexnet_v2/conv5/weights', + 'alexnet_v2/conv5/biases', + 'alexnet_v2/fc6/weights', + 'alexnet_v2/fc6/biases', + 'alexnet_v2/fc7/weights', + 'alexnet_v2/fc7/biases', + 'alexnet_v2/fc8/weights', + 'alexnet_v2/fc8/biases', + ] + model_variables = [v.op.name for v in slim.get_model_variables()] + self.assertSetEqual(set(model_variables), set(expected_names)) + + def testEvaluation(self): + batch_size = 2 + height, width = 224, 224 + num_classes = 1000 + with self.test_session(): + eval_inputs = tf.random_uniform((batch_size, height, width, 3)) + logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False) + self.assertListEqual(logits.get_shape().as_list(), + [batch_size, num_classes]) + predictions = tf.argmax(logits, 1) + self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) + + def testTrainEvalWithReuse(self): + train_batch_size = 2 + eval_batch_size = 1 + train_height, train_width = 224, 224 + eval_height, eval_width = 300, 400 + num_classes = 1000 + with self.test_session(): + train_inputs = tf.random_uniform( + (train_batch_size, train_height, train_width, 3)) + logits, _ = alexnet.alexnet_v2(train_inputs) + self.assertListEqual(logits.get_shape().as_list(), + [train_batch_size, num_classes]) + tf.get_variable_scope().reuse_variables() + eval_inputs = tf.random_uniform( + (eval_batch_size, eval_height, eval_width, 3)) + logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False, + spatial_squeeze=False) + self.assertListEqual(logits.get_shape().as_list(), + [eval_batch_size, 4, 7, num_classes]) + logits = tf.reduce_mean(logits, [1, 2]) + predictions = tf.argmax(logits, 1) + self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) + + def testForward(self): + batch_size = 1 + height, width = 224, 224 + with self.test_session() as sess: + inputs = tf.random_uniform((batch_size, height, width, 3)) + logits, _ = alexnet.alexnet_v2(inputs) + sess.run(tf.initialize_all_variables()) + output = sess.run(logits) + self.assertTrue(output.any()) + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/slim/python/slim/nets/overfeat.py b/tensorflow/contrib/slim/python/slim/nets/overfeat.py new file mode 100644 index 0000000000..a0c2033934 --- /dev/null +++ b/tensorflow/contrib/slim/python/slim/nets/overfeat.py @@ -0,0 +1,112 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains the model definition for the OverFeat network. + +The definition for the network was obtained from: + OverFeat: Integrated Recognition, Localization and Detection using + Convolutional Networks + Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and + Yann LeCun, 2014 + http://arxiv.org/abs/1312.6229 + +Usage: + with slim.arg_scope(overfeat.overfeat_arg_scope()): + outputs, end_points = overfeat.overfeat(inputs) +""" + +import tensorflow as tf + +slim = tf.contrib.slim +trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) + + +def overfeat_arg_scope(weight_decay=0.0005): + with slim.arg_scope([slim.conv2d, slim.fully_connected], + activation_fn=tf.nn.relu, + weights_regularizer=slim.l2_regularizer(weight_decay), + biases_initializer=tf.zeros_initializer): + with slim.arg_scope([slim.conv2d], padding='SAME'): + with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: + return arg_sc + + +def overfeat(inputs, + num_classes=1000, + dropout_keep_prob=0.5, + is_training=True, + spatial_squeeze=True, + scope='overfeat'): + """Contains the model definition for the OverFeat network. + + The definition for the network was obtained from: + OverFeat: Integrated Recognition, Localization and Detection using + Convolutional Networks + Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and + Yann LeCun, 2014 + http://arxiv.org/abs/1312.6229 + + Note: All the fully_connected layers have been transformed to conv2d layers. + To use in classification mode, resize input to 231x231. To use in fully + convolutional mode, set spatial_squeeze to false. + + Args: + inputs: a tensor of size [batch_size, height, width, channels]. + num_classes: number of predicted classes. + dropout_keep_prob: the probability that activations are kept in the dropout + layers during training. + is_training: whether or not the model is being trained. + spatial_squeeze: whether or not should squeeze the spatial dimensions of the + outputs. Useful to remove unnecessary dimensions for classification. + scope: Optional scope for the variables. + + Returns: + the last op containing the log predictions and end_points dict. + + """ + with tf.variable_op_scope([inputs], scope, 'overfeat') as sc: + end_points_collection = sc.name + '_end_points' + # Collect outputs for conv2d, fully_connected and max_pool2d + with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], + outputs_collections=end_points_collection): + net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', + scope='conv1') + net = slim.max_pool2d(net, [2, 2], scope='pool1') + net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2') + net = slim.max_pool2d(net, [2, 2], scope='pool2') + net = slim.conv2d(net, 512, [3, 3], scope='conv3') + net = slim.conv2d(net, 1024, [3, 3], scope='conv4') + net = slim.conv2d(net, 1024, [3, 3], scope='conv5') + net = slim.max_pool2d(net, [2, 2], scope='pool5') + with slim.arg_scope([slim.conv2d], + weights_initializer=trunc_normal(0.005), + biases_initializer=tf.constant_initializer(0.1)): + # Use conv2d instead of fully_connected layers. + net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6') + net = slim.dropout(net, dropout_keep_prob, is_training=is_training, + scope='dropout6') + net = slim.conv2d(net, 4096, [1, 1], scope='fc7') + net = slim.dropout(net, dropout_keep_prob, is_training=is_training, + scope='dropout7') + net = slim.conv2d(net, num_classes, [1, 1], + activation_fn=None, + normalizer_fn=None, + biases_initializer=tf.zeros_initializer, + scope='fc8') + # Convert end_points_collection into a end_point dict. + end_points = dict(tf.get_collection(end_points_collection)) + if spatial_squeeze: + net = tf.squeeze(net, [1, 2], name='fc8/squeezed') + end_points[sc.name + '/fc8'] = net + return net, end_points diff --git a/tensorflow/contrib/slim/python/slim/nets/overfeat_test.py b/tensorflow/contrib/slim/python/slim/nets/overfeat_test.py new file mode 100644 index 0000000000..f26e03c00e --- /dev/null +++ b/tensorflow/contrib/slim/python/slim/nets/overfeat_test.py @@ -0,0 +1,145 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slim.nets.overfeat.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.slim.python.slim.nets import overfeat + +slim = tf.contrib.slim + + +class OverFeatTest(tf.test.TestCase): + + def testBuild(self): + batch_size = 5 + height, width = 231, 231 + num_classes = 1000 + with self.test_session(): + inputs = tf.random_uniform((batch_size, height, width, 3)) + logits, _ = overfeat.overfeat(inputs, num_classes) + self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed') + self.assertListEqual(logits.get_shape().as_list(), + [batch_size, num_classes]) + + def testFullyConvolutional(self): + batch_size = 1 + height, width = 281, 281 + num_classes = 1000 + with self.test_session(): + inputs = tf.random_uniform((batch_size, height, width, 3)) + logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False) + self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') + self.assertListEqual(logits.get_shape().as_list(), + [batch_size, 2, 2, num_classes]) + + def testEndPoints(self): + batch_size = 5 + height, width = 231, 231 + num_classes = 1000 + with self.test_session(): + inputs = tf.random_uniform((batch_size, height, width, 3)) + _, end_points = overfeat.overfeat(inputs, num_classes) + expected_names = ['overfeat/conv1', + 'overfeat/pool1', + 'overfeat/conv2', + 'overfeat/pool2', + 'overfeat/conv3', + 'overfeat/conv4', + 'overfeat/conv5', + 'overfeat/pool5', + 'overfeat/fc6', + 'overfeat/fc7', + 'overfeat/fc8' + ] + self.assertSetEqual(set(end_points.keys()), set(expected_names)) + + def testModelVariables(self): + batch_size = 5 + height, width = 231, 231 + num_classes = 1000 + with self.test_session(): + inputs = tf.random_uniform((batch_size, height, width, 3)) + overfeat.overfeat(inputs, num_classes) + expected_names = ['overfeat/conv1/weights', + 'overfeat/conv1/biases', + 'overfeat/conv2/weights', + 'overfeat/conv2/biases', + 'overfeat/conv3/weights', + 'overfeat/conv3/biases', + 'overfeat/conv4/weights', + 'overfeat/conv4/biases', + 'overfeat/conv5/weights', + 'overfeat/conv5/biases', + 'overfeat/fc6/weights', + 'overfeat/fc6/biases', + 'overfeat/fc7/weights', + 'overfeat/fc7/biases', + 'overfeat/fc8/weights', + 'overfeat/fc8/biases', + ] + model_variables = [v.op.name for v in slim.get_model_variables()] + self.assertSetEqual(set(model_variables), set(expected_names)) + + def testEvaluation(self): + batch_size = 2 + height, width = 231, 231 + num_classes = 1000 + with self.test_session(): + eval_inputs = tf.random_uniform((batch_size, height, width, 3)) + logits, _ = overfeat.overfeat(eval_inputs, is_training=False) + self.assertListEqual(logits.get_shape().as_list(), + [batch_size, num_classes]) + predictions = tf.argmax(logits, 1) + self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) + + def testTrainEvalWithReuse(self): + train_batch_size = 2 + eval_batch_size = 1 + train_height, train_width = 231, 231 + eval_height, eval_width = 281, 281 + num_classes = 1000 + with self.test_session(): + train_inputs = tf.random_uniform( + (train_batch_size, train_height, train_width, 3)) + logits, _ = overfeat.overfeat(train_inputs) + self.assertListEqual(logits.get_shape().as_list(), + [train_batch_size, num_classes]) + tf.get_variable_scope().reuse_variables() + eval_inputs = tf.random_uniform( + (eval_batch_size, eval_height, eval_width, 3)) + logits, _ = overfeat.overfeat(eval_inputs, is_training=False, + spatial_squeeze=False) + self.assertListEqual(logits.get_shape().as_list(), + [eval_batch_size, 2, 2, num_classes]) + logits = tf.reduce_mean(logits, [1, 2]) + predictions = tf.argmax(logits, 1) + self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) + + def testForward(self): + batch_size = 1 + height, width = 231, 231 + with self.test_session() as sess: + inputs = tf.random_uniform((batch_size, height, width, 3)) + logits, _ = overfeat.overfeat(inputs) + sess.run(tf.initialize_all_variables()) + output = sess.run(logits) + self.assertTrue(output.any()) + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/slim/python/slim/nets/vgg.py b/tensorflow/contrib/slim/python/slim/nets/vgg.py new file mode 100644 index 0000000000..892064b051 --- /dev/null +++ b/tensorflow/contrib/slim/python/slim/nets/vgg.py @@ -0,0 +1,159 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains model definitions for versions of the VGG network. + +These model definitions were introduced in the technical report: + Very Deep Convolutional Networks For Large-Scale Image Recognition + Karen Simonyan and Andrew Zisserman, ICLR 2015 + +More information can be obtained from the VGG website: +www.robots.ox.ac.uk/~vgg/research/very_deep/ + +Usage: + with slim.arg_scope(vgg.vgg_arg_scope()): + outputs, end_points = vgg.vgg_a(inputs) + + with slim.arg_scope(vgg.vgg_arg_scope()): + outputs, end_points = vgg.vgg_16(inputs) +""" + +import tensorflow as tf + +slim = tf.contrib.slim + + +def vgg_arg_scope(weight_decay=0.0005): + with slim.arg_scope([slim.conv2d, slim.fully_connected], + activation_fn=tf.nn.relu, + weights_regularizer=slim.l2_regularizer(weight_decay), + biases_initializer=tf.zeros_initializer): + with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc: + return arg_sc + + +def vgg_a(inputs, + num_classes=1000, + dropout_keep_prob=0.5, + is_training=True, + spatial_squeeze=True, + scope='vgg_a'): + """Oxford Net VGG 11-Layers version A Example. + + Note: All the fully_connected layers have been transformed to conv2d layers. + To use in classification mode, resize input to 224x224. + + Args: + inputs: a tensor of size [batch_size, height, width, channels]. + num_classes: number of predicted classes. + dropout_keep_prob: the probability that activations are kept in the dropout + layers during training. + is_training: whether or not the model is being trained. + spatial_squeeze: whether or not should squeeze the spatial dimensions of the + outputs. Useful to remove unnecessary dimensions for classification. + scope: Optional scope for the variables. + + Returns: + the last op containing the log predictions and end_points dict. + """ + with tf.variable_op_scope([inputs], scope, 'vgg_a') as sc: + end_points_collection = sc.name + '_end_points' + # Collect outputs for conv2d, fully_connected and max_pool2d. + with slim.arg_scope([slim.conv2d, slim.max_pool2d], + outputs_collections=end_points_collection): + net = slim.repeat(inputs, 1, slim.conv2d, 64, [3, 3], scope='conv1') + net = slim.max_pool2d(net, [2, 2], scope='pool1') + net = slim.repeat(net, 1, slim.conv2d, 128, [3, 3], scope='conv2') + net = slim.max_pool2d(net, [2, 2], scope='pool2') + net = slim.repeat(net, 2, slim.conv2d, 256, [3, 3], scope='conv3') + net = slim.max_pool2d(net, [2, 2], scope='pool3') + net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv4') + net = slim.max_pool2d(net, [2, 2], scope='pool4') + net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv5') + net = slim.max_pool2d(net, [2, 2], scope='pool5') + # Use conv2d instead of fully_connected layers. + net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') + net = slim.dropout(net, dropout_keep_prob, is_training=is_training, + scope='dropout6') + net = slim.conv2d(net, 4096, [1, 1], scope='fc7') + net = slim.dropout(net, dropout_keep_prob, is_training=is_training, + scope='dropout7') + net = slim.conv2d(net, num_classes, [1, 1], + activation_fn=None, + normalizer_fn=None, + scope='fc8') + # Convert end_points_collection into a end_point dict. + end_points = dict(tf.get_collection(end_points_collection)) + if spatial_squeeze: + net = tf.squeeze(net, [1, 2], name='fc8/squeezed') + end_points[sc.name + '/fc8'] = net + return net, end_points + + +def vgg_16(inputs, + num_classes=1000, + dropout_keep_prob=0.5, + is_training=True, + spatial_squeeze=True, + scope='vgg_16'): + """Oxford Net VGG 16-Layers version D Example. + + Note: All the fully_connected layers have been transformed to conv2d layers. + To use in classification mode, resize input to 224x224. + + Args: + inputs: a tensor of size [batch_size, height, width, channels]. + num_classes: number of predicted classes. + dropout_keep_prob: the probability that activations are kept in the dropout + layers during training. + is_training: whether or not the model is being trained. + spatial_squeeze: whether or not should squeeze the spatial dimensions of the + outputs. Useful to remove unnecessary dimensions for classification. + scope: Optional scope for the variables. + + Returns: + the last op containing the log predictions and end_points dict. + """ + with tf.variable_op_scope([inputs], scope, 'vgg_16') as sc: + end_points_collection = sc.name + '_end_points' + # Collect outputs for conv2d, fully_connected and max_pool2d. + with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], + outputs_collections=end_points_collection): + net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') + net = slim.max_pool2d(net, [2, 2], scope='pool1') + net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') + net = slim.max_pool2d(net, [2, 2], scope='pool2') + net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3') + net = slim.max_pool2d(net, [2, 2], scope='pool3') + net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4') + net = slim.max_pool2d(net, [2, 2], scope='pool4') + net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5') + net = slim.max_pool2d(net, [2, 2], scope='pool5') + # Use conv2d instead of fully_connected layers. + net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') + net = slim.dropout(net, dropout_keep_prob, is_training=is_training, + scope='dropout6') + net = slim.conv2d(net, 4096, [1, 1], scope='fc7') + net = slim.dropout(net, dropout_keep_prob, is_training=is_training, + scope='dropout7') + net = slim.conv2d(net, num_classes, [1, 1], + activation_fn=None, + normalizer_fn=None, + scope='fc8') + # Convert end_points_collection into a end_point dict. + end_points = dict(tf.get_collection(end_points_collection)) + if spatial_squeeze: + net = tf.squeeze(net, [1, 2], name='fc8/squeezed') + end_points[sc.name + '/fc8'] = net + return net, end_points diff --git a/tensorflow/contrib/slim/python/slim/nets/vgg_test.py b/tensorflow/contrib/slim/python/slim/nets/vgg_test.py new file mode 100644 index 0000000000..81e4d47d4d --- /dev/null +++ b/tensorflow/contrib/slim/python/slim/nets/vgg_test.py @@ -0,0 +1,300 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slim.nets.vgg.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.slim.python.slim.nets import vgg +slim = tf.contrib.slim + + +class VGGATest(tf.test.TestCase): + + def testBuild(self): + batch_size = 5 + height, width = 224, 224 + num_classes = 1000 + with self.test_session(): + inputs = tf.random_uniform((batch_size, height, width, 3)) + logits, _ = vgg.vgg_a(inputs, num_classes) + self.assertEquals(logits.op.name, 'vgg_a/fc8/squeezed') + self.assertListEqual(logits.get_shape().as_list(), + [batch_size, num_classes]) + + def testFullyConvolutional(self): + batch_size = 1 + height, width = 256, 256 + num_classes = 1000 + with self.test_session(): + inputs = tf.random_uniform((batch_size, height, width, 3)) + logits, _ = vgg.vgg_a(inputs, num_classes, spatial_squeeze=False) + self.assertEquals(logits.op.name, 'vgg_a/fc8/BiasAdd') + self.assertListEqual(logits.get_shape().as_list(), + [batch_size, 2, 2, num_classes]) + + def testEndPoints(self): + batch_size = 5 + height, width = 224, 224 + num_classes = 1000 + with self.test_session(): + inputs = tf.random_uniform((batch_size, height, width, 3)) + _, end_points = vgg.vgg_a(inputs, num_classes) + expected_names = ['vgg_a/conv1/conv1_1', + 'vgg_a/pool1', + 'vgg_a/conv2/conv2_1', + 'vgg_a/pool2', + 'vgg_a/conv3/conv3_1', + 'vgg_a/conv3/conv3_2', + 'vgg_a/pool3', + 'vgg_a/conv4/conv4_1', + 'vgg_a/conv4/conv4_2', + 'vgg_a/pool4', + 'vgg_a/conv5/conv5_1', + 'vgg_a/conv5/conv5_2', + 'vgg_a/pool5', + 'vgg_a/fc6', + 'vgg_a/fc7', + 'vgg_a/fc8' + ] + self.assertSetEqual(set(end_points.keys()), set(expected_names)) + + def testModelVariables(self): + batch_size = 5 + height, width = 224, 224 + num_classes = 1000 + with self.test_session(): + inputs = tf.random_uniform((batch_size, height, width, 3)) + vgg.vgg_a(inputs, num_classes) + expected_names = ['vgg_a/conv1/conv1_1/weights', + 'vgg_a/conv1/conv1_1/biases', + 'vgg_a/conv2/conv2_1/weights', + 'vgg_a/conv2/conv2_1/biases', + 'vgg_a/conv3/conv3_1/weights', + 'vgg_a/conv3/conv3_1/biases', + 'vgg_a/conv3/conv3_2/weights', + 'vgg_a/conv3/conv3_2/biases', + 'vgg_a/conv4/conv4_1/weights', + 'vgg_a/conv4/conv4_1/biases', + 'vgg_a/conv4/conv4_2/weights', + 'vgg_a/conv4/conv4_2/biases', + 'vgg_a/conv5/conv5_1/weights', + 'vgg_a/conv5/conv5_1/biases', + 'vgg_a/conv5/conv5_2/weights', + 'vgg_a/conv5/conv5_2/biases', + 'vgg_a/fc6/weights', + 'vgg_a/fc6/biases', + 'vgg_a/fc7/weights', + 'vgg_a/fc7/biases', + 'vgg_a/fc8/weights', + 'vgg_a/fc8/biases', + ] + model_variables = [v.op.name for v in slim.get_model_variables()] + self.assertSetEqual(set(model_variables), set(expected_names)) + + def testEvaluation(self): + batch_size = 2 + height, width = 224, 224 + num_classes = 1000 + with self.test_session(): + eval_inputs = tf.random_uniform((batch_size, height, width, 3)) + logits, _ = vgg.vgg_a(eval_inputs, is_training=False) + self.assertListEqual(logits.get_shape().as_list(), + [batch_size, num_classes]) + predictions = tf.argmax(logits, 1) + self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) + + def testTrainEvalWithReuse(self): + train_batch_size = 2 + eval_batch_size = 1 + train_height, train_width = 224, 224 + eval_height, eval_width = 256, 256 + num_classes = 1000 + with self.test_session(): + train_inputs = tf.random_uniform( + (train_batch_size, train_height, train_width, 3)) + logits, _ = vgg.vgg_a(train_inputs) + self.assertListEqual(logits.get_shape().as_list(), + [train_batch_size, num_classes]) + tf.get_variable_scope().reuse_variables() + eval_inputs = tf.random_uniform( + (eval_batch_size, eval_height, eval_width, 3)) + logits, _ = vgg.vgg_a(eval_inputs, is_training=False, + spatial_squeeze=False) + self.assertListEqual(logits.get_shape().as_list(), + [eval_batch_size, 2, 2, num_classes]) + logits = tf.reduce_mean(logits, [1, 2]) + predictions = tf.argmax(logits, 1) + self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) + + def testForward(self): + batch_size = 1 + height, width = 224, 224 + with self.test_session() as sess: + inputs = tf.random_uniform((batch_size, height, width, 3)) + logits, _ = vgg.vgg_a(inputs) + sess.run(tf.initialize_all_variables()) + output = sess.run(logits) + self.assertTrue(output.any()) + + +class VGG16Test(tf.test.TestCase): + + def testBuild(self): + batch_size = 5 + height, width = 224, 224 + num_classes = 1000 + with self.test_session(): + inputs = tf.random_uniform((batch_size, height, width, 3)) + logits, _ = vgg.vgg_16(inputs, num_classes) + self.assertEquals(logits.op.name, 'vgg_16/fc8/squeezed') + self.assertListEqual(logits.get_shape().as_list(), + [batch_size, num_classes]) + + def testFullyConvolutional(self): + batch_size = 1 + height, width = 256, 256 + num_classes = 1000 + with self.test_session(): + inputs = tf.random_uniform((batch_size, height, width, 3)) + logits, _ = vgg.vgg_16(inputs, num_classes, spatial_squeeze=False) + self.assertEquals(logits.op.name, 'vgg_16/fc8/BiasAdd') + self.assertListEqual(logits.get_shape().as_list(), + [batch_size, 2, 2, num_classes]) + + def testEndPoints(self): + batch_size = 5 + height, width = 224, 224 + num_classes = 1000 + with self.test_session(): + inputs = tf.random_uniform((batch_size, height, width, 3)) + _, end_points = vgg.vgg_16(inputs, num_classes) + expected_names = ['vgg_16/conv1/conv1_1', + 'vgg_16/conv1/conv1_2', + 'vgg_16/pool1', + 'vgg_16/conv2/conv2_1', + 'vgg_16/conv2/conv2_2', + 'vgg_16/pool2', + 'vgg_16/conv3/conv3_1', + 'vgg_16/conv3/conv3_2', + 'vgg_16/conv3/conv3_3', + 'vgg_16/pool3', + 'vgg_16/conv4/conv4_1', + 'vgg_16/conv4/conv4_2', + 'vgg_16/conv4/conv4_3', + 'vgg_16/pool4', + 'vgg_16/conv5/conv5_1', + 'vgg_16/conv5/conv5_2', + 'vgg_16/conv5/conv5_3', + 'vgg_16/pool5', + 'vgg_16/fc6', + 'vgg_16/fc7', + 'vgg_16/fc8' + ] + print(end_points.keys()) + self.assertSetEqual(set(end_points.keys()), set(expected_names)) + + def testModelVariables(self): + batch_size = 5 + height, width = 224, 224 + num_classes = 1000 + with self.test_session(): + inputs = tf.random_uniform((batch_size, height, width, 3)) + vgg.vgg_16(inputs, num_classes) + expected_names = ['vgg_16/conv1/conv1_1/weights', + 'vgg_16/conv1/conv1_1/biases', + 'vgg_16/conv1/conv1_2/weights', + 'vgg_16/conv1/conv1_2/biases', + 'vgg_16/conv2/conv2_1/weights', + 'vgg_16/conv2/conv2_1/biases', + 'vgg_16/conv2/conv2_2/weights', + 'vgg_16/conv2/conv2_2/biases', + 'vgg_16/conv3/conv3_1/weights', + 'vgg_16/conv3/conv3_1/biases', + 'vgg_16/conv3/conv3_2/weights', + 'vgg_16/conv3/conv3_2/biases', + 'vgg_16/conv3/conv3_3/weights', + 'vgg_16/conv3/conv3_3/biases', + 'vgg_16/conv4/conv4_1/weights', + 'vgg_16/conv4/conv4_1/biases', + 'vgg_16/conv4/conv4_2/weights', + 'vgg_16/conv4/conv4_2/biases', + 'vgg_16/conv4/conv4_3/weights', + 'vgg_16/conv4/conv4_3/biases', + 'vgg_16/conv5/conv5_1/weights', + 'vgg_16/conv5/conv5_1/biases', + 'vgg_16/conv5/conv5_2/weights', + 'vgg_16/conv5/conv5_2/biases', + 'vgg_16/conv5/conv5_3/weights', + 'vgg_16/conv5/conv5_3/biases', + 'vgg_16/fc6/weights', + 'vgg_16/fc6/biases', + 'vgg_16/fc7/weights', + 'vgg_16/fc7/biases', + 'vgg_16/fc8/weights', + 'vgg_16/fc8/biases', + ] + model_variables = [v.op.name for v in slim.get_model_variables()] + self.assertSetEqual(set(model_variables), set(expected_names)) + + def testEvaluation(self): + batch_size = 2 + height, width = 224, 224 + num_classes = 1000 + with self.test_session(): + eval_inputs = tf.random_uniform((batch_size, height, width, 3)) + logits, _ = vgg.vgg_16(eval_inputs, is_training=False) + self.assertListEqual(logits.get_shape().as_list(), + [batch_size, num_classes]) + predictions = tf.argmax(logits, 1) + self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) + + def testTrainEvalWithReuse(self): + train_batch_size = 2 + eval_batch_size = 1 + train_height, train_width = 224, 224 + eval_height, eval_width = 256, 256 + num_classes = 1000 + with self.test_session(): + train_inputs = tf.random_uniform( + (train_batch_size, train_height, train_width, 3)) + logits, _ = vgg.vgg_16(train_inputs) + self.assertListEqual(logits.get_shape().as_list(), + [train_batch_size, num_classes]) + tf.get_variable_scope().reuse_variables() + eval_inputs = tf.random_uniform( + (eval_batch_size, eval_height, eval_width, 3)) + logits, _ = vgg.vgg_16(eval_inputs, is_training=False, + spatial_squeeze=False) + self.assertListEqual(logits.get_shape().as_list(), + [eval_batch_size, 2, 2, num_classes]) + logits = tf.reduce_mean(logits, [1, 2]) + predictions = tf.argmax(logits, 1) + self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) + + def testForward(self): + batch_size = 1 + height, width = 224, 224 + with self.test_session() as sess: + inputs = tf.random_uniform((batch_size, height, width, 3)) + logits, _ = vgg.vgg_16(inputs) + sess.run(tf.initialize_all_variables()) + output = sess.run(logits) + self.assertTrue(output.any()) + +if __name__ == '__main__': + tf.test.main() |