diff options
author | 2017-11-27 20:28:58 -0800 | |
---|---|---|
committer | 2017-11-27 20:32:41 -0800 | |
commit | 119e3a18ce480b7f808638a2821de1d935f2df8f (patch) | |
tree | 7cef1532dabf40887dd2368172d78201e6c7fa69 | |
parent | a8a923b3be645bad6cd08c7d80a148ebbaf47445 (diff) |
Make ClientLibraryTestBase automatic choose float precision based on a flag.
PiperOrigin-RevId: 177109696
-rw-r--r-- | tensorflow/compiler/xla/reference_util.cc | 133 | ||||
-rw-r--r-- | tensorflow/compiler/xla/reference_util.h | 146 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/client_library_test_base.cc | 87 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/client_library_test_base.h | 49 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/literal_test_util.cc | 32 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/literal_test_util.h | 6 |
6 files changed, 289 insertions, 164 deletions
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 90aa9720a1..5a899d550b 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -703,137 +703,4 @@ ReferenceUtil::ReduceToRowArray2D( return result; } -/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::PadArray2D( - const Array2D<float>& operand, const PaddingConfig& padding, - const float pad) { - int64 in0 = operand.n1(); - int64 high_padding0 = padding.dimensions(0).edge_padding_high(); - int64 low_padding0 = padding.dimensions(0).edge_padding_low(); - int64 interior_padding0 = padding.dimensions(0).interior_padding(); - int64 out0 = - in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0; - - int64 in1 = operand.n2(); - int64 high_padding1 = padding.dimensions(1).edge_padding_high(); - int64 low_padding1 = padding.dimensions(1).edge_padding_low(); - int64 interior_padding1 = padding.dimensions(1).interior_padding(); - int64 out1 = - in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1; - - auto result = MakeUnique<Array2D<float>>(out0, out1); - result->Fill(pad); - int64 o0 = low_padding0; - for (int64 i0 = 0; i0 < in0; ++i0) { - int64 o1 = low_padding1; - for (int64 i1 = 0; i1 < in1; ++i1) { - if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) { - (*result)(o0, o1) = operand(i0, i1); - } - o1 += interior_padding1 + 1; - } - o0 += interior_padding0 + 1; - } - return result; -} - -/* static */ Array3D<float> ReferenceUtil::PadArray3D( - const Array3D<float>& operand, const PaddingConfig& padding, - const float pad) { - CHECK_EQ(padding.dimensions_size(), 3); - - const std::vector<int64> input_bounds = {operand.n1(), operand.n2(), - operand.n3()}; - std::vector<int64> pad_low(3); - std::vector<int64> pad_high(3); - std::vector<int64> pad_interior(3); - std::vector<int64> output_bounds(3); - for (int64 i = 0; i < 3; ++i) { - pad_low[i] = padding.dimensions(i).edge_padding_low(); - pad_high[i] = padding.dimensions(i).edge_padding_high(); - CHECK_LE(0, pad_low[i]); - CHECK_LE(0, pad_high[i]); - CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented"; - pad_interior[i] = padding.dimensions(i).interior_padding(); - - output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + - (input_bounds[i] - 1) * pad_interior[i]; - } - - Array3D<float> result(output_bounds[0], output_bounds[1], output_bounds[2]); - std::vector<int> indices = {0, 0, 0}; - for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) { - for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) { - for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) { - float* value = &result(indices[0], indices[1], indices[2]); - bool value_padded = false; - for (int i = 0; i < 3; ++i) { - bool in_low_padding = indices[i] < pad_low[i]; - bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; - if (in_low_padding || in_high_padding) { - *value = pad; - value_padded = true; - } - if (pad_interior[i] && - (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { - *value = pad; - value_padded = true; - } - } - if (value_padded) { - continue; - } - *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), - (indices[1] - pad_low[1]) / (pad_interior[1] + 1), - (indices[2] - pad_low[2]) / (pad_interior[2] + 1)); - } - } - } - return result; -} - -/* static */ Array4D<float> ReferenceUtil::PadArray4D( - const Array4D<float>& operand, const PaddingConfig& padding, - const float pad) { - CHECK_EQ(padding.dimensions_size(), 4); - - const std::vector<int64> input_bounds = {operand.n1(), operand.n2(), - operand.n3(), operand.n4()}; - std::vector<int64> pad_low(4); - std::vector<int64> pad_high(4); - std::vector<int64> pad_interior(4); - std::vector<int64> output_bounds(4); - for (int64 i = 0; i < 4; ++i) { - pad_low[i] = padding.dimensions(i).edge_padding_low(); - pad_high[i] = padding.dimensions(i).edge_padding_high(); - CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented"; - pad_interior[i] = padding.dimensions(i).interior_padding(); - - output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + - (input_bounds[i] - 1) * pad_interior[i]; - } - - Array4D<float> result(output_bounds[0], output_bounds[1], output_bounds[2], - output_bounds[3]); - result.Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) { - for (int i = 0; i < 4; ++i) { - bool in_low_padding = indices[i] < pad_low[i]; - bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; - if (in_low_padding || in_high_padding) { - *value = pad; - return; - } - if (pad_interior[i] && - (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { - *value = pad; - return; - } - } - *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), - (indices[1] - pad_low[1]) / (pad_interior[1] + 1), - (indices[2] - pad_low[2]) / (pad_interior[2] + 1), - (indices[3] - pad_low[3]) / (pad_interior[3] + 1)); - }); - return result; -} - } // namespace xla diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 2da1730781..62d455d71a 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -486,19 +486,147 @@ class ReferenceUtil { } // Returns the result of a 2D pad on an input matrix. - static std::unique_ptr<Array2D<float>> PadArray2D( - const Array2D<float>& operand, const PaddingConfig& padding, - const float pad); + template <typename NativeT> + static std::unique_ptr<Array2D<NativeT>> PadArray2D( + const Array2D<NativeT>& operand, const PaddingConfig& padding, + const NativeT pad) { + int64 in0 = operand.n1(); + int64 high_padding0 = padding.dimensions(0).edge_padding_high(); + int64 low_padding0 = padding.dimensions(0).edge_padding_low(); + int64 interior_padding0 = padding.dimensions(0).interior_padding(); + int64 out0 = + in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0; + + int64 in1 = operand.n2(); + int64 high_padding1 = padding.dimensions(1).edge_padding_high(); + int64 low_padding1 = padding.dimensions(1).edge_padding_low(); + int64 interior_padding1 = padding.dimensions(1).interior_padding(); + int64 out1 = + in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1; + + auto result = MakeUnique<Array2D<NativeT>>(out0, out1); + result->Fill(pad); + int64 o0 = low_padding0; + for (int64 i0 = 0; i0 < in0; ++i0) { + int64 o1 = low_padding1; + for (int64 i1 = 0; i1 < in1; ++i1) { + if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) { + (*result)(o0, o1) = operand(i0, i1); + } + o1 += interior_padding1 + 1; + } + o0 += interior_padding0 + 1; + } + return result; + } // Returns the result of a 3D pad on an input matrix. - static Array3D<float> PadArray3D(const Array3D<float>& operand, - const PaddingConfig& padding, - const float pad); + template <typename NativeT> + static Array3D<NativeT> PadArray3D(const Array3D<NativeT>& operand, + const PaddingConfig& padding, + const NativeT pad) { + CHECK_EQ(padding.dimensions_size(), 3); + + const std::vector<int64> input_bounds = {operand.n1(), operand.n2(), + operand.n3()}; + std::vector<int64> pad_low(3); + std::vector<int64> pad_high(3); + std::vector<int64> pad_interior(3); + std::vector<int64> output_bounds(3); + for (int64 i = 0; i < 3; ++i) { + pad_low[i] = padding.dimensions(i).edge_padding_low(); + pad_high[i] = padding.dimensions(i).edge_padding_high(); + CHECK_LE(0, pad_low[i]); + CHECK_LE(0, pad_high[i]); + CHECK_LE(0, padding.dimensions(i).interior_padding()) + << "not implemented"; + pad_interior[i] = padding.dimensions(i).interior_padding(); + + output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + + (input_bounds[i] - 1) * pad_interior[i]; + } + + Array3D<NativeT> result(output_bounds[0], output_bounds[1], + output_bounds[2]); + std::vector<int> indices = {0, 0, 0}; + for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) { + for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) { + for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) { + NativeT* value = &result(indices[0], indices[1], indices[2]); + bool value_padded = false; + for (int i = 0; i < 3; ++i) { + bool in_low_padding = indices[i] < pad_low[i]; + bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; + if (in_low_padding || in_high_padding) { + *value = pad; + value_padded = true; + } + if (pad_interior[i] && + (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { + *value = pad; + value_padded = true; + } + } + if (value_padded) { + continue; + } + *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), + (indices[1] - pad_low[1]) / (pad_interior[1] + 1), + (indices[2] - pad_low[2]) / (pad_interior[2] + 1)); + } + } + } + return result; + } // Returns the result of a 4D pad on an input array. - static Array4D<float> PadArray4D(const Array4D<float>& operand, - const PaddingConfig& padding, - const float pad); + template <typename NativeT> + static Array4D<NativeT> PadArray4D(const Array4D<NativeT>& operand, + const PaddingConfig& padding, + const NativeT pad) { + CHECK_EQ(padding.dimensions_size(), 4); + + const std::vector<int64> input_bounds = {operand.n1(), operand.n2(), + operand.n3(), operand.n4()}; + std::vector<int64> pad_low(4); + std::vector<int64> pad_high(4); + std::vector<int64> pad_interior(4); + std::vector<int64> output_bounds(4); + for (int64 i = 0; i < 4; ++i) { + pad_low[i] = padding.dimensions(i).edge_padding_low(); + pad_high[i] = padding.dimensions(i).edge_padding_high(); + CHECK_LE(0, padding.dimensions(i).interior_padding()) + << "not implemented"; + pad_interior[i] = padding.dimensions(i).interior_padding(); + + output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + + (input_bounds[i] - 1) * pad_interior[i]; + } + + Array4D<NativeT> result(output_bounds[0], output_bounds[1], + output_bounds[2], output_bounds[3]); + result.Each( + [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT* value) { + for (int i = 0; i < 4; ++i) { + bool in_low_padding = indices[i] < pad_low[i]; + bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; + if (in_low_padding || in_high_padding) { + *value = pad; + return; + } + if (pad_interior[i] && + (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { + *value = pad; + return; + } + } + *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), + (indices[1] - pad_low[1]) / (pad_interior[1] + 1), + (indices[2] - pad_low[2]) / (pad_interior[2] + 1), + (indices[3] - pad_low[3]) / (pad_interior[3] + 1)); + }); + return result; + } // ApplyElementwise2D(f, x, y, ...) returns the Array2D formed by running // f(x[i], y[i], ...) for each array element in the Array2Ds x, y, .... diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index ef54714e46..15bd273e9b 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -262,20 +262,34 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( expected.shape().element_type() == PRED) << ShapeUtil::HumanString(expected.shape()); } + // We allow using a float expected literal for a bfloat16 output. In this + // case, we need to convert the expected literal to bfloat16. + const Literal* expected_ptr = &expected; + std::unique_ptr<Literal> converted_expected; + Shape layout_shape; + if (expected.shape().element_type() == F32 && use_bfloat16_) { + converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + expected_ptr = converted_expected.get(); + if (shape_with_layout != nullptr) { + layout_shape = *shape_with_layout; + layout_shape.set_element_type(BF16); + shape_with_layout = &layout_shape; + } + } auto expect_equal = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectEqual(expected, actual, error_message); + LiteralTestUtil::ExpectEqual(*expected_ptr, actual, error_message); }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { return ComputeAndCompareLiteralWithAllOutputLayouts( - computation, expected, arguments, expect_equal); + computation, *expected_ptr, arguments, expect_equal); } if (execution_options_.debug_options().xla_test_all_input_layouts()) { return ComputeAndCompareLiteralWithAllInputLayouts( - computation, expected, arguments, expect_equal, shape_with_layout); + computation, *expected_ptr, arguments, expect_equal, shape_with_layout); } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectEqual(expected, *actual); + LiteralTestUtil::ExpectEqual(*expected_ptr, *actual); return tensorflow::Status::OK(); } @@ -286,20 +300,35 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) || ShapeUtil::ElementIsComplex(expected.shape())); TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); + // We allow using a float expected literal for a bfloat16 output. In this + // case, we need to convert the expected literal to bfloat16. + const Literal* expected_ptr = &expected; + std::unique_ptr<Literal> converted_expected; + Shape layout_shape; + if (expected.shape().element_type() == F32 && use_bfloat16_) { + converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + expected_ptr = converted_expected.get(); + layout_shape.set_element_type(BF16); + if (shape_with_layout != nullptr) { + layout_shape = *shape_with_layout; + layout_shape.set_element_type(BF16); + shape_with_layout = &layout_shape; + } + } auto expect_near = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectNear(expected, actual, error, error_message); + LiteralTestUtil::ExpectNear(*expected_ptr, actual, error, error_message); }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { - return ComputeAndCompareLiteralWithAllOutputLayouts(computation, expected, - arguments, expect_near); + return ComputeAndCompareLiteralWithAllOutputLayouts( + computation, *expected_ptr, arguments, expect_near); } if (execution_options_.debug_options().xla_test_all_input_layouts()) { return ComputeAndCompareLiteralWithAllInputLayouts( - computation, expected, arguments, expect_near, shape_with_layout); + computation, *expected_ptr, arguments, expect_near, shape_with_layout); } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectNear(expected, *actual, error); + LiteralTestUtil::ExpectNear(*expected_ptr, *actual, error); return tensorflow::Status::OK(); } @@ -402,8 +431,11 @@ ClientLibraryTestBase::ComputeValueAndReference( Computation ClientLibraryTestBase::CreateScalarRelu() { ComputationBuilder builder(client_, "relu"); - auto z_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value"); - auto zero = builder.ConstantR0<float>(0.0); + auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); + auto z_value = builder.Parameter(0, shape, "z_value"); + auto zero = use_bfloat16_ + ? builder.ConstantR0<bfloat16>(static_cast<bfloat16>(0.0f)) + : builder.ConstantR0<float>(0.0f); builder.Max(z_value, zero); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); @@ -412,8 +444,9 @@ Computation ClientLibraryTestBase::CreateScalarRelu() { Computation ClientLibraryTestBase::CreateScalarMax() { ComputationBuilder builder(client_, "max"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); + auto x = builder.Parameter(0, shape, "x"); + auto y = builder.Parameter(1, shape, "y"); builder.Max(x, y); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); @@ -422,11 +455,12 @@ Computation ClientLibraryTestBase::CreateScalarMax() { Computation ClientLibraryTestBase::CreateScalarReluSensitivity() { ComputationBuilder builder(client_, "relu_sensitivity"); - auto activation = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "activation"); - auto backprop = - builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "backprop"); - auto zero = builder.ConstantR0<float>(0.0); + auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); + auto activation = builder.Parameter(0, shape, "activation"); + auto backprop = builder.Parameter(1, shape, "backprop"); + auto zero = use_bfloat16_ + ? builder.ConstantR0<bfloat16>(static_cast<bfloat16>(0.0f)) + : builder.ConstantR0<float>(0.0f); auto activation_gtz = builder.Gt(activation, zero); builder.Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero); @@ -461,4 +495,21 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, return array; } +std::unique_ptr<GlobalData> +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + ComputationBuilder* builder, ComputationDataHandle* data_handle) { + const Literal* param_literal = &literal; + std::unique_ptr<Literal> converted_literal; + if (use_bfloat16_ && literal.shape().element_type() == F32) { + converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); + param_literal = converted_literal.get(); + } + std::unique_ptr<GlobalData> data = + client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + *data_handle = + builder->Parameter(parameter_number, param_literal->shape(), name); + return data; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index af22c12684..e8599a5cd3 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -245,51 +245,76 @@ class ClientLibraryTestBase : public ::testing::Test { const int rows, const int cols, const int rows_padded, const int cols_padded); - // Create a parameter instruction that wraps a given value and then stores + // Creates a parameter instruction, transfers the literal for the parameter to + // server, then stores into "data_handle" the global handle for that + // parameter. When the use_bfloat16 flag is set but the literal has F32 + // elements, the literal will be converted to BF16 before being transferred. + std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + ComputationBuilder* builder, ComputationDataHandle* data_handle); + + // Creates a parameter instruction that wraps a given value and then stores // into "data_handle" the global handle for that parameter. // // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. + // + // When the use_bfloat16 flag is set but NativeT is float, the data will be + // converted to bfloat16. template <typename NativeT> std::unique_ptr<GlobalData> CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); - // Create a parameter instruction that wraps the given values and then stores + // Creates a parameter instruction that wraps the given values and then stores // into "data_handle" the global handle for that parameter. // // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. + // + // When the use_bfloat16 flag is set but NativeT is float, the data will be + // converted to bfloat16. template <typename NativeT> std::unique_ptr<GlobalData> CreateR1Parameter( tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); - // Create a parameter instruction that wraps the given constant array + // Creates a parameter instruction that wraps the given constant array // "array_2d" and then stores to "data_handle" the global handle for that // parameter. // // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. + // + // When the use_bfloat16 flag is set but NativeT is float, the data will be + // converted to bfloat16. template <typename NativeT> std::unique_ptr<GlobalData> CreateR2Parameter( const Array2D<NativeT>& array_2d, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); - // Create a parameter instruction that wraps the given constant array + // Creates a parameter instruction that wraps the given constant array // "array_3d" and then stores to "data_handle" the global handle for that // parameter. // // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. + // + // When the use_bfloat16 flag is set but NativeT is float, the data will be + // converted to bfloat16. template <typename NativeT> std::unique_ptr<GlobalData> CreateR3Parameter( const Array3D<NativeT>& array_3d, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); + // Getter and setter for the use_bfloat16 flag, which indicates whether to run + // tests with all float-type input/output converted to bfloat16. + bool use_bfloat16() const { return use_bfloat16_; } + void set_use_bfloat16(bool value) { use_bfloat16_ = value; } + Client* client_; ExecutionOptions execution_options_; @@ -315,6 +340,10 @@ class ClientLibraryTestBase : public ::testing::Test { ComputeValueAndReference(ComputationBuilder* builder, const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice<Literal> arguments); + + // Whether to run tests with all float-type input/output converted to + // bfloat16. + bool use_bfloat16_ = false; }; template <typename NativeT> @@ -443,6 +472,9 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { std::unique_ptr<Literal> literal = Literal::CreateR0(value); + if (use_bfloat16_ && literal->shape().element_type() == F32) { + literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + } std::unique_ptr<GlobalData> data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -455,6 +487,9 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter( const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { std::unique_ptr<Literal> literal = Literal::CreateR1(values); + if (use_bfloat16_ && literal->shape().element_type() == F32) { + literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + } std::unique_ptr<GlobalData> data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -467,6 +502,9 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter( const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { std::unique_ptr<Literal> literal = Literal::CreateR2FromArray2D(array_2d); + if (use_bfloat16_ && literal->shape().element_type() == F32) { + literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + } std::unique_ptr<GlobalData> data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -479,6 +517,9 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter( const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(array_3d); + if (use_bfloat16_ && literal->shape().element_type() == F32) { + literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + } std::unique_ptr<GlobalData> data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 9ae5c7b6f0..6aa27e5470 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -100,6 +100,38 @@ namespace xla { ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString()); } +/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertBF16ToF32( + const Literal& bf16_literal) { + CHECK_EQ(bf16_literal.shape().element_type(), BF16); + Shape converted_shape = bf16_literal.shape(); + converted_shape.set_element_type(F32); + auto converted = Literal::CreateFromShape(converted_shape); + if (!ShapeUtil::HasZeroElements(converted_shape)) { + std::vector<int64> index(converted_shape.dimensions_size(), 0); + do { + converted->Set<float>( + index, static_cast<float>(bf16_literal.Get<bfloat16>(index))); + } while (IndexUtil::BumpIndices(converted_shape, &index)); + } + return converted; +} + +/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertF32ToBF16( + const Literal& f32_literal) { + CHECK_EQ(f32_literal.shape().element_type(), F32); + Shape converted_shape = f32_literal.shape(); + converted_shape.set_element_type(BF16); + auto converted = Literal::CreateFromShape(converted_shape); + if (!ShapeUtil::HasZeroElements(converted_shape)) { + std::vector<int64> index(converted_shape.dimensions_size(), 0); + do { + converted->Set<bfloat16>( + index, static_cast<bfloat16>(f32_literal.Get<float>(index))); + } while (IndexUtil::BumpIndices(converted_shape, &index)); + } + return converted; +} + namespace { string Hostname() { diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 467d44b857..6e4add2690 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -59,6 +59,12 @@ class LiteralTestUtil { static void AssertEqualShapesAndLayouts(const Shape& expected, const Shape& actual); + // Converts a bfloat16 literal to a float literal. + static std::unique_ptr<Literal> ConvertBF16ToF32(const Literal& bf16_literal); + + // Converts a float literal to a bfloat16 literal. + static std::unique_ptr<Literal> ConvertF32ToBF16(const Literal& f32_literal); + // Asserts that the expected and actual literals are (bitwise) equal for all // elements in the literal. Also, asserts that the rank, dimensions sizes, and // primitive type are equal. |