aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-26 08:21:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 08:26:38 -0700
commit5498f24a3385bdd256b8b1e41329c5841996b26d (patch)
tree859fc7c8d48b75539d4f35194554cc1bcefe8e4e /tensorflow/compiler/tests
parente45f7ee4182d5e831026f329cff5da2596d6733a (diff)
Changed FusedBatchNorm and FusedBatchNormGrad to use allowed_values for data_format attr.
PiperOrigin-RevId: 214608039
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r--tensorflow/compiler/tests/fused_batchnorm_test.py40
1 files changed, 10 insertions, 30 deletions
diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py
index 8c018cccb8..374942a0b3 100644
--- a/tensorflow/compiler/tests/fused_batchnorm_test.py
+++ b/tensorflow/compiler/tests/fused_batchnorm_test.py
@@ -29,6 +29,11 @@ from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn
from tensorflow.python.platform import test
+DATA_FORMATS = (
+ ("_data_format_NHWC", "NHWC"),
+ ("_data_format_NCHW", "NCHW"),
+)
+
class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
@@ -65,12 +70,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
grad_offset = np.sum(grad_y, axis=(0, 1, 2))
return grad_x, grad_scale, grad_offset
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testInference(self, data_format):
channel = 3
x_shape = [2, 2, 6, channel]
@@ -170,30 +170,15 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
self.assertAllClose(y_val, y_ref_converted, atol=1e-3)
self.assertAllClose(var_val, var_ref, atol=1e-3)
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testLearning(self, data_format):
self._testLearning(False, data_format)
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testLearningWithGradientChecker(self, data_format):
self._testLearning(True, data_format)
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testGradientTraining(self, data_format):
# TODO(b/64270657): Use gradient_checker here in addition to comparing with
# this reference implementation.
@@ -241,12 +226,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2)
self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3)
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testGradientInference(self, data_format):
# TODO(b/64270657): Use gradient_checker here in addition to comparing with
# this reference implementation.