aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/conv2d_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/conv2d_test.py')
-rw-r--r--tensorflow/compiler/tests/conv2d_test.py11
1 files changed, 4 insertions, 7 deletions
diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py
index d12e1ff1e8..f9db103f6d 100644
--- a/tensorflow/compiler/tests/conv2d_test.py
+++ b/tensorflow/compiler/tests/conv2d_test.py
@@ -26,23 +26,20 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import test_utils
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
-
DATA_FORMATS = (
("_data_format_NHWC", "NHWC"),
("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
)
-class Conv2DTest(XLATestCase, parameterized.TestCase):
+class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase):
def _VerifyValues(self,
input_sizes=None,
@@ -236,7 +233,7 @@ class Conv2DTest(XLATestCase, parameterized.TestCase):
expected=np.reshape([108, 128], [1, 1, 1, 2]))
-class Conv2DBackpropInputTest(XLATestCase, parameterized.TestCase):
+class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase):
def _VerifyValues(self,
input_sizes=None,
@@ -534,7 +531,7 @@ class Conv2DBackpropInputTest(XLATestCase, parameterized.TestCase):
expected=[5, 0, 11, 0, 0, 0, 17, 0, 23])
-class Conv2DBackpropFilterTest(XLATestCase, parameterized.TestCase):
+class Conv2DBackpropFilterTest(xla_test.XLATestCase, parameterized.TestCase):
def _VerifyValues(self,
input_sizes=None,