aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vincent Vanhoucke <vanhoucke@google.com>2016-02-24 16:27:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-25 09:01:18 -0800
commitbce6216610d57f8f4b1e9e79836737df109c4e42 (patch)
tree05e502a0ec55c4a275e657a458b32105db6e6040
parent2cc5ed87e3308f37cd1eaacc58992367f679e69f (diff)
Switch nn.moments() to using a one-pass stable algorithm.
Helps with: https://github.com/tensorflow/tensorflow/issues/917 Also fixes https://github.com/tensorflow/tensorflow/issues/1162 The main benefit is that the computation of the sufficient statistics is now decoupled of the aggregation of the moments, which means that if you want to perform the accumulation incrementally, you don't have to keep all the inputs around, and can instead keep the much more compact sum and sum-of-squares. Accumulation could also be performed locally if you aggregate across multiple devices. Computing sum and sum-of-squares can also theoretically be performed in parallel now. Tested running inception: same performance, same step time. Batch normalization benchmark is a bit faster on CPU, a bit slower on GPU: Before: cpu shape:4/3 #layers:10 mode:py scale:True train:False - 1.139310 secs gpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.021970 secs cpu shape:4/3 #layers:10 mode:py scale:True train:True - 2.767147 secs gpu shape:4/3 #layers:10 mode:py scale:True train:True - 0.074531 secs cpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.742835 secs gpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.013473 secs cpu shape:4/3 #layers:10 mode:py scale:True train:True - 1.738806 secs gpu shape:4/3 #layers:10 mode:py scale:True train:True - 0.052777 secs cpu shape:2/1 #layers:10 mode:py scale:True train:False - 0.119180 secs gpu shape:2/1 #layers:10 mode:py scale:True train:False - 0.011201 secs cpu shape:2/1 #layers:10 mode:py scale:True train:True - 0.218297 secs gpu shape:2/1 #layers:10 mode:py scale:True train:True - 0.048526 secs After: cpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.998944 secs gpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.025828 secs cpu shape:4/3 #layers:10 mode:py scale:True train:True - 2.657428 secs gpu shape:4/3 #layers:10 mode:py scale:True train:True - 0.086614 secs cpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.603137 secs gpu shape:4/3 #layers:10 mode:py scale:True train:False - 0.017668 secs cpu shape:4/3 #layers:10 mode:py scale:True train:True - 1.519533 secs gpu shape:4/3 #layers:10 mode:py scale:True train:True - 0.055214 secs cpu shape:2/1 #layers:10 mode:py scale:True train:False - 0.071344 secs gpu shape:2/1 #layers:10 mode:py scale:True train:False - 0.016440 secs cpu shape:2/1 #layers:10 mode:py scale:True train:True - 0.222093 secs gpu shape:2/1 #layers:10 mode:py scale:True train:True - 0.039967 secs Change: 115507032
-rw-r--r--tensorflow/python/ops/nn.py136
-rw-r--r--tensorflow/python/ops/nn_test.py130
2 files changed, 229 insertions, 37 deletions
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index ecff8241a6..bc5ba95348 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -134,6 +134,8 @@ have varying scale, and to aid generalization.
@@l2_normalize
@@local_response_normalization
+@@sufficient_statistics
+@@aggregate_moments
@@moments
## Losses
@@ -495,6 +497,101 @@ def separable_conv2d(input, depthwise_filter, pointwise_filter, strides,
padding="VALID", name=name)
+def sufficient_statistics(x, axes, shift=True, keep_dims=False, name=None):
+ """Calculate the sufficient statistics for the mean and variance of `x`.
+
+ These sufficient statistics are computed using the one pass algorithm on
+ an input that's optionally shifted using the value of the 1st element in `x`.
+ See:
+ https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
+
+ Args:
+ x: A `Tensor`.
+ axes: Array of ints. Axes along which to compute mean and variance.
+ shift: If true, shift the data to provide more numerically stable results.
+ keep_dims: produce statistics with the same dimensionality as the input.
+ name: Name used to scope the operations that compute the sufficient stats.
+
+ Returns:
+ Four `Tensor` objects of the same type as `x`:
+ * the count (number of elements to average over).
+ * the (possibly shifted) sum of the elements in the array.
+ * the (possibly shifted) sum of squares of the elements in the array.
+ * the shift by which the mean must be corrected or None if `shift` is False.
+ """
+ with ops.op_scope([x, axes], name, "sufficient_statistics"):
+ x = ops.convert_to_tensor(x, name="x")
+ x_shape = x.get_shape()
+ if x_shape.is_fully_defined():
+ counts = 1
+ m_shape = []
+ for d in xrange(x_shape.ndims):
+ dim = x_shape[d].value
+ if d in set(axes):
+ counts *= dim
+ dim = 1
+ m_shape.append(dim)
+ counts = constant_op.constant(counts, dtype=x.dtype)
+ else: # shape needs to be inferred at runtime.
+ x_shape = array_ops.shape(x)
+ select_axes = sparse_ops.sparse_to_dense(axes, array_ops.shape(x_shape),
+ True, False)
+ m_shape = math_ops.select(select_axes, array_ops.ones_like(x_shape),
+ x_shape)
+ counts = math_ops.cast(
+ math_ops.reduce_prod(x_shape / m_shape),
+ x.dtype,
+ name="count")
+ if shift:
+ shift_value = array_ops.slice(x, array_ops.zeros_like(m_shape), m_shape)
+ m_ss = math_ops.sub(x, shift_value)
+ v_ss = math_ops.squared_difference(x, shift_value)
+ if keep_dims:
+ shift_value = array_ops.identity(shift_value, name="shift")
+ else:
+ shift_value = array_ops.squeeze(shift_value,
+ squeeze_dims=axes,
+ name="shift")
+ else: # not shift.
+ m_ss = x
+ v_ss = math_ops.square(x)
+ shift_value = None
+ m_ss = math_ops.reduce_sum(m_ss, axes, keep_dims=keep_dims, name="mean_ss")
+ v_ss = math_ops.reduce_sum(v_ss, axes, keep_dims=keep_dims, name="var_ss")
+ return counts, m_ss, v_ss, shift_value
+
+
+def aggregate_moments(counts, mean_ss, variance_ss, shift, name=None):
+ """Calculate the mean and variance of based on the sufficient statistics.
+
+ Args:
+ counts: A `Tensor` containing a the total count of the data (one value).
+ mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly
+ shifted) sum of the elements to average over.
+ variance_ss: A `Tensor` containing the variance sufficient statistics: the
+ (possibly shifted) squared sum of the data to compute the variance over.
+ shift: A `Tensor` containing the value by which the data is shifted for
+ numerical stability, or `None` if no shift was performed.
+ name: Name used to scope the operations that compute the moments.
+
+ Returns:
+ Two `Tensor` objects: `mean` and `variance`.
+ """
+ with ops.op_scope([counts, mean_ss, variance_ss, shift], name, "aggregate"):
+ divisor = math_ops.inv(counts, name="divisor")
+ if shift is not None:
+ shifted_mean = math_ops.mul(mean_ss, divisor, name="shifted_mean")
+ mean = math_ops.add(shifted_mean, shift, name="mean")
+ else: # no shift.
+ shifted_mean = math_ops.mul(mean_ss, divisor, name="mean")
+ mean = shifted_mean
+ variance = math_ops.sub(
+ math_ops.mul(variance_ss, divisor),
+ math_ops.square(shifted_mean),
+ name="variance")
+ return (mean, variance)
+
+
def moments(x, axes, name=None, keep_dims=False):
"""Calculate the mean and variance of `x`.
@@ -519,40 +616,11 @@ def moments(x, axes, name=None, keep_dims=False):
Two `Tensor` objects: `mean` and `variance`.
"""
with ops.op_scope([x, axes], name, "moments"):
- x = ops.convert_to_tensor(x, name="x")
- x_shape = x.get_shape()
- if all(x_shape[d].value is not None for d in axes):
- # The shape is known in the relevant axes, so we can statically
- # compute the divisor.
- divisor = 1.0
- for d in set(axes):
- divisor *= x.get_shape()[d].value
- divisor = constant_op.constant(1.0 / divisor, x.dtype, name="divisor")
- else:
- divisor = constant_op.constant(1.0, dtype=x.dtype)
- x_dynamic_shape = array_ops.shape(x)
- for d in set(axes):
- divisor *= math_ops.cast(x_dynamic_shape[d], x.dtype)
- divisor = math_ops.inv(divisor, name="divisor")
- constant_axes = constant_op.constant(axes, name="axes")
- # Note: We do not use Mean here because it is very slow on GPU.
- mean = math_ops.mul(
- math_ops.reduce_sum(x,
- constant_axes,
- keep_dims=True),
- divisor,
- name="mean")
- var = math_ops.mul(
- math_ops.reduce_sum(
- math_ops.squared_difference(x, mean),
- constant_axes,
- keep_dims=keep_dims),
- divisor,
- name="variance")
- if keep_dims:
- return mean, var
- else:
- return array_ops.squeeze(mean, squeeze_dims=axes), var
+ counts, m_ss, v_ss, shift = sufficient_statistics(x,
+ axes,
+ keep_dims=keep_dims,
+ name=name)
+ return aggregate_moments(counts, m_ss, v_ss, shift, name=name)
def batch_normalization(x,
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 30c866e6a4..317a074830 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -476,7 +476,7 @@ class DropoutTest(tf.test.TestCase):
_ = tf.nn.dropout(t, keep_prob, noise_shape=[1, 1])
-class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
+class BatchNormalizationTest(tf.test.TestCase):
def _npBatchNorm(self, x, m, v, beta, gamma, epsilon,
scale_after_normalization, shift_after_normalization):
@@ -670,8 +670,7 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
else:
all_grads = sess.run([dx, dm, dv, db, odx, odm, odv, odb])
to_check = ["dx", "dm", "dv", "db"]
- for i, n in enumerate(to_check):
- print(n)
+ for i, _ in enumerate(to_check):
self.assertAllClose(
all_grads[i + len(to_check)], all_grads[i], atol=0.000001)
@@ -759,6 +758,117 @@ class BatchNormWithGlobalNormalizationTest(tf.test.TestCase):
atol=0.005)
+class SufficientStatisticsTest(tf.test.TestCase):
+
+ def _npSuffStats(self, x, axes, shift, keep_dims):
+ axis = tuple(axes)
+ if shift:
+ shift_value = x[[slice(None) if i not in set(axis) else slice(0, 1)
+ for i in xrange(x.ndim)]]
+ m_ss = np.sum(x - shift_value, axis=axis, keepdims=keep_dims)
+ v_ss = np.sum(
+ (x - shift_value) * (x - shift_value),
+ axis=axis,
+ keepdims=keep_dims)
+ else:
+ shift_value = None
+ m_ss = np.sum(x, axis=axis, keepdims=keep_dims)
+ v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims)
+ count = 1.0
+ for d in xrange(x.ndim):
+ if d in set(axes):
+ count *= x.shape[d]
+ if not keep_dims:
+ shift_value = np.squeeze(shift_value, axis=axis)
+ return count, m_ss, v_ss, shift_value
+
+ def _opSuffStats(self, x, axes, shift, keep_dims):
+ return tf.nn.sufficient_statistics(x, axes, shift, keep_dims)
+
+ def _testSuffStats(self, x_shape, axes, shift, keep_dims, has_shape):
+ x_val = np.random.random_sample(x_shape).astype(np.float32)
+ np_c, np_m, np_v, np_s = self._npSuffStats(x_val, axes, shift, keep_dims)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu) as sess:
+ if has_shape:
+ x = tf.constant(x_val, name="x")
+ x.set_shape(x_shape)
+ op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims)
+ if shift:
+ tf_c, tf_m, tf_v, tf_s = sess.run([op_c, op_m, op_v, op_s])
+ else:
+ tf_c, tf_m, tf_v = sess.run([op_c, op_m, op_v])
+ else:
+ x = tf.placeholder(dtype=tf.float32,
+ shape=[None] * len(x_shape),
+ name="x")
+ op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims)
+ if shift:
+ tf_c, tf_m, tf_v, tf_s = sess.run(
+ [op_c, op_m, op_v, op_s],
+ feed_dict={x: x_val})
+ else:
+ tf_c, tf_m, tf_v = sess.run(
+ [op_c, op_m, op_v],
+ feed_dict={x: x_val})
+ self.assertAllClose(np_c, tf_c, atol=0.000001)
+ self.assertAllClose(np_m, tf_m, atol=0.000001)
+ self.assertAllClose(np_v, tf_v, atol=0.000001)
+ if shift:
+ self.assertAllClose(np_s, tf_s, atol=0.000001)
+
+ def testSuffStats(self):
+ for has_shape in [True, False]:
+ for keep_dims in [True, False]:
+ for shift in [True, False]:
+ self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape)
+ self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape)
+ self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape)
+
+
+class AggregateMomentsTest(tf.test.TestCase):
+
+ def _npAggregateMoments(self, counts, mean_ss, variance_ss, shift):
+ mean = mean_ss / counts
+ variance = variance_ss / counts - mean * mean
+ if shift is not None:
+ mean += shift
+ return mean, variance
+
+ def _opAggregateMoments(self, counts, mean_ss, variance_ss, shift):
+ return tf.nn.aggregate_moments(counts, mean_ss, variance_ss, shift)
+
+ def _testAggregateMoments(self, shape, shift):
+ counts = np.ones([1]).astype(np.float32)
+ mean_ss = np.random.random_sample(shape).astype(np.float32)
+ variance_ss = np.random.random_sample(shape).astype(np.float32)
+ variance_ss *= variance_ss
+ if shift:
+ shift_v = np.random.random_sample(shape).astype(np.float32)
+ else:
+ shift_v = None
+ npm, npv = self._npAggregateMoments(counts, mean_ss, variance_ss, shift_v)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu) as sess:
+ tf_counts = tf.constant(counts, name="counts")
+ tf_mean_ss = tf.constant(mean_ss, name="mean_ss")
+ tf_variance_ss = tf.constant(variance_ss, name="variance_ss")
+ if shift:
+ tf_shift_v = tf.constant(shift_v, name="shift")
+ else:
+ tf_shift_v = None
+ opm, opv = self._opAggregateMoments(tf_counts, tf_mean_ss,
+ tf_variance_ss, tf_shift_v)
+ tfm, tfv = sess.run([opm, opv])
+ self.assertAllClose(npm, tfm, atol=0.000001)
+ self.assertAllClose(npv, tfv, atol=0.000001)
+
+ def testAggregateMoments(self):
+ for shift in [True, False]:
+ self._testAggregateMoments([3], shift)
+ self._testAggregateMoments([2, 3], shift)
+
+
class MomentsTest(tf.test.TestCase):
def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims):
@@ -857,6 +967,20 @@ class MomentsTest(tf.test.TestCase):
def testVarGlobalGradient(self):
self._testGlobalGradient(from_y="var")
+ def testOutputNamesNoKeep(self):
+ """Make sure the output names are stable."""
+ with self.test_session():
+ mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=False)
+ self.assertEquals(mean.op.name, "moments/aggregate/mean")
+ self.assertEquals(var.op.name, "moments/aggregate/variance")
+
+ def testOutputNamesKeep(self):
+ """Make sure the output names are stable."""
+ with self.test_session():
+ mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=True)
+ self.assertEquals(mean.op.name, "moments/aggregate/mean")
+ self.assertEquals(var.op.name, "moments/aggregate/variance")
+
class ComputeSampledLogitsTest(tf.test.TestCase):