aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-10-02 13:08:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 13:13:22 -0700
commit78e4ce52aeda5a10ddaf5e64ea8958f439a2f9f2 (patch)
tree096cb5b28777e227053a50e622313cfe66e56744 /tensorflow/compiler
parent0a201955b47d484c6bfa149364c264a5b5f91be7 (diff)
Add proto serialization/deserialization testing to the HLO parser tests.
Many of the HLO parser tests verify that an text form of an HLO module preserves all information when running through ToString then parsing. It makes sense to also use these tests to exercise proto serialization/deserialization. This is done by adding additional instantiations of the parameterized parsing tests. This caught several bugs which are fixed in this CL: (1) Domain instructions were not being serialized properly. (2) Host send/recv instructions did not preserve the is_host_transfer bit. (3) Sparse literals could not be serialized or deserialized. PiperOrigin-RevId: 215445200
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/literal.cc18
-rw-r--r--tensorflow/compiler/xla/literal_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/BUILD20
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto6
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc33
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc85
8 files changed, 141 insertions, 55 deletions
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 5035f41988..d1dad0d45f 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -1850,6 +1850,24 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape()));
+ if (LayoutUtil::IsSparseArray(subshape())) {
+ // Compute the number of elements (indices) in the sparse shape and reserve
+ // the necessary space in spare_indices.
+ TF_RET_CHECK(ShapeUtil::Rank(subshape()) != 0)
+ << "Scalar shapes cannot be sparse";
+ TF_RET_CHECK(proto.sparse_indices_size() % ShapeUtil::Rank(subshape()) == 0)
+ << "Unexpected number of indices in proto ("
+ << proto.sparse_indices_size() << ") for shape of rank "
+ << ShapeUtil::Rank(subshape());
+ const int64 index_count =
+ proto.sparse_indices_size() / ShapeUtil::Rank(subshape());
+ sparse_indices()->Resize(index_count);
+
+ // Copy the indices from the proto into the SparseIndexArray object.
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(sparse_indices()->mutable_data(),
+ proto.sparse_indices()));
+ }
+
switch (subshape().element_type()) {
case PRED:
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc
index 7ad287c897..dd5b54e4c9 100644
--- a/tensorflow/compiler/xla/literal_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -224,6 +224,16 @@ TEST_F(LiteralUtilTest, CreateSparse) {
absl::Span<const int64>(expected_indices.data(),
expected_indices.num_elements()));
EXPECT_EQ(literal.data<int64>(), absl::Span<const int64>(expected_values));
+
+ // Serialize then deserialize and verify the resulting literal.
+ TF_ASSERT_OK_AND_ASSIGN(Literal literal_from_proto,
+ Literal::CreateFromProto(literal.ToProto()));
+
+ EXPECT_EQ(literal_from_proto.sparse_indices()->data(),
+ absl::Span<const int64>(expected_indices.data(),
+ expected_indices.num_elements()));
+ EXPECT_EQ(literal_from_proto.data<int64>(),
+ absl::Span<const int64>(expected_values));
}
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 3f8b734afb..f329a27e14 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -300,6 +300,7 @@ cc_library(
"hlo_opcode.cc",
"hlo_schedule.cc",
"hlo_sharding.cc",
+ "hlo_sharding_metadata.cc",
],
hdrs = [
"dfs_hlo_visitor.h",
@@ -313,6 +314,7 @@ cc_library(
"hlo_opcode.h",
"hlo_schedule.h",
"hlo_sharding.h",
+ "hlo_sharding_metadata.h",
],
deps = [
":hlo_casting_utils",
@@ -2760,22 +2762,6 @@ cc_library(
)
cc_library(
- name = "hlo_sharding_metadata",
- srcs = ["hlo_sharding_metadata.cc"],
- hdrs = [
- "hlo_sharding_metadata.h",
- ],
- deps = [
- ":hlo",
- "//tensorflow/compiler/xla:shape_tree",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/core:lib",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
name = "hlo_domain_verifier",
srcs = ["hlo_domain_verifier.cc"],
hdrs = ["hlo_domain_verifier.h"],
@@ -2825,7 +2811,6 @@ tf_cc_test(
":hlo_domain_isolator",
":hlo_domain_remover",
":hlo_parser",
- ":hlo_sharding_metadata",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
@@ -3441,7 +3426,6 @@ cc_library(
deps = [
":hlo",
":hlo_lexer",
- ":hlo_sharding_metadata",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index caaca16f71..1ea26ddd5b 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
-// Next ID: 54
+// Next ID: 56
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@@ -180,6 +180,10 @@ message HloInstructionProto {
// Collective permute field.
repeated SourceTarget source_target_pairs = 52;
+
+ // Sharding for kDomain instructions.
+ xla.OpSharding domain_entry_sharding = 54;
+ xla.OpSharding domain_exit_sharding = 55;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 0207f9ae3f..de22b2d3a5 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -467,14 +468,27 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
proto.dot_dimension_numbers(), precision_config);
break;
}
- case HloOpcode::kDomain:
+ case HloOpcode::kDomain: {
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "Domain instruction should have 1 operands but sees "
<< proto.operand_ids_size();
+ TF_RET_CHECK(proto.has_domain_entry_sharding())
+ << "Domain instruction must domain_entry_sharding";
+ TF_RET_CHECK(proto.has_domain_exit_sharding())
+ << "Domain instruction must domain_exit_sharding";
+ TF_ASSIGN_OR_RETURN(
+ HloSharding entry_hlo_sharding,
+ HloSharding::FromProto(proto.domain_entry_sharding()));
+ TF_ASSIGN_OR_RETURN(HloSharding exit_hlo_sharding,
+ HloSharding::FromProto(proto.domain_exit_sharding()));
instruction = absl::make_unique<HloDomainInstruction>(
- proto.shape(), operands(0), /*operand_side_metadata=*/nullptr,
- /*user_side_metadata=*/nullptr);
+ proto.shape(), operands(0),
+ absl::make_unique<ShardingMetadata>(
+ std::make_shared<const HloSharding>(entry_hlo_sharding)),
+ absl::make_unique<ShardingMetadata>(
+ std::make_shared<const HloSharding>(exit_hlo_sharding)));
break;
+ }
default: {
instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
@@ -482,12 +496,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< "No instruction with id " << operand_id;
instruction->AppendOperand(instruction_map.at(operand_id));
}
- for (const int64 predecessor_id : proto.control_predecessor_ids()) {
- TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
- << "No instruction with id " << predecessor_id;
- TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
- ->AddControlDependencyTo(instruction.get()));
- }
if (instruction->opcode() != HloOpcode::kFusion) {
for (const int64 computation_id : proto.called_computation_ids()) {
TF_RET_CHECK(ContainsKey(computation_map, computation_id))
@@ -503,6 +511,13 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
}
+ for (const int64 predecessor_id : proto.control_predecessor_ids()) {
+ TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
+ << "No instruction with id " << predecessor_id;
+ TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
+ ->AddControlDependencyTo(instruction.get()));
+ }
+
TF_RET_CHECK(!proto.name().empty());
instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 1bc168c8b7..68d0979f5c 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/window_util.h"
namespace xla {
@@ -213,6 +214,7 @@ HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
HloInstructionProto HloSendRecvInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
proto.set_channel_id(channel_id_);
+ proto.set_is_host_transfer(is_host_transfer_);
return proto;
}
@@ -2310,4 +2312,23 @@ std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
shape, new_operands[0], operand_side_metadata_->Clone(),
user_side_metadata_->Clone());
}
+
+HloInstructionProto HloDomainInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ auto operand_side_sharding =
+ dynamic_cast<const ShardingMetadata*>(operand_side_metadata_.get());
+ if (operand_side_sharding) {
+ *proto.mutable_domain_entry_sharding() =
+ operand_side_sharding->sharding()->ToProto();
+ }
+
+ auto user_side_sharding =
+ dynamic_cast<const ShardingMetadata*>(user_side_metadata_.get());
+ if (user_side_sharding) {
+ *proto.mutable_domain_exit_sharding() =
+ user_side_sharding->sharding()->ToProto();
+ }
+
+ return proto;
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 9c22f5db7e..c929867bb9 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1341,6 +1341,9 @@ class HloDomainInstruction : public HloInstruction {
std::unique_ptr<DomainMetadata> operand_side_metadata,
std::unique_ptr<DomainMetadata> user_side_metadata);
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
// Retrieves the operand side metadata of a kDomain instruction.
const DomainMetadata& operand_side_metadata() const {
return *operand_side_metadata_;
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 96db96bdb9..dd4ee780f0 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1163,49 +1163,80 @@ ENTRY Sort {
// clang-format on
}
-class HloParserTest : public ::testing::Test,
- public ::testing::WithParamInterface<TestData> {
+// The test class for those tests defined above which round-trip through the
+// parser and ToString is templatized on two bool parameters:
+//
+// short_form : used for the "short" test cases which use the ShortParsable
+// output form.
+// proto_round_trip : whether the module should also be round-tripped through
+// HloProto form. This provides much better coverage for the proto
+// serialization/deserialization.
+//
+// The proto_round_trip=true case also technically covers the Parser->ToString
+// roundtrip as well, but separating out the Parser->ToString roundtrip as its
+// own test provides better isolation and could conceivably catch weirdo bugs
+// which are hidden by interaction between the textual and proto roundtripping.
+template <bool short_form, bool proto_round_trip>
+class HloParameterizedParserTest
+ : public ::testing::Test,
+ public ::testing::WithParamInterface<TestData> {
protected:
- static void ExpectHasSubstr(string_view s, string_view expected) {
- EXPECT_TRUE(absl::StrContains(s, expected))
- << "'" << s << "' does not contain '" << expected << "'";
- }
-
// Expects "ToString(ParseHloString(string)) == string", that is, parses the
// string, asserts that it succeeded, stringifies the parsed module, and
// checks that the it equals the original string.
void ExpectEqual() {
const string& original = GetParam().module_string;
- auto result = ParseHloString(original);
- TF_ASSERT_OK(result.status());
- EXPECT_EQ(original, result.ValueOrDie()->ToString(
- HloPrintOptions().set_print_large_constants(true)));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(original));
+ if (proto_round_trip) {
+ TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
+ module->ToProto(), module->config()));
+ }
+ if (short_form) {
+ EXPECT_EQ(original, module->ToString(HloPrintOptions::ShortParsable()));
+ } else {
+ EXPECT_EQ(
+ original,
+ module->ToString(HloPrintOptions().set_print_large_constants(true)));
+ }
}
};
-class HloParserShortTest : public HloParserTest {
- protected:
- void ExpectEqualShort() {
- const string& original = GetParam().module_string;
- auto result = ParseHloString(original);
- TF_ASSERT_OK(result.status());
- EXPECT_EQ(original,
- result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable()));
- }
-};
+// These using shenanigans are required because the TEST_P macro doesn't like
+// template instantiations which contain commas.
+using HloParserTestLong = HloParameterizedParserTest<false, false>;
+using HloParserTestLongProto = HloParameterizedParserTest<false, true>;
+using HloParserTestShort = HloParameterizedParserTest<true, false>;
+using HloParserTestShortProto = HloParameterizedParserTest<true, true>;
-TEST_P(HloParserTest, Run) { ExpectEqual(); }
+TEST_P(HloParserTestLong, Run) { ExpectEqual(); }
+TEST_P(HloParserTestLongProto, Run) { ExpectEqual(); }
+TEST_P(HloParserTestShort, Run) { ExpectEqual(); }
+TEST_P(HloParserTestShortProto, Run) { ExpectEqual(); }
-TEST_P(HloParserShortTest, Run) { ExpectEqualShort(); }
-
-INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest,
+INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestLong,
::testing::ValuesIn(CreateTestCases()),
TestDataToString);
-
-INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest,
+INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation,
+ HloParserTestLongProto,
+ ::testing::ValuesIn(CreateTestCases()),
+ TestDataToString);
+INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestShort,
+ ::testing::ValuesIn(CreateShortTestCases()),
+ TestDataToString);
+INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation,
+ HloParserTestShortProto,
::testing::ValuesIn(CreateShortTestCases()),
TestDataToString);
+class HloParserTest : public ::testing::Test {
+ protected:
+ static void ExpectHasSubstr(string_view s, string_view expected) {
+ EXPECT_TRUE(absl::StrContains(s, expected))
+ << "'" << s << "' does not contain '" << expected << "'";
+ }
+};
+
TEST_F(HloParserTest, Empty) {
const string original = "";
auto result = ParseHloString(original);