diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/client_library_test_base.h')
-rw-r--r-- | tensorflow/compiler/xla/tests/client_library_test_base.h | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index fcc9347db5..f0f7ff1ea0 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -399,12 +399,16 @@ class ClientLibraryTestBase : public ::testing::Test { const string& error_message)>& verify_output, const Shape* output_with_layout = nullptr); + // Converts an f32 shape/literal to bf16 if use_bfloat16_ is true. + Literal MaybeConvertLiteralToBfloat16(const Literal& literal); + Shape MaybeConvertShapeToBfloat16(const Shape& shape); + // Whether to run tests with all float-type input/output converted to // bfloat16. bool use_bfloat16_ = false; // Arguments to be passed to the computation when it runs. - std::vector<std::unique_ptr<GlobalData>> arguments_; + std::vector<Literal> arguments_; }; template <typename NativeT> |