aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-29 14:27:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-29 15:39:40 -0700
commitef9f5fee0a079f6bed445064e8e9d18fb7a904d8 (patch)
treeb404a7c55a7e3195c1de2f9719520be3b89a23fe
parentbc0a56da15eed8738e8a53e2dd340030332df28a (diff)
Add weighted_moments, and allow batch norm to use it to compute frequency-weighted statistics.
Change: 134717043
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py24
-rw-r--r--tensorflow/python/ops/nn.py79
-rw-r--r--tensorflow/python/ops/nn_batchnorm_test.py139
3 files changed, 220 insertions, 22 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 0ed396e453..dc4ee9226a 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -123,6 +123,7 @@ def batch_norm(inputs,
variables_collections=None,
outputs_collections=None,
trainable=True,
+ batch_weights=None,
scope=None):
"""Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
@@ -171,6 +172,11 @@ def batch_norm(inputs,
outputs_collections: collections to add the outputs.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ batch_weights: An optional tensor of shape `[batch_size]`,
+ containing a frequency weight for each batch item. If present,
+ then the batch normalization uses weighted mean and
+ variance. (This can be used to correct for bias in training
+ example selection.)
scope: Optional scope for `variable_scope`.
Returns:
@@ -187,6 +193,14 @@ def batch_norm(inputs,
if inputs_rank is None:
raise ValueError('Inputs %s has undefined rank.' % inputs.name)
dtype = inputs.dtype.base_dtype
+ if batch_weights is not None:
+ batch_weights = ops.convert_to_tensor(batch_weights)
+ inputs_shape[0:1].assert_is_compatible_with(batch_weights.get_shape())
+
+ # Reshape batch weight values so they broadcast across inputs.
+ nshape = [-1] + [1 for _ in range(inputs_rank - 1)]
+ batch_weights = array_ops.reshape(batch_weights, nshape)
+
axis = list(range(inputs_rank - 1))
params_shape = inputs_shape[-1:]
if not params_shape.is_fully_defined():
@@ -240,9 +254,13 @@ def batch_norm(inputs,
need_moments = is_training_value is None or is_training_value
if need_moments:
# Calculate the moments based on the individual batch.
- # Use a copy of moving_mean as a shift to compute more reliable moments.
- shift = math_ops.add(moving_mean, 0)
- mean, variance = nn.moments(inputs, axis, shift=shift)
+ if batch_weights is None:
+ # Use a copy of moving_mean as a shift to compute more reliable moments.
+ shift = math_ops.add(moving_mean, 0)
+ mean, variance = nn.moments(inputs, axis, shift=shift)
+ else:
+ mean, variance = nn.weighted_moments(inputs, axis, batch_weights)
+
moving_vars_fn = lambda: (moving_mean, moving_variance)
if updates_collections is None:
def _force_updates():
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 958c32f0fc..992e0f6f79 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -188,6 +188,7 @@ have varying scale, and to aid generalization.
@@sufficient_statistics
@@normalize_moments
@@moments
+@@weighted_moments
## Losses
@@ -819,7 +820,7 @@ def moments(x, axes, shift=None, name=None, keep_dims=False):
Args:
x: A `Tensor`.
- axes: array of ints. Axes along which to compute mean and
+ axes: Array of ints. Axes along which to compute mean and
variance.
shift: A `Tensor` containing the value by which to shift the data for
numerical stability, or `None` if no shift is to be performed. A shift
@@ -848,6 +849,82 @@ def moments(x, axes, shift=None, name=None, keep_dims=False):
return (mean, variance)
+def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False):
+ """Returns the frequency-weighted mean and variance of `x`.
+
+ Args:
+ x: A tensor.
+ axes: 1-d tensor of int32 values; these are the axes along which
+ to compute mean and variance.
+ frequency_weights: A tensor of positive weights which can be
+ broadcast with x.
+ name: Name used to scope the operation.
+ keep_dims: Produce moments with the same dimensionality as the input.
+
+ Returns:
+ Two tensors: `weighted_mean` and `weighted_variance`.
+ """
+ with ops.name_scope(name, "weighted_moments", [x, frequency_weights, axes]):
+ x = ops.convert_to_tensor(x, name="x")
+ frequency_weights = ops.convert_to_tensor(
+ frequency_weights, name="frequency_weights")
+
+ # Unlike moments(), this just uses a simpler two-pass method.
+
+ # See comment in moments() WRT precision; it applies here too.
+ needs_cast = x.dtype == dtypes.float16
+ if needs_cast:
+ x = math_ops.cast(x, dtypes.float32)
+
+ if frequency_weights.dtype != x.dtype:
+ frequency_weights = math_ops.cast(frequency_weights, x.dtype)
+
+ # Note that we use keep_dims=True for our reductions regardless of the arg;
+ # this is so that the results remain broadcast-compatible with the inputs.
+ weighted_input_sum = math_ops.reduce_sum(frequency_weights * x,
+ axes,
+ name="weighted_input_sum",
+ keep_dims=True)
+
+ # The shape of the weights isn't necessarily the same as x's
+ # shape, just broadcast-compatible with it -- so this expression
+ # performs broadcasting to give a per-item weight, with the same
+ # shape as (freqency_weights * x). This avoids having to reason
+ # through all the broadcast logic to compute a correct
+ # sum_of_weights.
+ broadcasted_weights = frequency_weights + array_ops.zeros_like(x)
+
+ sum_of_weights = math_ops.reduce_sum(
+ broadcasted_weights,
+ axes,
+ name="sum_of_weights",
+ keep_dims=True)
+
+ divisor = math_ops.inv(sum_of_weights, name="inv_weight_sum")
+
+ weighted_mean = math_ops.mul(weighted_input_sum, divisor)
+
+ # Have the weighted mean; now on to variance:
+ weighted_distsq = math_ops.reduce_sum(
+ frequency_weights * math_ops.squared_difference(x, weighted_mean),
+ axes,
+ name="weighted_distsq",
+ keep_dims=True)
+
+ weighted_variance = math_ops.mul(weighted_distsq, divisor)
+
+ if not keep_dims:
+ weighted_mean = array_ops.squeeze(weighted_mean, squeeze_dims=axes)
+ weighted_variance = array_ops.squeeze(weighted_variance,
+ squeeze_dims=axes)
+
+ if needs_cast:
+ weighted_mean = math_ops.cast(weighted_mean, dtypes.float16)
+ weighted_variance = math_ops.cast(weighted_variance, dtypes.float16)
+
+ return weighted_mean, weighted_variance
+
+
def batch_normalization(x,
mean,
variance,
diff --git a/tensorflow/python/ops/nn_batchnorm_test.py b/tensorflow/python/ops/nn_batchnorm_test.py
index 9ccf331c48..5e928fba56 100644
--- a/tensorflow/python/ops/nn_batchnorm_test.py
+++ b/tensorflow/python/ops/nn_batchnorm_test.py
@@ -420,6 +420,16 @@ class NormalizeMomentsTest(tf.test.TestCase):
class MomentsTest(tf.test.TestCase):
+ def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
+ # Method to compute moments of `x` wrt `axes`.
+ #
+ # This is exposed so WeightedMomentsTest can inherit the tests and
+ # assertions from MomentsTest; the extra_out_grads argument allows
+ # its inherited gradient tests to assert gradients against the
+ # weights as well as the input values.
+
+ return tf.nn.moments(x, axes, keep_dims=keep_dims)
+
def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
with self.test_session():
# shape = [batch, width, height, depth]
@@ -428,7 +438,7 @@ class MomentsTest(tf.test.TestCase):
x_numpy = np.random.normal(size=shape).astype(np.float32)
x = tf.placeholder(dtype, shape=[None] * len(shape))
- mean, var = tf.nn.moments(x, axes, keep_dims=keep_dims)
+ mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
num_elements = np.prod([shape[i] for i in axes])
@@ -456,7 +466,11 @@ class MomentsTest(tf.test.TestCase):
x_numpy = np.random.normal(size=shape).astype(np.float32)
x = tf.cast(tf.constant(x_numpy), dtype=dtype)
- mean, var = tf.nn.moments(x, axes, keep_dims=keep_dims)
+ # Compute the expected values at high precision since the method
+ # is prone to catastrophic cancellation:
+ x_numpy = x_numpy.astype(np.float128)
+
+ mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
num_elements = np.prod([shape[i] for i in axes])
@@ -519,14 +533,21 @@ class MomentsTest(tf.test.TestCase):
axes = [0, 1, 2]
y_shape = [2] # Depth of x
- out_mean, out_var = tf.nn.moments(x, axes)
+
+ inputs_to_compute_gradients_for = [x]
+
+ out_mean, out_var = self._unweighted_moments(
+ x, axes, extra_out_grads=inputs_to_compute_gradients_for)
if from_y == "mean":
y = out_mean
elif from_y == "var":
y = out_var
- err = tf.test.compute_gradient_error(x, x_shape, y, y_shape)
- print("Moments %s gradient err = %g" % (from_y, err))
- self.assertLess(err, 1e-11)
+
+ for (i, v) in enumerate(inputs_to_compute_gradients_for):
+ err = tf.test.compute_gradient_error(v, v.get_shape().as_list(),
+ y, y_shape)
+ print("Moments %s gradient err vs input %d = %g" % (from_y, i, err))
+ self.assertLess(err, 1e-11)
def testMeanGlobalGradient(self):
self._testGlobalGradient(from_y="mean")
@@ -534,19 +555,101 @@ class MomentsTest(tf.test.TestCase):
def testVarGlobalGradient(self):
self._testGlobalGradient(from_y="var")
- def testOutputNamesNoKeep(self):
- """Make sure the output names are stable."""
- with self.test_session():
- mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=False)
- self.assertEquals(mean.op.name, "moments/normalize/mean")
- self.assertEquals(var.op.name, "moments/normalize/variance")
- def testOutputNamesKeep(self):
- """Make sure the output names are stable."""
- with self.test_session():
- mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=True)
- self.assertEquals(mean.op.name, "moments/normalize/mean")
- self.assertEquals(var.op.name, "moments/normalize/variance")
+class WeightedMomentsTest(MomentsTest):
+ """Tests for nn.weighted_moments.
+
+ Note that this test inherits from MomentsTest, inheriting all its
+ test methods!
+
+ It modifies MomentsTest in two ways:
+
+ a) By overriding _unweighted_moments, all the codepaths in
+ MomentsTest are executed, but with calls to tf.nn.moments()
+ replaced by calls to tf.nn.weighted_moments() with a constant
+ weight of 1.
+
+ b) By overriding RunMomentTest and RunMomentTestWithDynamicShape,
+ this test adds multiple additional calls to
+ RunWeightedMomentsTest() to exercise correctness with
+ non-constant weights and varying broadcasting situations. (It
+ also continues to call MomentsTest.Run(Weighted)?MomentsTest as
+ well.)
+
+ """
+
+ def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
+ weights = tf.constant(1, dtype=x.dtype)
+ if extra_out_grads is not None:
+ # We want to assert gradients WRT weights as well as X!
+ extra_out_grads.append(weights)
+ return tf.nn.weighted_moments(
+ x, axes, weights, keep_dims=keep_dims)
+
+ def RunMomentTest(self, shape, axes, keep_dims, dtype, dynshapes=False):
+ if not dynshapes:
+ super(WeightedMomentsTest, self).RunMomentTest(
+ shape, axes, keep_dims, dtype)
+ else:
+ super(WeightedMomentsTest, self).RunMomentTestWithDynamicShape(
+ shape, axes, keep_dims, dtype)
+
+ # 1:1 weights and inputs
+ self.RunWeightedMomentTest(shape, shape, axes, keep_dims, dtype)
+
+ # Various broadcasting combinations
+ for idx in range(len(shape)):
+ # try broadcasting weights in all positions
+ weight_shape = [1] * len(shape)
+ weight_shape[idx] = shape[idx]
+
+ self.RunWeightedMomentTest(shape, weight_shape, axes, keep_dims, dtype)
+
+ # Also try broadcasting with a suffix of length n
+ weight_shape = shape[-(idx+1):]
+ self.RunWeightedMomentTest(
+ shape, weight_shape, axes, keep_dims, dtype, dynshapes=dynshapes)
+
+ def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
+ self.RunMomentTest(shape, axes, keep_dims, dtype, dynshapes=True)
+
+ def RunWeightedMomentTest(
+ self, shape, weights_shape, axes, keep_dims, dtype, dynshapes=False):
+ with self.test_session() as s:
+ x_numpy = np.random.normal(size=shape).astype(np.float32)
+ weights_numpy = np.absolute( # weights must be positive
+ np.random.normal(size=weights_shape, loc=1.0).astype(np.float32))
+
+ # Expand the numpy version to higher precision
+ x_numpy = x_numpy.astype(np.float128)
+ weights_numpy = weights_numpy.astype(np.float128)
+
+ x_shape = [None] * len(shape) if dynshapes else shape
+ weights_shape = (
+ [None] * len(weights_shape) if dynshapes else weights_shape)
+
+ x = tf.placeholder(dtype, shape=x_shape)
+ weights = tf.placeholder(dtype, shape=weights_shape)
+
+ mean, var = tf.nn.weighted_moments(x, axes, weights, keep_dims=keep_dims)
+
+ ax = tuple(axes)
+
+ def _np_weighted_sum(v):
+ return np.sum(weights_numpy * v, axis=ax, keepdims=keep_dims)
+
+ weight_sum = _np_weighted_sum(np.ones_like(x_numpy))
+ expected_mean = _np_weighted_sum(x_numpy) / weight_sum
+ expected_mean_squared = np.multiply(expected_mean, expected_mean)
+ expected_x_squared = (
+ _np_weighted_sum(np.multiply(x_numpy, x_numpy)) / weight_sum)
+ expected_variance = expected_x_squared - expected_mean_squared
+
+ mean_v, var_v = s.run([mean, var],
+ feed_dict={x: x_numpy, weights: weights_numpy})
+
+ self.assertAllCloseAccordingToType(expected_mean, mean_v)
+ self.assertAllCloseAccordingToType(expected_variance, var_v)
if __name__ == "__main__":