diff options
author | 2017-08-14 09:52:07 -0700 | |
---|---|---|
committer | 2017-08-14 09:55:32 -0700 | |
commit | 030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060 (patch) | |
tree | 4a53b39def8d775f3d21461893efd23b92096c32 /tensorflow/compiler/tests/fused_batchnorm_test.py | |
parent | aea11343414f4feaf2b2be5845b34b20e419c86f (diff) |
Several updates and fixes for fused batchnorm in xla.
1. There is a bug in in deriving the formulas for batch norm gradient calculation. This commit corrects the formula and
updates the implementation.
2. Learning from mistake, this commit uses gradient_checker to test batch_norm_grad in a more generic way -- this method calculates the "real" gradients by evaluating the graph twice, and then compare it with the gradient.
RELNOTES: Update formula for fused batch norm in xla
PiperOrigin-RevId: 165190486
Diffstat (limited to 'tensorflow/compiler/tests/fused_batchnorm_test.py')
-rw-r--r-- | tensorflow/compiler/tests/fused_batchnorm_test.py | 41 |
1 files changed, 38 insertions, 3 deletions
diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 437c71b0da..f8e9fc9268 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import nn from tensorflow.python.platform import test @@ -42,15 +43,27 @@ class FusedBatchNormTest(XLATestCase): return (normalized * scale + offset), mean, var def _reference_grad(self, x, grad_y, scale, mean, var, epsilon, data_format): + # Use the following formulas to calculate gradients: + # grad_scale = + # sum(grad_y * (x - mean)) * rsqrt(var + epsilon) + # + # grad_offset = sum(output_y) + # + # grad_x = + # 1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) - + # (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon)) if data_format != "NHWC": raise ValueError("data_format must be NHWC, got %s." % data_format) - grad_x = grad_y * scale * np.sqrt(var + epsilon) + grad_x = scale * (grad_y - np.mean(grad_y, axis=(0, 1, 2)) - + (x - mean) * np.mean(grad_y * + (x - mean), axis=(0, 1, 2)) / + (var + epsilon)) / np.sqrt(var + epsilon) grad_scale = np.sum( - grad_y * (x - mean) * np.sqrt(var + epsilon), axis=(0, 1, 2)) + grad_y * (x - mean) / np.sqrt(var + epsilon), axis=(0, 1, 2)) grad_offset = np.sum(grad_y, axis=(0, 1, 2)) return grad_x, grad_scale, grad_offset - def testBasic(self): + def _testLearning(self, use_gradient_checker): x_shape = [2, 2, 6, 2] scale_shape = [2] x_val = np.random.random_sample(x_shape).astype(np.float32) @@ -75,6 +88,20 @@ class FusedBatchNormTest(XLATestCase): epsilon=epsilon, data_format=data_format, is_training=True) + # Check gradient. + if use_gradient_checker: + err = gradient_checker.compute_gradient_error( + t_val, + x_shape, + y, + x_shape, + extra_feed_dict={ + t_val: x_val, + scale: scale_val, + offset: offset_val + }) + self.assertLess(err, 1e-3) + y_val, mean_val, var_val = sess.run( [y, mean, var], {t_val: x_val, scale: scale_val, @@ -85,6 +112,12 @@ class FusedBatchNormTest(XLATestCase): self.assertAllClose(y_val, y_ref, atol=1e-3) self.assertAllClose(var_val, var_ref, atol=1e-3) + def testLearning(self): + self._testLearning(False) + + def testLearningWithGradientChecker(self): + self._testLearning(True) + def testGradient(self): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. @@ -105,6 +138,7 @@ class FusedBatchNormTest(XLATestCase): scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( grad, x, scale, mean, var, data_format="NHWC") + grad_x_val, grad_scale_val, grad_offset_val = sess.run( [grad_x, grad_scale, grad_offset], { grad: grad_val, @@ -121,5 +155,6 @@ class FusedBatchNormTest(XLATestCase): self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) + if __name__ == "__main__": test.main() |