diff options
-rw-r--r-- | tensorflow/compiler/xla/tests/client_library_test_base.cc | 25 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/literal_test_util.cc | 38 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/literal_test_util.h | 8 |
3 files changed, 51 insertions, 20 deletions
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index bbd6a87ca3..50bf185936 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -267,12 +267,17 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( const Literal* expected_ptr = &expected; std::unique_ptr<Literal> converted_expected; Shape layout_shape; - if (expected.shape().element_type() == F32 && use_bfloat16_) { + if (use_bfloat16_) { converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; - layout_shape.set_element_type(BF16); + ShapeUtil::ForEachMutableSubshape( + &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { + if (subshape->element_type() == F32) { + subshape->set_element_type(BF16); + } + }); shape_with_layout = &layout_shape; } } @@ -305,13 +310,17 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( const Literal* expected_ptr = &expected; std::unique_ptr<Literal> converted_expected; Shape layout_shape; - if (expected.shape().element_type() == F32 && use_bfloat16_) { + if (use_bfloat16_) { converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); - layout_shape.set_element_type(BF16); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; - layout_shape.set_element_type(BF16); + ShapeUtil::ForEachMutableSubshape( + &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { + if (subshape->element_type() == F32) { + subshape->set_element_type(BF16); + } + }); shape_with_layout = &layout_shape; } } @@ -501,7 +510,7 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral( ComputationBuilder* builder, ComputationDataHandle* data_handle) { const Literal* param_literal = &literal; std::unique_ptr<Literal> converted_literal; - if (use_bfloat16_ && literal.shape().element_type() == F32) { + if (use_bfloat16_) { converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); param_literal = converted_literal.get(); } @@ -515,9 +524,7 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral( ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral( const Literal& literal, ComputationBuilder* builder) { return builder->ConstantLiteral( - use_bfloat16_ && literal.shape().element_type() == F32 - ? *LiteralTestUtil::ConvertF32ToBF16(literal) - : literal); + use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 6aa27e5470..e1a948c096 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -101,32 +101,52 @@ namespace xla { } /* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertBF16ToF32( - const Literal& bf16_literal) { - CHECK_EQ(bf16_literal.shape().element_type(), BF16); - Shape converted_shape = bf16_literal.shape(); + const Literal& literal) { + if (ShapeUtil::IsTuple(literal.shape())) { + std::vector<std::unique_ptr<Literal>> converted_elements; + for (const auto& element : literal.tuple_literals()) { + converted_elements.push_back(ConvertBF16ToF32(element)); + } + return Literal::MakeTupleOwned(std::move(converted_elements)); + } + + if (literal.shape().element_type() != BF16) { + return MakeUnique<Literal>(literal); + } + Shape converted_shape = literal.shape(); converted_shape.set_element_type(F32); auto converted = Literal::CreateFromShape(converted_shape); if (!ShapeUtil::HasZeroElements(converted_shape)) { std::vector<int64> index(converted_shape.dimensions_size(), 0); do { - converted->Set<float>( - index, static_cast<float>(bf16_literal.Get<bfloat16>(index))); + converted->Set<float>(index, + static_cast<float>(literal.Get<bfloat16>(index))); } while (IndexUtil::BumpIndices(converted_shape, &index)); } return converted; } /* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertF32ToBF16( - const Literal& f32_literal) { - CHECK_EQ(f32_literal.shape().element_type(), F32); - Shape converted_shape = f32_literal.shape(); + const Literal& literal) { + if (ShapeUtil::IsTuple(literal.shape())) { + std::vector<std::unique_ptr<Literal>> converted_elements; + for (const auto& element : literal.tuple_literals()) { + converted_elements.push_back(ConvertF32ToBF16(element)); + } + return Literal::MakeTupleOwned(std::move(converted_elements)); + } + + if (literal.shape().element_type() != F32) { + return MakeUnique<Literal>(literal); + } + Shape converted_shape = literal.shape(); converted_shape.set_element_type(BF16); auto converted = Literal::CreateFromShape(converted_shape); if (!ShapeUtil::HasZeroElements(converted_shape)) { std::vector<int64> index(converted_shape.dimensions_size(), 0); do { converted->Set<bfloat16>( - index, static_cast<bfloat16>(f32_literal.Get<float>(index))); + index, static_cast<bfloat16>(literal.Get<float>(index))); } while (IndexUtil::BumpIndices(converted_shape, &index)); } return converted; diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 6e4add2690..bf8c92f16d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -59,10 +59,14 @@ class LiteralTestUtil { static void AssertEqualShapesAndLayouts(const Shape& expected, const Shape& actual); - // Converts a bfloat16 literal to a float literal. + // If the given literal's data type is bfloat16, converts it to a float + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. static std::unique_ptr<Literal> ConvertBF16ToF32(const Literal& bf16_literal); - // Converts a float literal to a bfloat16 literal. + // If the given literal's data type is float, converts it to a bfloat16 + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. static std::unique_ptr<Literal> ConvertF32ToBF16(const Literal& f32_literal); // Asserts that the expected and actual literals are (bitwise) equal for all |