aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/learn/resnet.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/learn/resnet.py')
-rwxr-xr-xtensorflow/examples/learn/resnet.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/examples/learn/resnet.py b/tensorflow/examples/learn/resnet.py
index 9542e55250..c00de932a8 100755
--- a/tensorflow/examples/learn/resnet.py
+++ b/tensorflow/examples/learn/resnet.py
@@ -53,6 +53,8 @@ def res_net_model(features, labels, mode):
ndim = int(sqrt(input_shape[1]))
x = tf.reshape(x, [-1, ndim, ndim, 1])
+ training = (mode == tf.estimator.ModeKeys.TRAIN)
+
# First convolution expands to 64 channels
with tf.variable_scope('conv_layer1'):
net = tf.layers.conv2d(
@@ -60,7 +62,7 @@ def res_net_model(features, labels, mode):
filters=64,
kernel_size=7,
activation=tf.nn.relu)
- net = tf.layers.batch_normalization(net)
+ net = tf.layers.batch_normalization(net, training=training)
# Max pool
net = tf.layers.max_pooling2d(
@@ -88,7 +90,7 @@ def res_net_model(features, labels, mode):
kernel_size=1,
padding='valid',
activation=tf.nn.relu)
- conv = tf.layers.batch_normalization(conv)
+ conv = tf.layers.batch_normalization(conv, training=training)
with tf.variable_scope(name + '/conv_bottleneck'):
conv = tf.layers.conv2d(
@@ -97,7 +99,7 @@ def res_net_model(features, labels, mode):
kernel_size=3,
padding='same',
activation=tf.nn.relu)
- conv = tf.layers.batch_normalization(conv)
+ conv = tf.layers.batch_normalization(conv, training=training)
# 1x1 convolution responsible for restoring dimension
with tf.variable_scope(name + '/conv_out'):
@@ -108,7 +110,7 @@ def res_net_model(features, labels, mode):
kernel_size=1,
padding='valid',
activation=tf.nn.relu)
- conv = tf.layers.batch_normalization(conv)
+ conv = tf.layers.batch_normalization(conv, training=training)
# shortcut connections that turn the network into its counterpart
# residual function (identity shortcut)
@@ -154,7 +156,7 @@ def res_net_model(features, labels, mode):
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
# Create training op.
- if mode == tf.estimator.ModeKeys.TRAIN:
+ if training:
optimizer = tf.train.AdagradOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)