diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/client_library_test_base.h')
-rw-r--r-- | tensorflow/compiler/xla/tests/client_library_test_base.h | 64 |
1 files changed, 35 insertions, 29 deletions
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 37862fa9cb..edc1ba8a57 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -27,7 +27,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -284,7 +285,7 @@ class ClientLibraryTestBase : public ::testing::Test { template <class T> XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) { - return AddParam(*Literal::CreateFromArray(argument), builder); + return AddParam(*LiteralUtil::CreateFromArray(argument), builder); } // Creates a constant instruction with the given literal. When the @@ -299,13 +300,14 @@ class ClientLibraryTestBase : public ::testing::Test { template <typename NativeT> XlaOp CreateConstantFromArray(const Array<NativeT>& array, XlaBuilder* builder) { - return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); + return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array), + builder); } // Same as CreateConstantFromArray, but for scalars. template <typename NativeT> XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) { - return CreateConstantFromLiteral(*Literal::CreateR0<NativeT>(value), + return CreateConstantFromLiteral(*LiteralUtil::CreateR0<NativeT>(value), builder); } @@ -373,6 +375,13 @@ class ClientLibraryTestBase : public ::testing::Test { // The float type used in this test, BF16 or F32 according to use_bfloat16. PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; } + // Executes the computation and calculates the expected reference value using + // the reference client. Returns two literals in the order of (expected, + // actual). + StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>> + ComputeValueAndReference(XlaBuilder* builder, + tensorflow::gtl::ArraySlice<Literal> arguments); + Client* client_; Client* ref_client_; // To compute reference result. ExecutionOptions execution_options_; @@ -390,19 +399,16 @@ class ClientLibraryTestBase : public ::testing::Test { const string& error_message)>& verify_output, const Shape* output_with_layout = nullptr); - // Executes the computation and calculates the expected reference value using - // the reference client. Returns two literals in the order of (expected, - // actual). - StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>> - ComputeValueAndReference(XlaBuilder* builder, - tensorflow::gtl::ArraySlice<Literal> arguments); + // Converts an f32 shape/literal to bf16 if use_bfloat16_ is true. + Literal MaybeConvertLiteralToBfloat16(const Literal& literal); + Shape MaybeConvertShapeToBfloat16(const Shape& shape); // Whether to run tests with all float-type input/output converted to // bfloat16. bool use_bfloat16_ = false; // Arguments to be passed to the computation when it runs. - std::vector<std::unique_ptr<GlobalData>> arguments_; + std::vector<Literal> arguments_; }; template <typename NativeT> @@ -410,7 +416,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments) { std::unique_ptr<Literal> expected_literal = - Literal::CreateR0<NativeT>(expected); + LiteralUtil::CreateR0<NativeT>(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -426,7 +432,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( std::is_same<NativeT, complex64>::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr<Literal> expected_literal = - Literal::CreateR0<NativeT>(expected); + LiteralUtil::CreateR0<NativeT>(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -436,7 +442,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments) { std::unique_ptr<Literal> expected_literal = - Literal::CreateR1<NativeT>(expected); + LiteralUtil::CreateR1<NativeT>(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -452,7 +458,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( std::is_same<NativeT, complex64>::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr<Literal> expected_literal = - Literal::CreateR1<NativeT>(expected); + LiteralUtil::CreateR1<NativeT>(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -462,7 +468,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( XlaBuilder* builder, const Array2D<NativeT>& expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments) { std::unique_ptr<Literal> expected_literal = - Literal::CreateR2FromArray2D<NativeT>(expected); + LiteralUtil::CreateR2FromArray2D<NativeT>(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -478,7 +484,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( std::is_same<NativeT, complex64>::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr<Literal> expected_literal = - Literal::CreateR2FromArray2D<NativeT>(expected); + LiteralUtil::CreateR2FromArray2D<NativeT>(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -488,7 +494,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( XlaBuilder* builder, const Array3D<NativeT>& expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments) { std::unique_ptr<Literal> expected_literal = - Literal::CreateR3FromArray3D<NativeT>(expected); + LiteralUtil::CreateR3FromArray3D<NativeT>(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -504,7 +510,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( std::is_same<NativeT, complex64>::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr<Literal> expected_literal = - Literal::CreateR3FromArray3D<NativeT>(expected); + LiteralUtil::CreateR3FromArray3D<NativeT>(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -514,7 +520,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( XlaBuilder* builder, const Array4D<NativeT>& expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments) { std::unique_ptr<Literal> expected_literal = - Literal::CreateR4FromArray4D<NativeT>(expected); + LiteralUtil::CreateR4FromArray4D<NativeT>(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -530,7 +536,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( std::is_same<NativeT, complex64>::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr<Literal> expected_literal = - Literal::CreateR4FromArray4D<NativeT>(expected); + LiteralUtil::CreateR4FromArray4D<NativeT>(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -539,9 +545,9 @@ template <typename NativeT> std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr<Literal> literal = Literal::CreateR0(value); + std::unique_ptr<Literal> literal = LiteralUtil::CreateR0(value); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = Literal::ConvertF32ToBF16(*literal); + literal = LiteralUtil::ConvertF32ToBF16(*literal); } std::unique_ptr<GlobalData> data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -553,9 +559,9 @@ template <typename NativeT> std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter( tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr<Literal> literal = Literal::CreateR1(values); + std::unique_ptr<Literal> literal = LiteralUtil::CreateR1(values); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = Literal::ConvertF32ToBF16(*literal); + literal = LiteralUtil::ConvertF32ToBF16(*literal); } std::unique_ptr<GlobalData> data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -567,9 +573,9 @@ template <typename NativeT> std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter( const Array2D<NativeT>& array_2d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr<Literal> literal = Literal::CreateR2FromArray2D(array_2d); + std::unique_ptr<Literal> literal = LiteralUtil::CreateR2FromArray2D(array_2d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = Literal::ConvertF32ToBF16(*literal); + literal = LiteralUtil::ConvertF32ToBF16(*literal); } std::unique_ptr<GlobalData> data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -581,9 +587,9 @@ template <typename NativeT> std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter( const Array3D<NativeT>& array_3d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(array_3d); + std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(array_3d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = Literal::ConvertF32ToBF16(*literal); + literal = LiteralUtil::ConvertF32ToBF16(*literal); } std::unique_ptr<GlobalData> data = client_->TransferToServer(*literal).ConsumeValueOrDie(); |