aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/skflow/resnet.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/skflow/resnet.py')
-rwxr-xr-x[-rw-r--r--]tensorflow/examples/skflow/resnet.py62
1 files changed, 27 insertions, 35 deletions
diff --git a/tensorflow/examples/skflow/resnet.py b/tensorflow/examples/skflow/resnet.py
index 03a5d5e519..d67022d457 100644..100755
--- a/tensorflow/examples/skflow/resnet.py
+++ b/tensorflow/examples/skflow/resnet.py
@@ -52,13 +52,13 @@ def res_net(x, y, activation=tf.nn.relu):
Predictions and loss tensors.
"""
- # Configurations for each bottleneck block.
- BottleneckBlock = namedtuple(
- 'BottleneckBlock', ['num_layers', 'num_filters', 'bottleneck_size'])
- blocks = [BottleneckBlock(3, 128, 32),
- BottleneckBlock(3, 256, 64),
- BottleneckBlock(3, 512, 128),
- BottleneckBlock(3, 1024, 256)]
+ # Configurations for each bottleneck group.
+ BottleneckGroup = namedtuple(
+ 'BottleneckGroup', ['num_blocks', 'num_filters', 'bottleneck_size'])
+ groups = [BottleneckGroup(3, 128, 32),
+ BottleneckGroup(3, 256, 64),
+ BottleneckGroup(3, 512, 128),
+ BottleneckGroup(3, 1024, 256)]
input_shape = x.get_shape().as_list()
@@ -78,19 +78,19 @@ def res_net(x, y, activation=tf.nn.relu):
# First chain of resnets
with tf.variable_scope('conv_layer2'):
- net = learn.ops.conv2d(net, blocks[0].num_filters,
+ net = learn.ops.conv2d(net, groups[0].num_filters,
[1, 1], [1, 1, 1, 1],
padding='VALID', bias=True)
- # Create each bottleneck building block for each layer
- for block_i, block in enumerate(blocks):
- for layer_i in range(block.num_layers):
-
- name = 'block_%d/layer_%d' % (block_i, layer_i)
+ # Create the bottleneck groups, each of which contains `num_blocks`
+ # bottleneck groups.
+ for group_i, group in enumerate(groups):
+ for block_i in range(group.num_blocks):
+ name = 'group_%d/block_%d' % (group_i, block_i)
# 1x1 convolution responsible for reducing dimension
with tf.variable_scope(name + '/conv_in'):
- conv = learn.ops.conv2d(net, block.bottleneck_size,
+ conv = learn.ops.conv2d(net, group.bottleneck_size,
[1, 1], [1, 1, 1, 1],
padding='VALID',
activation=activation,
@@ -98,7 +98,7 @@ def res_net(x, y, activation=tf.nn.relu):
bias=False)
with tf.variable_scope(name + '/conv_bottleneck'):
- conv = learn.ops.conv2d(conv, block.bottleneck_size,
+ conv = learn.ops.conv2d(conv, group.bottleneck_size,
[3, 3], [1, 1, 1, 1],
padding='SAME',
activation=activation,
@@ -107,7 +107,8 @@ def res_net(x, y, activation=tf.nn.relu):
# 1x1 convolution responsible for restoring dimension
with tf.variable_scope(name + '/conv_out'):
- conv = learn.ops.conv2d(conv, block.num_filters,
+ input_dim = net.get_shape()[-1].value
+ conv = learn.ops.conv2d(conv, input_dim,
[1, 1], [1, 1, 1, 1],
padding='VALID',
activation=activation,
@@ -118,16 +119,16 @@ def res_net(x, y, activation=tf.nn.relu):
# residual function (identity shortcut)
net = conv + net
- try:
- # upscale to the next block size
- next_block = blocks[block_i + 1]
- with tf.variable_scope('block_%d/conv_upscale' % block_i):
- net = learn.ops.conv2d(net, next_block.num_filters,
- [1, 1], [1, 1, 1, 1],
- bias=False,
- padding='SAME')
- except IndexError:
- pass
+ try:
+ # upscale to the next group size
+ next_group = groups[group_i + 1]
+ with tf.variable_scope('block_%d/conv_upscale' % group_i):
+ net = learn.ops.conv2d(net, next_group.num_filters,
+ [1, 1], [1, 1, 1, 1],
+ bias=False,
+ padding='SAME')
+ except IndexError:
+ pass
net_shape = net.get_shape().as_list()
net = tf.nn.avg_pool(net,
@@ -139,18 +140,12 @@ def res_net(x, y, activation=tf.nn.relu):
return learn.models.logistic_regression(net, y)
-
# Download and load MNIST data.
mnist = input_data.read_data_sets('MNIST_data')
# Restore model if graph is saved into a folder.
if os.path.exists('models/resnet/graph.pbtxt'):
classifier = learn.TensorFlowEstimator.restore('models/resnet/')
-else:
- # Create a new resnet classifier.
- classifier = learn.TensorFlowEstimator(
- model_fn=res_net, n_classes=10, batch_size=100, steps=100,
- learning_rate=0.001, continue_training=True)
while True:
# Train model and save summaries into logdir.
@@ -161,6 +156,3 @@ while True:
score = metrics.accuracy_score(
mnist.test.labels, classifier.predict(mnist.test.images, batch_size=64))
print('Accuracy: {0:f}'.format(score))
-
- # Save model graph and checkpoints.
- classifier.save('models/resnet/')