diff options
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers.py | 20 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers_test.py | 20 |
2 files changed, 38 insertions, 2 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 29ab281b1a..deeafdf300 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -463,7 +463,8 @@ def batch_norm(inputs, scope=None, renorm=False, renorm_clipping=None, - renorm_decay=0.99): + renorm_decay=0.99, + adjustment=None): """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing @@ -546,6 +547,17 @@ def batch_norm(inputs, and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `decay` is still applied to get the means and variances for inference. + adjustment: A function taking the `Tensor` containing the (dynamic) shape of + the input tensor and returning a pair (scale, bias) to apply to the + normalized values (before gamma and beta), only during training. For + example, + `adjustment = lambda shape: ( + tf.random_uniform(shape[-1:], 0.93, 1.07), + tf.random_uniform(shape[-1:], -0.1, 0.1))` + will scale the normalized value by up to 7% up or down, then shift the + result by up to 0.1 (with independent scaling and bias for each feature + but shared across all examples), and finally apply gamma and/or beta. If + `None`, no adjustment is applied. Returns: A `Tensor` representing the output of the operation. @@ -569,7 +581,10 @@ def batch_norm(inputs, # implementation in normalization_layers.BatchNormalization. inputs = ops.convert_to_tensor(inputs) rank = inputs.get_shape().ndims - possible_to_fuse = batch_weights is None and not renorm and rank in [2, 4] + possible_to_fuse = (batch_weights is None and + not renorm and + rank in [2, 4] and + adjustment is None) if fused and possible_to_fuse and ( zero_debias_moving_mean or rank == 2 or updates_collections is not ops.GraphKeys.UPDATE_OPS): @@ -636,6 +651,7 @@ def batch_norm(inputs, renorm=renorm, renorm_clipping=renorm_clipping, renorm_momentum=renorm_decay, + adjustment=adjustment, name=sc.name, _scope=sc, _reuse=reuse, diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 1040ad3ca7..7c77e905f7 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -2644,6 +2644,26 @@ class BatchNormTest(test.TestCase): zero_debias_moving_mean=True) sess.run(variables_lib.global_variables_initializer()) + def testAdjustmentCreated(self): + # Tests that the adjustment is appropriately passed to and used by the core + # BN layer. + all_adjustments = [] + def _create_adjustment(shape): + adjustments = [array_ops.ones(shape[-1:]), array_ops.zeros(shape[-1:])] + all_adjustments.extend(adjustments) + return adjustments + depth = 8 + images = array_ops.zeros([10, 5, 5, depth]) + output = _layers.batch_norm( + images, + is_training=True, + adjustment=_create_adjustment) + self.assertListEqual(output.shape.as_list(), images.shape.as_list()) + self.assertEqual(len(all_adjustments), 2) + self.assertListEqual(all_adjustments[0].shape.as_list(), [depth]) + self.assertListEqual(all_adjustments[1].shape.as_list(), [depth]) + + class LayerNormTest(test.TestCase): def testUnknownShape(self): |