aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Nathan Silberman <nsilberman@google.com>2016-11-22 10:27:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-22 10:43:35 -0800
commitcff53408cb8bb5072ba5d5847a50b1cde31af82f (patch)
tree3d156a129f0c9b33715b0018f8b00d12ad8a5f61
parent824c9f11fcd1f3b1b341a5bb830d2b463c8fbb15 (diff)
Fixing bug in which the 'trainable' argument wasn't being passed to the bias in convolution2d_transpose.
Change: 139925626
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py24
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py66
2 files changed, 53 insertions, 37 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 123cdb1d97..d09b57840d 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -220,14 +220,16 @@ def _fused_batch_norm(
if original_rank is None:
raise ValueError('Inputs %s has undefined rank' % inputs.name)
elif original_rank not in [2, 4]:
- raise ValueError('Inputs %s has unsupported rank. \
- Expected 2 or 4 but got %d' % (inputs.name, original_rank))
+ raise ValueError('Inputs %s has unsupported rank.'
+ ' Expected 2 or 4 but got %d' % (
+ inputs.name, original_rank))
if original_rank == 2:
channels = inputs.get_shape()[-1].value
if channels is None:
raise ValueError('`C` dimension must be known but is None')
- new_shape = [-1, channels, 1, 1] if data_format == DATA_FORMAT_NCHW else \
- [-1, 1, 1, channels]
+ new_shape = [-1, 1, 1, channels]
+ if data_format == DATA_FORMAT_NCHW:
+ new_shape = [-1, channels, 1, 1]
inputs = array_ops.reshape(inputs, new_shape)
inputs_shape = inputs.get_shape()
dtype = inputs.dtype.base_dtype
@@ -316,7 +318,7 @@ def _fused_batch_norm(
need_updates = is_training_value is None or is_training_value
if need_updates:
if updates_collections is None:
- _no_updates = lambda: outputs
+ no_updates = lambda: outputs
def _force_updates():
"""Internal function forces updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average(
@@ -326,7 +328,7 @@ def _fused_batch_norm(
with ops.control_dependencies(
[update_moving_mean, update_moving_variance]):
return array_ops.identity(outputs)
- outputs = utils.smart_cond(is_training, _force_updates, _no_updates)
+ outputs = utils.smart_cond(is_training, _force_updates, no_updates)
else:
moving_vars_fn = lambda: (moving_mean, moving_variance)
def _delay_updates():
@@ -684,7 +686,7 @@ def bias_add(inputs,
raise ValueError('Dims of shape must be known but is None')
elif inputs_rank != 4 and data_format == DATA_FORMAT_NCHW:
raise ValueError('Data format NCHW only supports 4D Tensor')
- axis = 1 if data_format==DATA_FORMAT_NCHW else -1
+ axis = 1 if data_format == DATA_FORMAT_NCHW else -1
num_features = inputs_shape[axis].value
if num_features is None:
raise ValueError('`C` dimension must be known but is None')
@@ -1081,7 +1083,6 @@ def convolution2d_transpose(
output_shape = [batch_size, num_outputs, out_height, out_width]
strides = [1, 1, stride_h, stride_w]
-
output_shape = array_ops.pack(output_shape)
outputs = nn.conv2d_transpose(inputs, weights, output_shape,
strides,
@@ -1091,8 +1092,10 @@ def convolution2d_transpose(
# Infer the static output shape:
out_shape = inputs.get_shape().as_list()
out_shape[c_axis] = num_outputs
- out_shape[h_axis] = get_deconv_dim(out_shape[h_axis], stride_h, kernel_h, padding)
- out_shape[w_axis] = get_deconv_dim(out_shape[w_axis], stride_w, kernel_w, padding)
+ out_shape[h_axis] = get_deconv_dim(
+ out_shape[h_axis], stride_h, kernel_h, padding)
+ out_shape[w_axis] = get_deconv_dim(
+ out_shape[w_axis], stride_w, kernel_w, padding)
outputs.set_shape(out_shape)
if normalizer_fn is not None:
@@ -1107,6 +1110,7 @@ def convolution2d_transpose(
dtype=dtype,
initializer=biases_initializer,
regularizer=biases_regularizer,
+ trainable=trainable,
collections=biases_collections)
outputs = nn.bias_add(outputs, biases, data_format=data_format)
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index fb96f19744..9f3e4e936d 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -23,9 +23,8 @@ import numpy as np
import tensorflow as tf
# TODO(sguada) Expose tf.with_dependencies
-from tensorflow.python.ops import control_flow_ops
from tensorflow.contrib.layers.python.layers import layers as _layers
-from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import control_flow_ops
class AvgPool2DTest(tf.test.TestCase):
@@ -181,7 +180,8 @@ class PoolTest(tf.test.TestCase):
height, width = 5, 8
images = tf.random_uniform((5, 3, height, width), seed=1)
output = tf.contrib.layers.pool(
- images, [2, 3], dilation_rate=[1, 2], pooling_type='AVG', data_format='NCHW')
+ images, [2, 3], dilation_rate=[1, 2], pooling_type='AVG',
+ data_format='NCHW')
self.assertEqual(output.get_shape().as_list(), [5, 3, 4, 4])
@@ -370,7 +370,7 @@ class ConvolutionTest(tf.test.TestCase):
tf.contrib.framework.get_variables_by_name('weights')[0])
wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
self.assertEqual(wd.op.name,
- 'Conv/weights/Regularizer/l2_regularizer')
+ 'Conv/weights/Regularizer/l2_regularizer')
sess.run(tf.global_variables_initializer())
self.assertAlmostEqual(sess.run(wd), weight_decay * l2_loss.eval())
@@ -588,6 +588,20 @@ class ConvolutionTest(tf.test.TestCase):
class Convolution2dTransposeTests(tf.test.TestCase):
+ def testTrainableFlagIsPassedOn(self):
+ for trainable in [True, False]:
+ with tf.Graph().as_default():
+ num_filters = 32
+ input_size = [5, 10, 12, 3]
+
+ images = tf.random_uniform(input_size, seed=1)
+ tf.contrib.layers.conv2d_transpose(
+ images, num_filters, [3, 3], stride=1, trainable=trainable)
+ model_variables = tf.contrib.framework.get_model_variables()
+ trainable_variables = tf.trainable_variables()
+ for model_variable in model_variables:
+ self.assertEqual(trainable, model_variable in trainable_variables)
+
def testInvalidDataFormat(self):
height, width = 7, 9
with self.test_session():
@@ -597,7 +611,6 @@ class Convolution2dTransposeTests(tf.test.TestCase):
tf.contrib.layers.convolution2d_transpose(
images, 32, 3, data_format='CHWN')
-
def testOutputSizeWithStrideOneSamePaddingNCHW(self):
# `NCHW` data fomat is only supported for `GPU` device.
if tf.test.is_gpu_available():
@@ -615,7 +628,6 @@ class Convolution2dTransposeTests(tf.test.TestCase):
sess.run(tf.global_variables_initializer())
self.assertListEqual(list(output.eval().shape), expected_size)
-
def testOutputSizeWithStrideOneValidPaddingNCHW(self):
if tf.test.is_gpu_available():
with self.test_session(use_gpu=True) as sess:
@@ -756,7 +768,6 @@ class Convolution2dTransposeTests(tf.test.TestCase):
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
-
def testOutputSizeWithStrideOneSamePadding(self):
num_filters = 32
input_size = [5, 10, 12, 3]
@@ -1284,7 +1295,7 @@ class FlattenTest(tf.test.TestCase):
images = tf.random_uniform((5, height, width, 3), seed=1, name='images')
output = tf.contrib.layers.flatten(images)
self.assertEqual(output.get_shape().num_elements(),
- images.get_shape().num_elements())
+ images.get_shape().num_elements())
self.assertEqual(output.get_shape()[0], images.get_shape()[0])
def testFlatten3D(self):
@@ -1293,7 +1304,7 @@ class FlattenTest(tf.test.TestCase):
images = tf.random_uniform((5, height, width), seed=1, name='images')
output = tf.contrib.layers.flatten(images)
self.assertEqual(output.get_shape().num_elements(),
- images.get_shape().num_elements())
+ images.get_shape().num_elements())
self.assertEqual(output.get_shape()[0], images.get_shape()[0])
def testFlattenBatchSize(self):
@@ -1303,10 +1314,10 @@ class FlattenTest(tf.test.TestCase):
inputs = tf.placeholder(tf.int32, (None, height, width, 3))
output = tf.contrib.layers.flatten(inputs)
self.assertEqual(output.get_shape().as_list(),
- [None, height * width * 3])
+ [None, height * width * 3])
output = sess.run(output, {inputs: images.eval()})
self.assertEqual(output.size,
- images.get_shape().num_elements())
+ images.get_shape().num_elements())
self.assertEqual(output.shape[0], images.get_shape()[0])
@@ -1463,7 +1474,7 @@ class FCTest(tf.test.TestCase):
weights_regularizer=weight_decay)
wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
self.assertEqual(wd.op.name,
- 'fully_connected/weights/Regularizer/l2_regularizer')
+ 'fully_connected/weights/Regularizer/l2_regularizer')
sess.run(tf.global_variables_initializer())
self.assertLess(sess.run(wd), 0.4)
@@ -1620,9 +1631,9 @@ class BatchNormTest(tf.test.TestCase):
update_moving_mean = update_layers[0]
update_moving_variance = update_layers[1]
self.assertEqual(update_moving_mean.op.name,
- 'BatchNorm/AssignMovingAvg')
+ 'BatchNorm/AssignMovingAvg')
self.assertEqual(update_moving_variance.op.name,
- 'BatchNorm/AssignMovingAvg_1')
+ 'BatchNorm/AssignMovingAvg_1')
def testReuseVariables(self):
height, width = 3, 3
@@ -1774,8 +1785,8 @@ class BatchNormTest(tf.test.TestCase):
if fused:
# Add Bessel's correction
moving_variance_corrected = moving_variance / correction_factor
- correct_moving_variance = state_ops.assign(moving_variance,
- moving_variance_corrected)
+ correct_moving_variance = tf.assign(moving_variance,
+ moving_variance_corrected)
sess.run(correct_moving_variance)
self.assertAllClose(variance, expected_var)
@@ -1888,8 +1899,8 @@ class BatchNormTest(tf.test.TestCase):
if fused:
# Add Bessel's correction
moving_variance_corrected = moving_variance / correction_factor
- correct_moving_variance = state_ops.assign(moving_variance,
- moving_variance_corrected)
+ correct_moving_variance = tf.assign(moving_variance,
+ moving_variance_corrected)
sess.run(correct_moving_variance)
self.assertAllClose(variance, expected_var)
# After convergence output_train and output_eval should be the same.
@@ -1961,8 +1972,8 @@ class BatchNormTest(tf.test.TestCase):
if fused:
# Add Bessel's correction
moving_variance_corrected = moving_variance / correction_factor
- correct_moving_variance = state_ops.assign(moving_variance,
- moving_variance_corrected)
+ correct_moving_variance = tf.assign(moving_variance,
+ moving_variance_corrected)
sess.run(correct_moving_variance)
output_false = sess.run([output], {is_training: False})
self.assertAllClose(output_true, output_false)
@@ -2100,8 +2111,8 @@ class BatchNormTest(tf.test.TestCase):
if fused:
# Add Bessel's correction
moving_variance_corrected = moving_variance / correction_factor
- correct_moving_variance = state_ops.assign(moving_variance,
- moving_variance_corrected)
+ correct_moving_variance = tf.assign(moving_variance,
+ moving_variance_corrected)
sess.run(correct_moving_variance)
output_false = sess.run([output], {is_training: False})
self.assertTrue(np.allclose(output_true, output_false))
@@ -2212,10 +2223,10 @@ class BatchNormTest(tf.test.TestCase):
scale=True,
epsilon=0.0,
param_initializers={
- 'beta': beta,
- 'gamma': gamma,
- 'moving_mean': mean,
- 'moving_variance': variance,
+ 'beta': beta,
+ 'gamma': gamma,
+ 'moving_mean': mean,
+ 'moving_variance': variance,
})
sess.run(tf.global_variables_initializer())
outs = sess.run(output)
@@ -2358,6 +2369,7 @@ class LayerNormTest(tf.test.TestCase):
def testOutput4DInput(self):
self.doOutputTest((100, 10, 10, 3))
+
class MaxPool2DTest(tf.test.TestCase):
def testInvalidDataFormat(self):
@@ -2974,7 +2986,7 @@ class LegacyFullyConnectedTest(tf.test.TestCase):
self.assertEqual(1, len(tf.get_collection('unbiased')))
self.assertEqual(1, len(tf.get_collection('biased')))
self.assertEqual(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES),
- tf.get_collection('all'))
+ tf.get_collection('all'))
def test_no_bias(self):
tf.contrib.layers.legacy_relu(self.input, 2, bias_init=None)