aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc25
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.cc38
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h8
3 files changed, 51 insertions, 20 deletions
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index bbd6a87ca3..50bf185936 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -267,12 +267,17 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
const Literal* expected_ptr = &expected;
std::unique_ptr<Literal> converted_expected;
Shape layout_shape;
- if (expected.shape().element_type() == F32 && use_bfloat16_) {
+ if (use_bfloat16_) {
converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected);
expected_ptr = converted_expected.get();
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
- layout_shape.set_element_type(BF16);
+ ShapeUtil::ForEachMutableSubshape(
+ &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) {
+ if (subshape->element_type() == F32) {
+ subshape->set_element_type(BF16);
+ }
+ });
shape_with_layout = &layout_shape;
}
}
@@ -305,13 +310,17 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
const Literal* expected_ptr = &expected;
std::unique_ptr<Literal> converted_expected;
Shape layout_shape;
- if (expected.shape().element_type() == F32 && use_bfloat16_) {
+ if (use_bfloat16_) {
converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected);
expected_ptr = converted_expected.get();
- layout_shape.set_element_type(BF16);
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
- layout_shape.set_element_type(BF16);
+ ShapeUtil::ForEachMutableSubshape(
+ &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) {
+ if (subshape->element_type() == F32) {
+ subshape->set_element_type(BF16);
+ }
+ });
shape_with_layout = &layout_shape;
}
}
@@ -501,7 +510,7 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral(
ComputationBuilder* builder, ComputationDataHandle* data_handle) {
const Literal* param_literal = &literal;
std::unique_ptr<Literal> converted_literal;
- if (use_bfloat16_ && literal.shape().element_type() == F32) {
+ if (use_bfloat16_) {
converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal);
param_literal = converted_literal.get();
}
@@ -515,9 +524,7 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral(
ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral(
const Literal& literal, ComputationBuilder* builder) {
return builder->ConstantLiteral(
- use_bfloat16_ && literal.shape().element_type() == F32
- ? *LiteralTestUtil::ConvertF32ToBF16(literal)
- : literal);
+ use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
index 6aa27e5470..e1a948c096 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -101,32 +101,52 @@ namespace xla {
}
/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertBF16ToF32(
- const Literal& bf16_literal) {
- CHECK_EQ(bf16_literal.shape().element_type(), BF16);
- Shape converted_shape = bf16_literal.shape();
+ const Literal& literal) {
+ if (ShapeUtil::IsTuple(literal.shape())) {
+ std::vector<std::unique_ptr<Literal>> converted_elements;
+ for (const auto& element : literal.tuple_literals()) {
+ converted_elements.push_back(ConvertBF16ToF32(element));
+ }
+ return Literal::MakeTupleOwned(std::move(converted_elements));
+ }
+
+ if (literal.shape().element_type() != BF16) {
+ return MakeUnique<Literal>(literal);
+ }
+ Shape converted_shape = literal.shape();
converted_shape.set_element_type(F32);
auto converted = Literal::CreateFromShape(converted_shape);
if (!ShapeUtil::HasZeroElements(converted_shape)) {
std::vector<int64> index(converted_shape.dimensions_size(), 0);
do {
- converted->Set<float>(
- index, static_cast<float>(bf16_literal.Get<bfloat16>(index)));
+ converted->Set<float>(index,
+ static_cast<float>(literal.Get<bfloat16>(index)));
} while (IndexUtil::BumpIndices(converted_shape, &index));
}
return converted;
}
/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertF32ToBF16(
- const Literal& f32_literal) {
- CHECK_EQ(f32_literal.shape().element_type(), F32);
- Shape converted_shape = f32_literal.shape();
+ const Literal& literal) {
+ if (ShapeUtil::IsTuple(literal.shape())) {
+ std::vector<std::unique_ptr<Literal>> converted_elements;
+ for (const auto& element : literal.tuple_literals()) {
+ converted_elements.push_back(ConvertF32ToBF16(element));
+ }
+ return Literal::MakeTupleOwned(std::move(converted_elements));
+ }
+
+ if (literal.shape().element_type() != F32) {
+ return MakeUnique<Literal>(literal);
+ }
+ Shape converted_shape = literal.shape();
converted_shape.set_element_type(BF16);
auto converted = Literal::CreateFromShape(converted_shape);
if (!ShapeUtil::HasZeroElements(converted_shape)) {
std::vector<int64> index(converted_shape.dimensions_size(), 0);
do {
converted->Set<bfloat16>(
- index, static_cast<bfloat16>(f32_literal.Get<float>(index)));
+ index, static_cast<bfloat16>(literal.Get<float>(index)));
} while (IndexUtil::BumpIndices(converted_shape, &index));
}
return converted;
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
index 6e4add2690..bf8c92f16d 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.h
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -59,10 +59,14 @@ class LiteralTestUtil {
static void AssertEqualShapesAndLayouts(const Shape& expected,
const Shape& actual);
- // Converts a bfloat16 literal to a float literal.
+ // If the given literal's data type is bfloat16, converts it to a float
+ // literal; otherwise, returns a copy of it. If the literal is a tuple,
+ // recursively converts its elements.
static std::unique_ptr<Literal> ConvertBF16ToF32(const Literal& bf16_literal);
- // Converts a float literal to a bfloat16 literal.
+ // If the given literal's data type is float, converts it to a bfloat16
+ // literal; otherwise, returns a copy of it. If the literal is a tuple,
+ // recursively converts its elements.
static std::unique_ptr<Literal> ConvertF32ToBF16(const Literal& f32_literal);
// Asserts that the expected and actual literals are (bitwise) equal for all