diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-26 08:21:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 08:26:38 -0700 |
commit | 5498f24a3385bdd256b8b1e41329c5841996b26d (patch) | |
tree | 859fc7c8d48b75539d4f35194554cc1bcefe8e4e /tensorflow/compiler/tests | |
parent | e45f7ee4182d5e831026f329cff5da2596d6733a (diff) |
Changed FusedBatchNorm and FusedBatchNormGrad to use allowed_values for data_format attr.
PiperOrigin-RevId: 214608039
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r-- | tensorflow/compiler/tests/fused_batchnorm_test.py | 40 |
1 files changed, 10 insertions, 30 deletions
diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 8c018cccb8..374942a0b3 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -29,6 +29,11 @@ from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import nn from tensorflow.python.platform import test +DATA_FORMATS = ( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), +) + class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): @@ -65,12 +70,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): grad_offset = np.sum(grad_y, axis=(0, 1, 2)) return grad_x, grad_scale, grad_offset - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testInference(self, data_format): channel = 3 x_shape = [2, 2, 6, channel] @@ -170,30 +170,15 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): self.assertAllClose(y_val, y_ref_converted, atol=1e-3) self.assertAllClose(var_val, var_ref, atol=1e-3) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testLearning(self, data_format): self._testLearning(False, data_format) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testLearningWithGradientChecker(self, data_format): self._testLearning(True, data_format) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testGradientTraining(self, data_format): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. @@ -241,12 +226,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testGradientInference(self, data_format): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. |