diff options
author | 2016-10-13 16:21:42 -0800 | |
---|---|---|
committer | 2016-10-13 17:34:11 -0700 | |
commit | e12d18cfe417ae610575a8c63a9ff6fd226d5888 (patch) | |
tree | 69db12b64297c77c021b5ee692ed090ac5370580 | |
parent | 1aff68cf7e8beddfd3b2ba9b1156d8def2138ed0 (diff) |
Disable partitioning of variable in batch_norm.
assign_moving_average is not supported for partitioned variables, and they're anyway small, so they don't need partitioning.
Change: 136107025
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers.py | 53 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers_test.py | 8 |
2 files changed, 38 insertions, 23 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 29c163877d..095de031b9 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -513,29 +513,36 @@ def batch_norm( trainable=trainable) # Create moving_mean and moving_variance variables and add them to the - # appropiate collections. - moving_mean_collections = utils.get_variable_collections( - variables_collections, 'moving_mean') - moving_mean_initializer = param_initializers.get('moving_mean', - init_ops.zeros_initializer) - moving_mean = variables.model_variable( - 'moving_mean', - shape=params_shape, - dtype=dtype, - initializer=moving_mean_initializer, - trainable=False, - collections=moving_mean_collections) - moving_variance_collections = utils.get_variable_collections( - variables_collections, 'moving_variance') - moving_variance_initializer = param_initializers.get( - 'moving_variance', init_ops.ones_initializer) - moving_variance = variables.model_variable( - 'moving_variance', - shape=params_shape, - dtype=dtype, - initializer=moving_variance_initializer, - trainable=False, - collections=moving_variance_collections) + # appropiate collections. We disable variable partitioning while creating + # them, because assign_moving_average is not yet supported for partitioned + # variables. + partitioner = variable_scope.get_variable_scope().partitioner + try: + variable_scope.get_variable_scope().set_partitioner(None) + moving_mean_collections = utils.get_variable_collections( + variables_collections, 'moving_mean') + moving_mean_initializer = param_initializers.get( + 'moving_mean', init_ops.zeros_initializer) + moving_mean = variables.model_variable( + 'moving_mean', + shape=params_shape, + dtype=dtype, + initializer=moving_mean_initializer, + trainable=False, + collections=moving_mean_collections) + moving_variance_collections = utils.get_variable_collections( + variables_collections, 'moving_variance') + moving_variance_initializer = param_initializers.get( + 'moving_variance', init_ops.ones_initializer) + moving_variance = variables.model_variable( + 'moving_variance', + shape=params_shape, + dtype=dtype, + initializer=moving_variance_initializer, + trainable=False, + collections=moving_variance_collections) + finally: + variable_scope.get_variable_scope().set_partitioner(partitioner) # If `is_training` doesn't have a constant value, because it is a `Tensor`, # a `Variable` or `Placeholder` then is_training_value will be None and diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 91375f17af..3cc1acff48 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1550,6 +1550,14 @@ class BatchNormTest(tf.test.TestCase): self.assertAllClose(mean, expected_mean) self.assertAllClose(variance, expected_var) + def testEvalMovingVarsWithPartitioner(self): + # This test makes sure that the moving-mean and moving-variance logic works + # when `batch_norm` is called within a variable-scope that has a variable + # partitioner. + partitioner = tf.fixed_size_partitioner(2, axis=0) + with tf.variable_scope(tf.get_variable_scope(), partitioner=partitioner): + self.testEvalMovingVars() + def _testReuseVars(self, fused): height, width = 3, 3 batch_size = 10 |