aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_fused_batchnorm_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_fused_batchnorm_test.py')
-rw-r--r--tensorflow/python/ops/nn_fused_batchnorm_test.py3
1 files changed, 1 insertions, 2 deletions
diff --git a/tensorflow/python/ops/nn_fused_batchnorm_test.py b/tensorflow/python/ops/nn_fused_batchnorm_test.py
index 48d1d5b25a..1c1554e9f3 100644
--- a/tensorflow/python/ops/nn_fused_batchnorm_test.py
+++ b/tensorflow/python/ops/nn_fused_batchnorm_test.py
@@ -198,8 +198,7 @@ class BatchNormalizationTest(test.TestCase):
epsilon = y.op.get_attr('epsilon')
data_format = y.op.get_attr('data_format')
grad_vals = sess.run([grad_x, grad_scale, grad_offset])
- grad_internal = nn_grad._BatchNormGrad(grad_y, x, scale, epsilon,
- data_format)
+ grad_internal = nn_grad._BatchNormGrad(grad_y, x, scale, pop_mean, pop_var, epsilon, data_format)
grad_internal_vals = sess.run(list(grad_internal))
for grad_val, grad_internal_val in zip(grad_vals, grad_internal_vals):
self.assertAllClose(grad_val, grad_internal_val, atol=err_tolerance)