diff options
Diffstat (limited to 'tensorflow/python/ops/nn_fused_batchnorm_test.py')
-rw-r--r-- | tensorflow/python/ops/nn_fused_batchnorm_test.py | 3 |
1 files changed, 1 insertions, 2 deletions
diff --git a/tensorflow/python/ops/nn_fused_batchnorm_test.py b/tensorflow/python/ops/nn_fused_batchnorm_test.py index 48d1d5b25a..1c1554e9f3 100644 --- a/tensorflow/python/ops/nn_fused_batchnorm_test.py +++ b/tensorflow/python/ops/nn_fused_batchnorm_test.py @@ -198,8 +198,7 @@ class BatchNormalizationTest(test.TestCase): epsilon = y.op.get_attr('epsilon') data_format = y.op.get_attr('data_format') grad_vals = sess.run([grad_x, grad_scale, grad_offset]) - grad_internal = nn_grad._BatchNormGrad(grad_y, x, scale, epsilon, - data_format) + grad_internal = nn_grad._BatchNormGrad(grad_y, x, scale, pop_mean, pop_var, epsilon, data_format) grad_internal_vals = sess.run(list(grad_internal)) for grad_val, grad_internal_val in zip(grad_vals, grad_internal_vals): self.assertAllClose(grad_val, grad_internal_val, atol=err_tolerance) |