aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/client_library_test_base.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/client_library_test_base.h')
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h6
1 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index fcc9347db5..f0f7ff1ea0 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -399,12 +399,16 @@ class ClientLibraryTestBase : public ::testing::Test {
const string& error_message)>& verify_output,
const Shape* output_with_layout = nullptr);
+ // Converts an f32 shape/literal to bf16 if use_bfloat16_ is true.
+ Literal MaybeConvertLiteralToBfloat16(const Literal& literal);
+ Shape MaybeConvertShapeToBfloat16(const Shape& shape);
+
// Whether to run tests with all float-type input/output converted to
// bfloat16.
bool use_bfloat16_ = false;
// Arguments to be passed to the computation when it runs.
- std::vector<std::unique_ptr<GlobalData>> arguments_;
+ std::vector<Literal> arguments_;
};
template <typename NativeT>