aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-27 21:48:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-27 21:52:22 -0800
commit8781d69b2e619e64555cb00b13783a7eee524b81 (patch)
tree230a7461ae764be472911e7d21682c42f36c9172
parent119e3a18ce480b7f808638a2821de1d935f2df8f (diff)
Allow BF16 to use error spec.
PiperOrigin-RevId: 177114689
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h4
1 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index e8599a5cd3..1d27880fb1 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -387,6 +387,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
+ std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
@@ -411,6 +412,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
+ std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
@@ -435,6 +437,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
+ std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
@@ -459,6 +462,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
+ std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =