aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/fused_batchnorm_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-14 09:52:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-14 09:55:32 -0700
commit030ddd98c4d0a5bb654e52fd0c54ba2a0bd51060 (patch)
tree4a53b39def8d775f3d21461893efd23b92096c32 /tensorflow/compiler/tests/fused_batchnorm_test.py
parentaea11343414f4feaf2b2be5845b34b20e419c86f (diff)
Several updates and fixes for fused batchnorm in xla.
1. There is a bug in in deriving the formulas for batch norm gradient calculation. This commit corrects the formula and updates the implementation. 2. Learning from mistake, this commit uses gradient_checker to test batch_norm_grad in a more generic way -- this method calculates the "real" gradients by evaluating the graph twice, and then compare it with the gradient. RELNOTES: Update formula for fused batch norm in xla PiperOrigin-RevId: 165190486
Diffstat (limited to 'tensorflow/compiler/tests/fused_batchnorm_test.py')
-rw-r--r--tensorflow/compiler/tests/fused_batchnorm_test.py41
1 files changed, 38 insertions, 3 deletions
diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py
index 437c71b0da..f8e9fc9268 100644
--- a/tensorflow/compiler/tests/fused_batchnorm_test.py
+++ b/tensorflow/compiler/tests/fused_batchnorm_test.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn
from tensorflow.python.platform import test
@@ -42,15 +43,27 @@ class FusedBatchNormTest(XLATestCase):
return (normalized * scale + offset), mean, var
def _reference_grad(self, x, grad_y, scale, mean, var, epsilon, data_format):
+ # Use the following formulas to calculate gradients:
+ # grad_scale =
+ # sum(grad_y * (x - mean)) * rsqrt(var + epsilon)
+ #
+ # grad_offset = sum(output_y)
+ #
+ # grad_x =
+ # 1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) -
+ # (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon))
if data_format != "NHWC":
raise ValueError("data_format must be NHWC, got %s." % data_format)
- grad_x = grad_y * scale * np.sqrt(var + epsilon)
+ grad_x = scale * (grad_y - np.mean(grad_y, axis=(0, 1, 2)) -
+ (x - mean) * np.mean(grad_y *
+ (x - mean), axis=(0, 1, 2)) /
+ (var + epsilon)) / np.sqrt(var + epsilon)
grad_scale = np.sum(
- grad_y * (x - mean) * np.sqrt(var + epsilon), axis=(0, 1, 2))
+ grad_y * (x - mean) / np.sqrt(var + epsilon), axis=(0, 1, 2))
grad_offset = np.sum(grad_y, axis=(0, 1, 2))
return grad_x, grad_scale, grad_offset
- def testBasic(self):
+ def _testLearning(self, use_gradient_checker):
x_shape = [2, 2, 6, 2]
scale_shape = [2]
x_val = np.random.random_sample(x_shape).astype(np.float32)
@@ -75,6 +88,20 @@ class FusedBatchNormTest(XLATestCase):
epsilon=epsilon,
data_format=data_format,
is_training=True)
+ # Check gradient.
+ if use_gradient_checker:
+ err = gradient_checker.compute_gradient_error(
+ t_val,
+ x_shape,
+ y,
+ x_shape,
+ extra_feed_dict={
+ t_val: x_val,
+ scale: scale_val,
+ offset: offset_val
+ })
+ self.assertLess(err, 1e-3)
+
y_val, mean_val, var_val = sess.run(
[y, mean, var], {t_val: x_val,
scale: scale_val,
@@ -85,6 +112,12 @@ class FusedBatchNormTest(XLATestCase):
self.assertAllClose(y_val, y_ref, atol=1e-3)
self.assertAllClose(var_val, var_ref, atol=1e-3)
+ def testLearning(self):
+ self._testLearning(False)
+
+ def testLearningWithGradientChecker(self):
+ self._testLearning(True)
+
def testGradient(self):
# TODO(b/64270657): Use gradient_checker here in addition to comparing with
# this reference implementation.
@@ -105,6 +138,7 @@ class FusedBatchNormTest(XLATestCase):
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad(
grad, x, scale, mean, var, data_format="NHWC")
+
grad_x_val, grad_scale_val, grad_offset_val = sess.run(
[grad_x, grad_scale, grad_offset], {
grad: grad_val,
@@ -121,5 +155,6 @@ class FusedBatchNormTest(XLATestCase):
self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2)
self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3)
+
if __name__ == "__main__":
test.main()