aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/fused_batchnorm_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/fused_batchnorm_test.py')
-rw-r--r--tensorflow/compiler/tests/fused_batchnorm_test.py25
1 files changed, 14 insertions, 11 deletions
diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py
index 936fcf8b6b..a773b5a947 100644
--- a/tensorflow/compiler/tests/fused_batchnorm_test.py
+++ b/tensorflow/compiler/tests/fused_batchnorm_test.py
@@ -36,7 +36,7 @@ class FusedBatchNormTest(XLATestCase):
x_square = x * x
x_square_sum = np.sum(x_square, (0, 1, 2))
x_sum = np.sum(x, axis=(0, 1, 2))
- element_count = np.size(x) / int(np.shape(x)[0])
+ element_count = np.size(x) / int(np.shape(x)[-1])
mean = x_sum / element_count
var = x_square_sum / element_count - mean * mean
normalized = (x - mean) / np.sqrt(var + epsilon)
@@ -64,8 +64,9 @@ class FusedBatchNormTest(XLATestCase):
return grad_x, grad_scale, grad_offset
def testInference(self):
- x_shape = [2, 2, 6, 2]
- scale_shape = [2]
+ channel = 3
+ x_shape = [2, 2, 6, channel]
+ scale_shape = [channel]
x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
@@ -74,8 +75,8 @@ class FusedBatchNormTest(XLATestCase):
with self.test_session() as sess, self.test_scope():
# To avoid constant folding
t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
- scale = array_ops.placeholder(np.float32, shape=[2], name="scale")
- offset = array_ops.placeholder(np.float32, shape=[2], name="offset")
+ scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
+ offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset")
epsilon = 0.001
y_ref, mean_ref, var_ref = self._reference_training(
x_val, scale_val, offset_val, epsilon, data_format)
@@ -97,8 +98,9 @@ class FusedBatchNormTest(XLATestCase):
self.assertAllClose(y_val, y_ref, atol=1e-3)
def _testLearning(self, use_gradient_checker):
- x_shape = [2, 2, 6, 2]
- scale_shape = [2]
+ channel = 3
+ x_shape = [2, 2, 6, channel]
+ scale_shape = [channel]
x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
@@ -109,8 +111,8 @@ class FusedBatchNormTest(XLATestCase):
with self.test_session() as sess, self.test_scope():
# To avoid constant folding
t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
- scale = array_ops.placeholder(np.float32, shape=[2], name="scale")
- offset = array_ops.placeholder(np.float32, shape=[2], name="offset")
+ scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
+ offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset")
epsilon = 0.001
y, mean, var = nn.fused_batch_norm(
t_val,
@@ -154,8 +156,9 @@ class FusedBatchNormTest(XLATestCase):
def testGradient(self):
# TODO(b/64270657): Use gradient_checker here in addition to comparing with
# this reference implementation.
- x_shape = [2, 2, 6, 2]
- scale_shape = [2]
+ channel = 3
+ x_shape = [2, 2, 6, channel]
+ scale_shape = [channel]
grad_val = np.random.random_sample(x_shape).astype(np.float32)
x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)