diff options
Diffstat (limited to 'tensorflow/examples/skflow/resnet.py')
-rwxr-xr-x[-rw-r--r--] | tensorflow/examples/skflow/resnet.py | 62 |
1 files changed, 27 insertions, 35 deletions
diff --git a/tensorflow/examples/skflow/resnet.py b/tensorflow/examples/skflow/resnet.py index 03a5d5e519..d67022d457 100644..100755 --- a/tensorflow/examples/skflow/resnet.py +++ b/tensorflow/examples/skflow/resnet.py @@ -52,13 +52,13 @@ def res_net(x, y, activation=tf.nn.relu): Predictions and loss tensors. """ - # Configurations for each bottleneck block. - BottleneckBlock = namedtuple( - 'BottleneckBlock', ['num_layers', 'num_filters', 'bottleneck_size']) - blocks = [BottleneckBlock(3, 128, 32), - BottleneckBlock(3, 256, 64), - BottleneckBlock(3, 512, 128), - BottleneckBlock(3, 1024, 256)] + # Configurations for each bottleneck group. + BottleneckGroup = namedtuple( + 'BottleneckGroup', ['num_blocks', 'num_filters', 'bottleneck_size']) + groups = [BottleneckGroup(3, 128, 32), + BottleneckGroup(3, 256, 64), + BottleneckGroup(3, 512, 128), + BottleneckGroup(3, 1024, 256)] input_shape = x.get_shape().as_list() @@ -78,19 +78,19 @@ def res_net(x, y, activation=tf.nn.relu): # First chain of resnets with tf.variable_scope('conv_layer2'): - net = learn.ops.conv2d(net, blocks[0].num_filters, + net = learn.ops.conv2d(net, groups[0].num_filters, [1, 1], [1, 1, 1, 1], padding='VALID', bias=True) - # Create each bottleneck building block for each layer - for block_i, block in enumerate(blocks): - for layer_i in range(block.num_layers): - - name = 'block_%d/layer_%d' % (block_i, layer_i) + # Create the bottleneck groups, each of which contains `num_blocks` + # bottleneck groups. + for group_i, group in enumerate(groups): + for block_i in range(group.num_blocks): + name = 'group_%d/block_%d' % (group_i, block_i) # 1x1 convolution responsible for reducing dimension with tf.variable_scope(name + '/conv_in'): - conv = learn.ops.conv2d(net, block.bottleneck_size, + conv = learn.ops.conv2d(net, group.bottleneck_size, [1, 1], [1, 1, 1, 1], padding='VALID', activation=activation, @@ -98,7 +98,7 @@ def res_net(x, y, activation=tf.nn.relu): bias=False) with tf.variable_scope(name + '/conv_bottleneck'): - conv = learn.ops.conv2d(conv, block.bottleneck_size, + conv = learn.ops.conv2d(conv, group.bottleneck_size, [3, 3], [1, 1, 1, 1], padding='SAME', activation=activation, @@ -107,7 +107,8 @@ def res_net(x, y, activation=tf.nn.relu): # 1x1 convolution responsible for restoring dimension with tf.variable_scope(name + '/conv_out'): - conv = learn.ops.conv2d(conv, block.num_filters, + input_dim = net.get_shape()[-1].value + conv = learn.ops.conv2d(conv, input_dim, [1, 1], [1, 1, 1, 1], padding='VALID', activation=activation, @@ -118,16 +119,16 @@ def res_net(x, y, activation=tf.nn.relu): # residual function (identity shortcut) net = conv + net - try: - # upscale to the next block size - next_block = blocks[block_i + 1] - with tf.variable_scope('block_%d/conv_upscale' % block_i): - net = learn.ops.conv2d(net, next_block.num_filters, - [1, 1], [1, 1, 1, 1], - bias=False, - padding='SAME') - except IndexError: - pass + try: + # upscale to the next group size + next_group = groups[group_i + 1] + with tf.variable_scope('block_%d/conv_upscale' % group_i): + net = learn.ops.conv2d(net, next_group.num_filters, + [1, 1], [1, 1, 1, 1], + bias=False, + padding='SAME') + except IndexError: + pass net_shape = net.get_shape().as_list() net = tf.nn.avg_pool(net, @@ -139,18 +140,12 @@ def res_net(x, y, activation=tf.nn.relu): return learn.models.logistic_regression(net, y) - # Download and load MNIST data. mnist = input_data.read_data_sets('MNIST_data') # Restore model if graph is saved into a folder. if os.path.exists('models/resnet/graph.pbtxt'): classifier = learn.TensorFlowEstimator.restore('models/resnet/') -else: - # Create a new resnet classifier. - classifier = learn.TensorFlowEstimator( - model_fn=res_net, n_classes=10, batch_size=100, steps=100, - learning_rate=0.001, continue_training=True) while True: # Train model and save summaries into logdir. @@ -161,6 +156,3 @@ while True: score = metrics.accuracy_score( mnist.test.labels, classifier.predict(mnist.test.images, batch_size=64)) print('Accuracy: {0:f}'.format(score)) - - # Save model graph and checkpoints. - classifier.save('models/resnet/') |