aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-13 16:21:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-13 17:34:11 -0700
commite12d18cfe417ae610575a8c63a9ff6fd226d5888 (patch)
tree69db12b64297c77c021b5ee692ed090ac5370580
parent1aff68cf7e8beddfd3b2ba9b1156d8def2138ed0 (diff)
Disable partitioning of variable in batch_norm.
assign_moving_average is not supported for partitioned variables, and they're anyway small, so they don't need partitioning. Change: 136107025
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py53
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py8
2 files changed, 38 insertions, 23 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 29c163877d..095de031b9 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -513,29 +513,36 @@ def batch_norm(
trainable=trainable)
# Create moving_mean and moving_variance variables and add them to the
- # appropiate collections.
- moving_mean_collections = utils.get_variable_collections(
- variables_collections, 'moving_mean')
- moving_mean_initializer = param_initializers.get('moving_mean',
- init_ops.zeros_initializer)
- moving_mean = variables.model_variable(
- 'moving_mean',
- shape=params_shape,
- dtype=dtype,
- initializer=moving_mean_initializer,
- trainable=False,
- collections=moving_mean_collections)
- moving_variance_collections = utils.get_variable_collections(
- variables_collections, 'moving_variance')
- moving_variance_initializer = param_initializers.get(
- 'moving_variance', init_ops.ones_initializer)
- moving_variance = variables.model_variable(
- 'moving_variance',
- shape=params_shape,
- dtype=dtype,
- initializer=moving_variance_initializer,
- trainable=False,
- collections=moving_variance_collections)
+ # appropiate collections. We disable variable partitioning while creating
+ # them, because assign_moving_average is not yet supported for partitioned
+ # variables.
+ partitioner = variable_scope.get_variable_scope().partitioner
+ try:
+ variable_scope.get_variable_scope().set_partitioner(None)
+ moving_mean_collections = utils.get_variable_collections(
+ variables_collections, 'moving_mean')
+ moving_mean_initializer = param_initializers.get(
+ 'moving_mean', init_ops.zeros_initializer)
+ moving_mean = variables.model_variable(
+ 'moving_mean',
+ shape=params_shape,
+ dtype=dtype,
+ initializer=moving_mean_initializer,
+ trainable=False,
+ collections=moving_mean_collections)
+ moving_variance_collections = utils.get_variable_collections(
+ variables_collections, 'moving_variance')
+ moving_variance_initializer = param_initializers.get(
+ 'moving_variance', init_ops.ones_initializer)
+ moving_variance = variables.model_variable(
+ 'moving_variance',
+ shape=params_shape,
+ dtype=dtype,
+ initializer=moving_variance_initializer,
+ trainable=False,
+ collections=moving_variance_collections)
+ finally:
+ variable_scope.get_variable_scope().set_partitioner(partitioner)
# If `is_training` doesn't have a constant value, because it is a `Tensor`,
# a `Variable` or `Placeholder` then is_training_value will be None and
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 91375f17af..3cc1acff48 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1550,6 +1550,14 @@ class BatchNormTest(tf.test.TestCase):
self.assertAllClose(mean, expected_mean)
self.assertAllClose(variance, expected_var)
+ def testEvalMovingVarsWithPartitioner(self):
+ # This test makes sure that the moving-mean and moving-variance logic works
+ # when `batch_norm` is called within a variable-scope that has a variable
+ # partitioner.
+ partitioner = tf.fixed_size_partitioner(2, axis=0)
+ with tf.variable_scope(tf.get_variable_scope(), partitioner=partitioner):
+ self.testEvalMovingVars()
+
def _testReuseVars(self, fused):
height, width = 3, 3
batch_size = 10