aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/test_macros.h
diff options
context:
space:
mode:
authorGravatar Bixia Zheng <bixia@google.com>2018-02-15 10:39:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-15 10:42:55 -0800
commitb91155edb661e074b716d7051c2cb71cbf9ec759 (patch)
tree48cc74ee878c9313e9cb3f8a10e8227dcdd96ede /tensorflow/compiler/xla/tests/test_macros.h
parentc356d2800182ef7430a70baa2b1b75ea854f9adf (diff)
Enable half precision convolution for the CPU and GPU backends.
Enhance the CPU IR emitter to support F16 dot operation and convolution operation. Add a CPU runtime implementation for F16 convolution. Enhance the GPU backend to handle F16 convolution thunk. Convert some F32 xla convolution tests to support both F32 and F16 and disable the tests for the CPU backend due to b/72509305. PiperOrigin-RevId: 185862438
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_macros.h')
-rw-r--r--tensorflow/compiler/xla/tests/test_macros.h27
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h
index cc4eaf62f5..e2d406f66d 100644
--- a/tensorflow/compiler/xla/tests/test_macros.h
+++ b/tensorflow/compiler/xla/tests/test_macros.h
@@ -161,4 +161,31 @@ string PrependDisabledIfIndicated(const string& test_case_name,
#define XLA_TEST_P(test_case_name, test_name) \
XLA_TEST_P_IMPL_(test_case_name, test_name)
+
+// This is identical to the TEST_F macro from "gtest", but it potentially
+// disables the test based on an external manifest file, DISABLED_MANIFEST.
+#define XLA_TYPED_TEST(CaseName, TestName) \
+ template <typename gtest_TypeParam_> \
+ class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \
+ : public CaseName<gtest_TypeParam_> { \
+ private: \
+ typedef CaseName<gtest_TypeParam_> TestFixture; \
+ typedef gtest_TypeParam_ TypeParam; \
+ virtual void TestBody(); \
+ }; \
+ bool gtest_##CaseName##_##TestName##_registered_ GTEST_ATTRIBUTE_UNUSED_ = \
+ ::testing::internal::TypeParameterizedTest< \
+ CaseName, \
+ ::testing::internal::TemplateSel<GTEST_TEST_CLASS_NAME_(CaseName, \
+ TestName)>, \
+ GTEST_TYPE_PARAMS_(CaseName)>:: \
+ Register( \
+ "", ::testing::internal::CodeLocation(__FILE__, __LINE__), \
+ #CaseName, \
+ ::xla::PrependDisabledIfIndicated(#CaseName, #TestName).c_str(), \
+ 0); \
+ template <typename gtest_TypeParam_> \
+ void GTEST_TEST_CLASS_NAME_(CaseName, \
+ TestName)<gtest_TypeParam_>::TestBody()
+
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_