diff options
author | 2017-11-27 21:48:38 -0800 | |
---|---|---|
committer | 2017-11-27 21:52:22 -0800 | |
commit | 8781d69b2e619e64555cb00b13783a7eee524b81 (patch) | |
tree | 230a7461ae764be472911e7d21682c42f36c9172 | |
parent | 119e3a18ce480b7f808638a2821de1d935f2df8f (diff) |
Allow BF16 to use error spec.
PiperOrigin-RevId: 177114689
-rw-r--r-- | tensorflow/compiler/xla/tests/client_library_test_base.h | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index e8599a5cd3..1d27880fb1 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -387,6 +387,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) { static_assert(std::is_same<NativeT, float>::value || std::is_same<NativeT, double>::value || + std::is_same<NativeT, bfloat16>::value || std::is_same<NativeT, complex64>::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr<Literal> expected_literal = @@ -411,6 +412,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) { static_assert(std::is_same<NativeT, float>::value || std::is_same<NativeT, double>::value || + std::is_same<NativeT, bfloat16>::value || std::is_same<NativeT, complex64>::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr<Literal> expected_literal = @@ -435,6 +437,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) { static_assert(std::is_same<NativeT, float>::value || std::is_same<NativeT, double>::value || + std::is_same<NativeT, bfloat16>::value || std::is_same<NativeT, complex64>::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr<Literal> expected_literal = @@ -459,6 +462,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) { static_assert(std::is_same<NativeT, float>::value || std::is_same<NativeT, double>::value || + std::is_same<NativeT, bfloat16>::value || std::is_same<NativeT, complex64>::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr<Literal> expected_literal = |