diff options
Diffstat (limited to 'tensorflow/compiler/tests/fused_batchnorm_test.py')
-rw-r--r-- | tensorflow/compiler/tests/fused_batchnorm_test.py | 25 |
1 files changed, 14 insertions, 11 deletions
diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 936fcf8b6b..a773b5a947 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -36,7 +36,7 @@ class FusedBatchNormTest(XLATestCase): x_square = x * x x_square_sum = np.sum(x_square, (0, 1, 2)) x_sum = np.sum(x, axis=(0, 1, 2)) - element_count = np.size(x) / int(np.shape(x)[0]) + element_count = np.size(x) / int(np.shape(x)[-1]) mean = x_sum / element_count var = x_square_sum / element_count - mean * mean normalized = (x - mean) / np.sqrt(var + epsilon) @@ -64,8 +64,9 @@ class FusedBatchNormTest(XLATestCase): return grad_x, grad_scale, grad_offset def testInference(self): - x_shape = [2, 2, 6, 2] - scale_shape = [2] + channel = 3 + x_shape = [2, 2, 6, channel] + scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) @@ -74,8 +75,8 @@ class FusedBatchNormTest(XLATestCase): with self.test_session() as sess, self.test_scope(): # To avoid constant folding t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") - scale = array_ops.placeholder(np.float32, shape=[2], name="scale") - offset = array_ops.placeholder(np.float32, shape=[2], name="offset") + scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") + offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset") epsilon = 0.001 y_ref, mean_ref, var_ref = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format) @@ -97,8 +98,9 @@ class FusedBatchNormTest(XLATestCase): self.assertAllClose(y_val, y_ref, atol=1e-3) def _testLearning(self, use_gradient_checker): - x_shape = [2, 2, 6, 2] - scale_shape = [2] + channel = 3 + x_shape = [2, 2, 6, channel] + scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) @@ -109,8 +111,8 @@ class FusedBatchNormTest(XLATestCase): with self.test_session() as sess, self.test_scope(): # To avoid constant folding t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") - scale = array_ops.placeholder(np.float32, shape=[2], name="scale") - offset = array_ops.placeholder(np.float32, shape=[2], name="offset") + scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") + offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset") epsilon = 0.001 y, mean, var = nn.fused_batch_norm( t_val, @@ -154,8 +156,9 @@ class FusedBatchNormTest(XLATestCase): def testGradient(self): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. - x_shape = [2, 2, 6, 2] - scale_shape = [2] + channel = 3 + x_shape = [2, 2, 6, channel] + scale_shape = [channel] grad_val = np.random.random_sample(x_shape).astype(np.float32) x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) |