aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-27 20:28:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-27 20:32:41 -0800
commit119e3a18ce480b7f808638a2821de1d935f2df8f (patch)
tree7cef1532dabf40887dd2368172d78201e6c7fa69
parenta8a923b3be645bad6cd08c7d80a148ebbaf47445 (diff)
Make ClientLibraryTestBase automatic choose float precision based on a flag.
PiperOrigin-RevId: 177109696
-rw-r--r--tensorflow/compiler/xla/reference_util.cc133
-rw-r--r--tensorflow/compiler/xla/reference_util.h146
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc87
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h49
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.cc32
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h6
6 files changed, 289 insertions, 164 deletions
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index 90aa9720a1..5a899d550b 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -703,137 +703,4 @@ ReferenceUtil::ReduceToRowArray2D(
return result;
}
-/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::PadArray2D(
- const Array2D<float>& operand, const PaddingConfig& padding,
- const float pad) {
- int64 in0 = operand.n1();
- int64 high_padding0 = padding.dimensions(0).edge_padding_high();
- int64 low_padding0 = padding.dimensions(0).edge_padding_low();
- int64 interior_padding0 = padding.dimensions(0).interior_padding();
- int64 out0 =
- in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0;
-
- int64 in1 = operand.n2();
- int64 high_padding1 = padding.dimensions(1).edge_padding_high();
- int64 low_padding1 = padding.dimensions(1).edge_padding_low();
- int64 interior_padding1 = padding.dimensions(1).interior_padding();
- int64 out1 =
- in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1;
-
- auto result = MakeUnique<Array2D<float>>(out0, out1);
- result->Fill(pad);
- int64 o0 = low_padding0;
- for (int64 i0 = 0; i0 < in0; ++i0) {
- int64 o1 = low_padding1;
- for (int64 i1 = 0; i1 < in1; ++i1) {
- if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) {
- (*result)(o0, o1) = operand(i0, i1);
- }
- o1 += interior_padding1 + 1;
- }
- o0 += interior_padding0 + 1;
- }
- return result;
-}
-
-/* static */ Array3D<float> ReferenceUtil::PadArray3D(
- const Array3D<float>& operand, const PaddingConfig& padding,
- const float pad) {
- CHECK_EQ(padding.dimensions_size(), 3);
-
- const std::vector<int64> input_bounds = {operand.n1(), operand.n2(),
- operand.n3()};
- std::vector<int64> pad_low(3);
- std::vector<int64> pad_high(3);
- std::vector<int64> pad_interior(3);
- std::vector<int64> output_bounds(3);
- for (int64 i = 0; i < 3; ++i) {
- pad_low[i] = padding.dimensions(i).edge_padding_low();
- pad_high[i] = padding.dimensions(i).edge_padding_high();
- CHECK_LE(0, pad_low[i]);
- CHECK_LE(0, pad_high[i]);
- CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented";
- pad_interior[i] = padding.dimensions(i).interior_padding();
-
- output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] +
- (input_bounds[i] - 1) * pad_interior[i];
- }
-
- Array3D<float> result(output_bounds[0], output_bounds[1], output_bounds[2]);
- std::vector<int> indices = {0, 0, 0};
- for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) {
- for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) {
- for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) {
- float* value = &result(indices[0], indices[1], indices[2]);
- bool value_padded = false;
- for (int i = 0; i < 3; ++i) {
- bool in_low_padding = indices[i] < pad_low[i];
- bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
- if (in_low_padding || in_high_padding) {
- *value = pad;
- value_padded = true;
- }
- if (pad_interior[i] &&
- (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) {
- *value = pad;
- value_padded = true;
- }
- }
- if (value_padded) {
- continue;
- }
- *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1),
- (indices[1] - pad_low[1]) / (pad_interior[1] + 1),
- (indices[2] - pad_low[2]) / (pad_interior[2] + 1));
- }
- }
- }
- return result;
-}
-
-/* static */ Array4D<float> ReferenceUtil::PadArray4D(
- const Array4D<float>& operand, const PaddingConfig& padding,
- const float pad) {
- CHECK_EQ(padding.dimensions_size(), 4);
-
- const std::vector<int64> input_bounds = {operand.n1(), operand.n2(),
- operand.n3(), operand.n4()};
- std::vector<int64> pad_low(4);
- std::vector<int64> pad_high(4);
- std::vector<int64> pad_interior(4);
- std::vector<int64> output_bounds(4);
- for (int64 i = 0; i < 4; ++i) {
- pad_low[i] = padding.dimensions(i).edge_padding_low();
- pad_high[i] = padding.dimensions(i).edge_padding_high();
- CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented";
- pad_interior[i] = padding.dimensions(i).interior_padding();
-
- output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] +
- (input_bounds[i] - 1) * pad_interior[i];
- }
-
- Array4D<float> result(output_bounds[0], output_bounds[1], output_bounds[2],
- output_bounds[3]);
- result.Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
- for (int i = 0; i < 4; ++i) {
- bool in_low_padding = indices[i] < pad_low[i];
- bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
- if (in_low_padding || in_high_padding) {
- *value = pad;
- return;
- }
- if (pad_interior[i] &&
- (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) {
- *value = pad;
- return;
- }
- }
- *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1),
- (indices[1] - pad_low[1]) / (pad_interior[1] + 1),
- (indices[2] - pad_low[2]) / (pad_interior[2] + 1),
- (indices[3] - pad_low[3]) / (pad_interior[3] + 1));
- });
- return result;
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h
index 2da1730781..62d455d71a 100644
--- a/tensorflow/compiler/xla/reference_util.h
+++ b/tensorflow/compiler/xla/reference_util.h
@@ -486,19 +486,147 @@ class ReferenceUtil {
}
// Returns the result of a 2D pad on an input matrix.
- static std::unique_ptr<Array2D<float>> PadArray2D(
- const Array2D<float>& operand, const PaddingConfig& padding,
- const float pad);
+ template <typename NativeT>
+ static std::unique_ptr<Array2D<NativeT>> PadArray2D(
+ const Array2D<NativeT>& operand, const PaddingConfig& padding,
+ const NativeT pad) {
+ int64 in0 = operand.n1();
+ int64 high_padding0 = padding.dimensions(0).edge_padding_high();
+ int64 low_padding0 = padding.dimensions(0).edge_padding_low();
+ int64 interior_padding0 = padding.dimensions(0).interior_padding();
+ int64 out0 =
+ in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0;
+
+ int64 in1 = operand.n2();
+ int64 high_padding1 = padding.dimensions(1).edge_padding_high();
+ int64 low_padding1 = padding.dimensions(1).edge_padding_low();
+ int64 interior_padding1 = padding.dimensions(1).interior_padding();
+ int64 out1 =
+ in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1;
+
+ auto result = MakeUnique<Array2D<NativeT>>(out0, out1);
+ result->Fill(pad);
+ int64 o0 = low_padding0;
+ for (int64 i0 = 0; i0 < in0; ++i0) {
+ int64 o1 = low_padding1;
+ for (int64 i1 = 0; i1 < in1; ++i1) {
+ if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) {
+ (*result)(o0, o1) = operand(i0, i1);
+ }
+ o1 += interior_padding1 + 1;
+ }
+ o0 += interior_padding0 + 1;
+ }
+ return result;
+ }
// Returns the result of a 3D pad on an input matrix.
- static Array3D<float> PadArray3D(const Array3D<float>& operand,
- const PaddingConfig& padding,
- const float pad);
+ template <typename NativeT>
+ static Array3D<NativeT> PadArray3D(const Array3D<NativeT>& operand,
+ const PaddingConfig& padding,
+ const NativeT pad) {
+ CHECK_EQ(padding.dimensions_size(), 3);
+
+ const std::vector<int64> input_bounds = {operand.n1(), operand.n2(),
+ operand.n3()};
+ std::vector<int64> pad_low(3);
+ std::vector<int64> pad_high(3);
+ std::vector<int64> pad_interior(3);
+ std::vector<int64> output_bounds(3);
+ for (int64 i = 0; i < 3; ++i) {
+ pad_low[i] = padding.dimensions(i).edge_padding_low();
+ pad_high[i] = padding.dimensions(i).edge_padding_high();
+ CHECK_LE(0, pad_low[i]);
+ CHECK_LE(0, pad_high[i]);
+ CHECK_LE(0, padding.dimensions(i).interior_padding())
+ << "not implemented";
+ pad_interior[i] = padding.dimensions(i).interior_padding();
+
+ output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] +
+ (input_bounds[i] - 1) * pad_interior[i];
+ }
+
+ Array3D<NativeT> result(output_bounds[0], output_bounds[1],
+ output_bounds[2]);
+ std::vector<int> indices = {0, 0, 0};
+ for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) {
+ for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) {
+ for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) {
+ NativeT* value = &result(indices[0], indices[1], indices[2]);
+ bool value_padded = false;
+ for (int i = 0; i < 3; ++i) {
+ bool in_low_padding = indices[i] < pad_low[i];
+ bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
+ if (in_low_padding || in_high_padding) {
+ *value = pad;
+ value_padded = true;
+ }
+ if (pad_interior[i] &&
+ (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) {
+ *value = pad;
+ value_padded = true;
+ }
+ }
+ if (value_padded) {
+ continue;
+ }
+ *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1),
+ (indices[1] - pad_low[1]) / (pad_interior[1] + 1),
+ (indices[2] - pad_low[2]) / (pad_interior[2] + 1));
+ }
+ }
+ }
+ return result;
+ }
// Returns the result of a 4D pad on an input array.
- static Array4D<float> PadArray4D(const Array4D<float>& operand,
- const PaddingConfig& padding,
- const float pad);
+ template <typename NativeT>
+ static Array4D<NativeT> PadArray4D(const Array4D<NativeT>& operand,
+ const PaddingConfig& padding,
+ const NativeT pad) {
+ CHECK_EQ(padding.dimensions_size(), 4);
+
+ const std::vector<int64> input_bounds = {operand.n1(), operand.n2(),
+ operand.n3(), operand.n4()};
+ std::vector<int64> pad_low(4);
+ std::vector<int64> pad_high(4);
+ std::vector<int64> pad_interior(4);
+ std::vector<int64> output_bounds(4);
+ for (int64 i = 0; i < 4; ++i) {
+ pad_low[i] = padding.dimensions(i).edge_padding_low();
+ pad_high[i] = padding.dimensions(i).edge_padding_high();
+ CHECK_LE(0, padding.dimensions(i).interior_padding())
+ << "not implemented";
+ pad_interior[i] = padding.dimensions(i).interior_padding();
+
+ output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] +
+ (input_bounds[i] - 1) * pad_interior[i];
+ }
+
+ Array4D<NativeT> result(output_bounds[0], output_bounds[1],
+ output_bounds[2], output_bounds[3]);
+ result.Each(
+ [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT* value) {
+ for (int i = 0; i < 4; ++i) {
+ bool in_low_padding = indices[i] < pad_low[i];
+ bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
+ if (in_low_padding || in_high_padding) {
+ *value = pad;
+ return;
+ }
+ if (pad_interior[i] &&
+ (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) {
+ *value = pad;
+ return;
+ }
+ }
+ *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1),
+ (indices[1] - pad_low[1]) / (pad_interior[1] + 1),
+ (indices[2] - pad_low[2]) / (pad_interior[2] + 1),
+ (indices[3] - pad_low[3]) / (pad_interior[3] + 1));
+ });
+ return result;
+ }
// ApplyElementwise2D(f, x, y, ...) returns the Array2D formed by running
// f(x[i], y[i], ...) for each array element in the Array2Ds x, y, ....
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index ef54714e46..15bd273e9b 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -262,20 +262,34 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
expected.shape().element_type() == PRED)
<< ShapeUtil::HumanString(expected.shape());
}
+ // We allow using a float expected literal for a bfloat16 output. In this
+ // case, we need to convert the expected literal to bfloat16.
+ const Literal* expected_ptr = &expected;
+ std::unique_ptr<Literal> converted_expected;
+ Shape layout_shape;
+ if (expected.shape().element_type() == F32 && use_bfloat16_) {
+ converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected);
+ expected_ptr = converted_expected.get();
+ if (shape_with_layout != nullptr) {
+ layout_shape = *shape_with_layout;
+ layout_shape.set_element_type(BF16);
+ shape_with_layout = &layout_shape;
+ }
+ }
auto expect_equal = [&](const Literal& actual, const string& error_message) {
- LiteralTestUtil::ExpectEqual(expected, actual, error_message);
+ LiteralTestUtil::ExpectEqual(*expected_ptr, actual, error_message);
};
if (execution_options_.debug_options().xla_test_all_output_layouts()) {
return ComputeAndCompareLiteralWithAllOutputLayouts(
- computation, expected, arguments, expect_equal);
+ computation, *expected_ptr, arguments, expect_equal);
}
if (execution_options_.debug_options().xla_test_all_input_layouts()) {
return ComputeAndCompareLiteralWithAllInputLayouts(
- computation, expected, arguments, expect_equal, shape_with_layout);
+ computation, *expected_ptr, arguments, expect_equal, shape_with_layout);
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
- LiteralTestUtil::ExpectEqual(expected, *actual);
+ LiteralTestUtil::ExpectEqual(*expected_ptr, *actual);
return tensorflow::Status::OK();
}
@@ -286,20 +300,35 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) ||
ShapeUtil::ElementIsComplex(expected.shape()));
TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
+ // We allow using a float expected literal for a bfloat16 output. In this
+ // case, we need to convert the expected literal to bfloat16.
+ const Literal* expected_ptr = &expected;
+ std::unique_ptr<Literal> converted_expected;
+ Shape layout_shape;
+ if (expected.shape().element_type() == F32 && use_bfloat16_) {
+ converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected);
+ expected_ptr = converted_expected.get();
+ layout_shape.set_element_type(BF16);
+ if (shape_with_layout != nullptr) {
+ layout_shape = *shape_with_layout;
+ layout_shape.set_element_type(BF16);
+ shape_with_layout = &layout_shape;
+ }
+ }
auto expect_near = [&](const Literal& actual, const string& error_message) {
- LiteralTestUtil::ExpectNear(expected, actual, error, error_message);
+ LiteralTestUtil::ExpectNear(*expected_ptr, actual, error, error_message);
};
if (execution_options_.debug_options().xla_test_all_output_layouts()) {
- return ComputeAndCompareLiteralWithAllOutputLayouts(computation, expected,
- arguments, expect_near);
+ return ComputeAndCompareLiteralWithAllOutputLayouts(
+ computation, *expected_ptr, arguments, expect_near);
}
if (execution_options_.debug_options().xla_test_all_input_layouts()) {
return ComputeAndCompareLiteralWithAllInputLayouts(
- computation, expected, arguments, expect_near, shape_with_layout);
+ computation, *expected_ptr, arguments, expect_near, shape_with_layout);
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
- LiteralTestUtil::ExpectNear(expected, *actual, error);
+ LiteralTestUtil::ExpectNear(*expected_ptr, *actual, error);
return tensorflow::Status::OK();
}
@@ -402,8 +431,11 @@ ClientLibraryTestBase::ComputeValueAndReference(
Computation ClientLibraryTestBase::CreateScalarRelu() {
ComputationBuilder builder(client_, "relu");
- auto z_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value");
- auto zero = builder.ConstantR0<float>(0.0);
+ auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
+ auto z_value = builder.Parameter(0, shape, "z_value");
+ auto zero = use_bfloat16_
+ ? builder.ConstantR0<bfloat16>(static_cast<bfloat16>(0.0f))
+ : builder.ConstantR0<float>(0.0f);
builder.Max(z_value, zero);
auto computation_status = builder.Build();
TF_CHECK_OK(computation_status.status());
@@ -412,8 +444,9 @@ Computation ClientLibraryTestBase::CreateScalarRelu() {
Computation ClientLibraryTestBase::CreateScalarMax() {
ComputationBuilder builder(client_, "max");
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
+ auto x = builder.Parameter(0, shape, "x");
+ auto y = builder.Parameter(1, shape, "y");
builder.Max(x, y);
auto computation_status = builder.Build();
TF_CHECK_OK(computation_status.status());
@@ -422,11 +455,12 @@ Computation ClientLibraryTestBase::CreateScalarMax() {
Computation ClientLibraryTestBase::CreateScalarReluSensitivity() {
ComputationBuilder builder(client_, "relu_sensitivity");
- auto activation =
- builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "activation");
- auto backprop =
- builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "backprop");
- auto zero = builder.ConstantR0<float>(0.0);
+ auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
+ auto activation = builder.Parameter(0, shape, "activation");
+ auto backprop = builder.Parameter(1, shape, "backprop");
+ auto zero = use_bfloat16_
+ ? builder.ConstantR0<bfloat16>(static_cast<bfloat16>(0.0f))
+ : builder.ConstantR0<float>(0.0f);
auto activation_gtz = builder.Gt(activation, zero);
builder.Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero);
@@ -461,4 +495,21 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols,
return array;
}
+std::unique_ptr<GlobalData>
+ClientLibraryTestBase::CreateParameterAndTransferLiteral(
+ int64 parameter_number, const Literal& literal, const string& name,
+ ComputationBuilder* builder, ComputationDataHandle* data_handle) {
+ const Literal* param_literal = &literal;
+ std::unique_ptr<Literal> converted_literal;
+ if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal);
+ param_literal = converted_literal.get();
+ }
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+ *data_handle =
+ builder->Parameter(parameter_number, param_literal->shape(), name);
+ return data;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index af22c12684..e8599a5cd3 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -245,51 +245,76 @@ class ClientLibraryTestBase : public ::testing::Test {
const int rows, const int cols, const int rows_padded,
const int cols_padded);
- // Create a parameter instruction that wraps a given value and then stores
+ // Creates a parameter instruction, transfers the literal for the parameter to
+ // server, then stores into "data_handle" the global handle for that
+ // parameter. When the use_bfloat16 flag is set but the literal has F32
+ // elements, the literal will be converted to BF16 before being transferred.
+ std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral(
+ int64 parameter_number, const Literal& literal, const string& name,
+ ComputationBuilder* builder, ComputationDataHandle* data_handle);
+
+ // Creates a parameter instruction that wraps a given value and then stores
// into "data_handle" the global handle for that parameter.
//
// "parameter_number" is the parameter number.
// "name" is the name of the parameter instruction.
+ //
+ // When the use_bfloat16 flag is set but NativeT is float, the data will be
+ // converted to bfloat16.
template <typename NativeT>
std::unique_ptr<GlobalData> CreateR0Parameter(
NativeT value, int64 parameter_number, const string& name,
ComputationBuilder* builder, ComputationDataHandle* data_handle);
- // Create a parameter instruction that wraps the given values and then stores
+ // Creates a parameter instruction that wraps the given values and then stores
// into "data_handle" the global handle for that parameter.
//
// "parameter_number" is the parameter number.
// "name" is the name of the parameter instruction.
+ //
+ // When the use_bfloat16 flag is set but NativeT is float, the data will be
+ // converted to bfloat16.
template <typename NativeT>
std::unique_ptr<GlobalData> CreateR1Parameter(
tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
const string& name, ComputationBuilder* builder,
ComputationDataHandle* data_handle);
- // Create a parameter instruction that wraps the given constant array
+ // Creates a parameter instruction that wraps the given constant array
// "array_2d" and then stores to "data_handle" the global handle for that
// parameter.
//
// "parameter_number" is the parameter number.
// "name" is the name of the parameter instruction.
+ //
+ // When the use_bfloat16 flag is set but NativeT is float, the data will be
+ // converted to bfloat16.
template <typename NativeT>
std::unique_ptr<GlobalData> CreateR2Parameter(
const Array2D<NativeT>& array_2d, int64 parameter_number,
const string& name, ComputationBuilder* builder,
ComputationDataHandle* data_handle);
- // Create a parameter instruction that wraps the given constant array
+ // Creates a parameter instruction that wraps the given constant array
// "array_3d" and then stores to "data_handle" the global handle for that
// parameter.
//
// "parameter_number" is the parameter number.
// "name" is the name of the parameter instruction.
+ //
+ // When the use_bfloat16 flag is set but NativeT is float, the data will be
+ // converted to bfloat16.
template <typename NativeT>
std::unique_ptr<GlobalData> CreateR3Parameter(
const Array3D<NativeT>& array_3d, int64 parameter_number,
const string& name, ComputationBuilder* builder,
ComputationDataHandle* data_handle);
+ // Getter and setter for the use_bfloat16 flag, which indicates whether to run
+ // tests with all float-type input/output converted to bfloat16.
+ bool use_bfloat16() const { return use_bfloat16_; }
+ void set_use_bfloat16(bool value) { use_bfloat16_ = value; }
+
Client* client_;
ExecutionOptions execution_options_;
@@ -315,6 +340,10 @@ class ClientLibraryTestBase : public ::testing::Test {
ComputeValueAndReference(ComputationBuilder* builder,
const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<Literal> arguments);
+
+ // Whether to run tests with all float-type input/output converted to
+ // bfloat16.
+ bool use_bfloat16_ = false;
};
template <typename NativeT>
@@ -443,6 +472,9 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
NativeT value, int64 parameter_number, const string& name,
ComputationBuilder* builder, ComputationDataHandle* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR0(value);
+ if (use_bfloat16_ && literal->shape().element_type() == F32) {
+ literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
+ }
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
*data_handle = builder->Parameter(parameter_number, literal->shape(), name);
@@ -455,6 +487,9 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
const string& name, ComputationBuilder* builder,
ComputationDataHandle* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR1(values);
+ if (use_bfloat16_ && literal->shape().element_type() == F32) {
+ literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
+ }
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
*data_handle = builder->Parameter(parameter_number, literal->shape(), name);
@@ -467,6 +502,9 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
const string& name, ComputationBuilder* builder,
ComputationDataHandle* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR2FromArray2D(array_2d);
+ if (use_bfloat16_ && literal->shape().element_type() == F32) {
+ literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
+ }
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
*data_handle = builder->Parameter(parameter_number, literal->shape(), name);
@@ -479,6 +517,9 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
const string& name, ComputationBuilder* builder,
ComputationDataHandle* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(array_3d);
+ if (use_bfloat16_ && literal->shape().element_type() == F32) {
+ literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
+ }
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
*data_handle = builder->Parameter(parameter_number, literal->shape(), name);
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
index 9ae5c7b6f0..6aa27e5470 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -100,6 +100,38 @@ namespace xla {
ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString());
}
+/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertBF16ToF32(
+ const Literal& bf16_literal) {
+ CHECK_EQ(bf16_literal.shape().element_type(), BF16);
+ Shape converted_shape = bf16_literal.shape();
+ converted_shape.set_element_type(F32);
+ auto converted = Literal::CreateFromShape(converted_shape);
+ if (!ShapeUtil::HasZeroElements(converted_shape)) {
+ std::vector<int64> index(converted_shape.dimensions_size(), 0);
+ do {
+ converted->Set<float>(
+ index, static_cast<float>(bf16_literal.Get<bfloat16>(index)));
+ } while (IndexUtil::BumpIndices(converted_shape, &index));
+ }
+ return converted;
+}
+
+/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertF32ToBF16(
+ const Literal& f32_literal) {
+ CHECK_EQ(f32_literal.shape().element_type(), F32);
+ Shape converted_shape = f32_literal.shape();
+ converted_shape.set_element_type(BF16);
+ auto converted = Literal::CreateFromShape(converted_shape);
+ if (!ShapeUtil::HasZeroElements(converted_shape)) {
+ std::vector<int64> index(converted_shape.dimensions_size(), 0);
+ do {
+ converted->Set<bfloat16>(
+ index, static_cast<bfloat16>(f32_literal.Get<float>(index)));
+ } while (IndexUtil::BumpIndices(converted_shape, &index));
+ }
+ return converted;
+}
+
namespace {
string Hostname() {
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
index 467d44b857..6e4add2690 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.h
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -59,6 +59,12 @@ class LiteralTestUtil {
static void AssertEqualShapesAndLayouts(const Shape& expected,
const Shape& actual);
+ // Converts a bfloat16 literal to a float literal.
+ static std::unique_ptr<Literal> ConvertBF16ToF32(const Literal& bf16_literal);
+
+ // Converts a float literal to a bfloat16 literal.
+ static std::unique_ptr<Literal> ConvertF32ToBF16(const Literal& f32_literal);
+
// Asserts that the expected and actual literals are (bitwise) equal for all
// elements in the literal. Also, asserts that the rank, dimensions sizes, and
// primitive type are equal.