aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-09-19 16:25:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 16:31:25 -0700
commit1d78936a3989f6ee5a9945746cd329c37e82287c (patch)
treea61d16ca252de1b21f6fdbbc86909e3b70a6eedc
parent9a7f252910bb2cc14092adc6e8163bd6e696c1f0 (diff)
Add VerifiedHloModule class.
VerifiedHloModule is derived from HloModule and verifies itself on destruction. This is designed to be used in HloVerifiedTestBase. This replaces the current mechanism which verifies HloModules in the TearDown method. The VerifiedHloModule approach is cleaner (less state on the test object) and more capable because these verified HLO modules can be passed to methods which require taking ownership of the module (eg, HlotestBase::Execute). This change required some changes to the parser which enables constructing the parsed HloModule into an already allocated HloModule. Some trivial changes to HloModule are required as well. PiperOrigin-RevId: 213718126
-rw-r--r--tensorflow/compiler/xla/service/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc83
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.h13
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc24
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc1
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.cc4
-rw-r--r--tensorflow/compiler/xla/tests/BUILD22
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.cc78
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.h63
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc158
11 files changed, 338 insertions, 114 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 4b183b4350..2bc50c70cf 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2605,7 +2605,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 3bc2d13781..735804e827 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -63,6 +63,7 @@ class HloModule {
// tests). The versioned handle is used by the service in the compilation
// cache. A default configuration is created for this module.
explicit HloModule(const string& name, const HloModuleConfig& config);
+ virtual ~HloModule() {}
// Adds an entry computation to the module. A module can only have one entry
// computation. Returns a pointer to the newly added computation.
@@ -87,6 +88,7 @@ class HloModule {
const std::unordered_map<HloComputation*, HloComputation*>& replacements);
const string& name() const { return name_; }
+ void set_name(string name) { name_ = std::move(name); }
// Returns a deep copy of this module including all computations.
std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
@@ -255,7 +257,7 @@ class HloModule {
std::unique_ptr<HloComputation> computation, bool is_entry,
bool uniquify_identifiers);
- const string name_;
+ string name_;
HloModuleConfig config_;
HloComputation* entry_computation_ = nullptr;
std::vector<std::unique_ptr<HloComputation>> computations_;
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 11caa89c54..37197b273b 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -64,14 +64,11 @@ class HloParser {
public:
using LocTy = HloLexer::LocTy;
- explicit HloParser(absl::string_view str, const HloModuleConfig& config)
- : lexer_(str), config_(config) {}
+ explicit HloParser(absl::string_view str) : lexer_(str) {}
- // Runs the parser. Returns false if an error occurred.
- bool Run();
-
- // Returns the parsed HloModule.
- std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
+ // Runs the parser and constructs the resulting HLO in the given (empty)
+ // HloModule. Returns false if an error occurred.
+ bool Run(HloModule* module);
// Returns the error information.
string GetError() const { return StrJoin(error_, "\n"); }
@@ -98,8 +95,8 @@ class HloParser {
const string& name, const optional<Shape>& shape = nullopt);
// ParseXXX returns false if an error occurred.
- bool ParseHloModule();
- bool ParseComputations();
+ bool ParseHloModule(HloModule* module);
+ bool ParseComputations(HloModule* module);
bool ParseComputation(HloComputation** entry_computation);
bool ParseInstructionList(HloComputation::Builder* builder,
string* root_name);
@@ -293,9 +290,7 @@ class HloParser {
computation_pool_;
HloLexer lexer_;
- std::unique_ptr<HloModule> module_;
std::vector<std::unique_ptr<HloComputation>> computations_;
- const HloModuleConfig config_;
std::vector<string> error_;
// Function that gets invoked when we try to resolve an instruction
@@ -349,9 +344,9 @@ bool HloParser::TokenError(absl::string_view msg) {
return Error(lexer_.GetLoc(), msg);
}
-bool HloParser::Run() {
+bool HloParser::Run(HloModule* module) {
lexer_.Lex();
- return ParseHloModule();
+ return ParseHloModule(module);
}
std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
@@ -366,7 +361,7 @@ std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
}
// ::= 'HloModule' name computations
-bool HloParser::ParseHloModule() {
+bool HloParser::ParseHloModule(HloModule* module) {
if (lexer_.GetKind() != TokKind::kw_HloModule) {
return TokenError("expects HloModule");
}
@@ -385,22 +380,20 @@ bool HloParser::ParseHloModule() {
return false;
}
- module_ = absl::make_unique<HloModule>(name, config_);
-
- if (!ParseComputations()) {
+ module->set_name(name);
+ if (!ParseComputations(module)) {
return false;
}
if (is_scheduled.has_value() && *is_scheduled) {
- TF_CHECK_OK(
- module_->set_schedule(ScheduleFromInstructionOrder(module_.get())));
+ TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
}
return true;
}
// computations ::= (computation)+
-bool HloParser::ParseComputations() {
+bool HloParser::ParseComputations(HloModule* module) {
HloComputation* entry_computation = nullptr;
do {
if (!ParseComputation(&entry_computation)) {
@@ -416,21 +409,20 @@ bool HloParser::ParseComputations() {
if ((entry_computation != nullptr &&
computations_[i].get() != entry_computation) ||
(entry_computation == nullptr && i != computations_.size() - 1)) {
- module_->AddEmbeddedComputation(std::move(computations_[i]));
+ module->AddEmbeddedComputation(std::move(computations_[i]));
continue;
}
- auto computation =
- module_->AddEntryComputation(std::move(computations_[i]));
+ auto computation = module->AddEntryComputation(std::move(computations_[i]));
// The parameters and result layouts were set to default layout. Here we
// set the layouts to what the hlo text says.
for (int p = 0; p < computation->num_parameters(); p++) {
const Shape& param_shape = computation->parameter_instruction(p)->shape();
- TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_parameter_layout(p)
->CopyLayoutFromShape(param_shape));
}
const Shape& result_shape = computation->root_instruction()->shape();
- TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_result_layout()
->CopyLayoutFromShape(result_shape));
}
@@ -3247,53 +3239,62 @@ Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder,
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
absl::string_view str, const HloModuleConfig& config) {
- HloParser parser(str, config);
- if (!parser.Run()) {
+ auto module = absl::make_unique<HloModule>(/*name=*/"", config);
+ HloParser parser(str);
+ if (!parser.Run(module.get())) {
return InvalidArgument("Syntax error:\n%s", parser.GetError());
}
- return parser.ConsumeHloModule();
+ return std::move(module);
}
StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) {
- HloModuleConfig config;
- return ParseHloString(str, config);
+ auto module = absl::make_unique<HloModule>(/*name=*/"", HloModuleConfig());
+ HloParser parser(str);
+ if (!parser.Run(module.get())) {
+ return InvalidArgument("Syntax error:\n%s", parser.GetError());
+ }
+ return std::move(module);
+}
+
+Status ParseHloString(absl::string_view str, HloModule* module) {
+ TF_RET_CHECK(module->computation_count() == 0);
+ HloParser parser(str);
+ if (!parser.Run(module)) {
+ return InvalidArgument("Syntax error:\n%s", parser.GetError());
+ }
+ return Status::OK();
}
StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
absl::string_view str, absl::string_view name) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
auto builder = absl::make_unique<HloComputation::Builder>(string(name));
string root_name;
TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name));
std::unique_ptr<HloComputation> computation = builder->Build();
- auto module = absl::make_unique<HloModule>(string(name), config);
+ auto module = absl::make_unique<HloModule>(string(name), HloModuleConfig());
module->AddEntryComputation(std::move(computation));
return std::move(module);
}
StatusOr<HloSharding> ParseSharding(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseShardingOnly();
}
StatusOr<Window> ParseWindow(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseWindowOnly();
}
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseConvolutionDimensionNumbersOnly();
}
StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParsePaddingConfigOnly();
}
diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h
index 1882a184da..3696035514 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.h
+++ b/tensorflow/compiler/xla/service/hlo_parser.h
@@ -30,18 +30,23 @@ namespace xla {
// For details about the syntax accepted by this parser, see
// g3doc/hlo_parser.md.
-// The api of the hlo parser. Given a string in the HloModule::ToString()
-// format, parses the string and creates a HloModule with the given config.
+// Given a string in the HloModule::ToString() format, parses the string and
+// creates a HloModule with the given config.
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
absl::string_view str, const HloModuleConfig& config);
+// Given a string in the HloModule::ToString() format, parses the string and
+// builds the HloModule in place at the given module pointer. 'module' must
+// point to an empty module (no computations).
+Status ParseHloString(absl::string_view str, HloModule* module);
+
// Parses the text for a single HLO operation into an HLO module with a function
// that runs that operation (with the same parameters) as its entry computation.
StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
absl::string_view str, absl::string_view name = "single_op");
-// The api of the hlo parser. Given a string in the HloModule::ToString()
-// format, parses the string and creates a HloModule with default config.
+// Given a string in the HloModule::ToString() format, parses the string and
+// creates a HloModule with default config.
StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str);
// Parses the result of HloSharding::ToString(), e.g. "{replicated}".
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
index e16b4d4c0a..ee8cb12b23 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
@@ -19,21 +19,21 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
-class HloPassPipelineTest : public HloTestBase {
+class HloPassPipelineTest : public HloVerifiedTestBase {
protected:
StatusOr<HloModuleGroup> ParseModuleGroup(
absl::Span<const string> hlo_strings) {
HloModuleGroup group(TestName());
for (const string& hlo_string : hlo_strings) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
- ParseHloString(hlo_string));
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
group.push_back(std::move(module));
}
return std::move(group);
@@ -106,8 +106,8 @@ ENTRY main {
ROOT foo = f32[] multiply(a, b)
}
)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
HloPassPipeline pipeline(TestName());
pipeline.AddPass<FooToBarModulePass>();
@@ -129,8 +129,8 @@ ENTRY main {
ROOT blahblah = f32[] multiply(a, b)
}
)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
HloPassPipeline pipeline(TestName());
pipeline.AddPass<FooToBarModulePass>();
@@ -191,8 +191,8 @@ ENTRY main {
ROOT foo = f32[] multiply(a, b)
}
)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
{
// Run a pipeline with just the invariant checker. It should not fail
// because there is no 'bar' instruction in the module.
@@ -243,8 +243,8 @@ ENTRY main {
ROOT foo = f32[] multiply(a, b)
}
)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(module_str));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
HloPassPipeline pipeline(TestName());
pipeline.AddPass<BazToQuxModuleGroupPass>();
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 50f39cbcb5..6eb6658904 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -1057,6 +1057,7 @@ Status VerifySendsAndRecvs(const HloModule& module) {
} // namespace
StatusOr<bool> HloVerifier::Run(HloModule* module) {
+ TF_RET_CHECK(!module->name().empty());
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));
diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc
index bd8fb17a23..ac2f79674f 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer.cc
@@ -39,8 +39,10 @@ NameUniquer::NameUniquer(const string& separator) {
}
/*static*/ string NameUniquer::GetSanitizedName(const string& name) {
+ if (name.empty()) {
+ return "";
+ }
string result = name;
- CHECK(!result.empty()) << "name should not be empty";
char c = static_cast<unsigned char>(result[0]);
if (!isalpha(c) && c != '_') {
result[0] = '_';
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index b49db029e2..fd3e3bfa94 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -154,11 +154,31 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/core:lib",
- "//tensorflow/core:test",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
],
)
+tf_cc_test(
+ name = "hlo_verified_test_base_test",
+ srcs = ["hlo_verified_test_base_test.cc"],
+ deps = [
+ ":hlo_test_base",
+ ":hlo_verified_test_base",
+ ":test_macros_cpu",
+ ":test_utils",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/service:hlo_verifier",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
tf_cc_binary(
name = "local_client_aot_test_helper",
srcs = ["local_client_aot_test_helper.cc"],
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index 8f86c528d0..8bd0a729b7 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -21,64 +21,68 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/test.h"
namespace xla {
-HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
- bool allow_mixed_precision)
- : HloTestBase(
- /*verifier_layout_sensitive=*/layout_sensitive,
- /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {}
-
-HloVerifiedTestBase::~HloVerifiedTestBase() {
- // We can't call the ASSERT or EXPECT test macros in destructors, so we
- // perform HLO verification in TearDown, and use the CHECK here to ensure
- // users don't accidentally override the verification.
- CHECK(tear_down_called_)
- << "TearDown was never called; subclasses of HloVerifiedTestBase that "
- << "override TearDown must call the superclass TearDown.";
-}
-
-void HloVerifiedTestBase::TearDown() {
- EXPECT_FALSE(tear_down_called_)
- << "TearDown called more than once; it should be called exactly once.";
- tear_down_called_ = true;
- if (module_) {
- VerifyModule(module_.get());
+Status VerifiedHloModule::Verify() {
+ if (computation_count() == 0) {
+ // The computation was never built. Nothing to verify.
+ return Status::OK();
}
- for (int i = 0; i < modules_.size(); ++i) {
- VerifyModule(modules_.at(i).get());
- }
- HloTestBase::TearDown();
+ return verifier_.Run(this).status();
}
-void HloVerifiedTestBase::VerifyModule(HloModule* module) {
- xla::StatusOr<bool> mutated = verifier().Run(module);
- if (!mutated.ok()) {
- ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
- } else {
- EXPECT_FALSE(mutated.ValueOrDie())
- << "HloVerifier should never mutate the HloModule";
+void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
+ Status status = Verify();
+ if (!status.ok()) {
+ ADD_FAILURE() << "HloVerifier failed on module " << name()
+ << (message.empty() ? "" : absl::StrCat(" (", message, ")"))
+ << ": " << status;
}
}
+HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
+ bool allow_mixed_precision)
+ : HloTestBase(
+ /*verifier_layout_sensitive=*/layout_sensitive,
+ /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision),
+ verifier_layout_sensitive_(layout_sensitive),
+ allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {}
+
HloModule& HloVerifiedTestBase::module() {
if (!module_) {
- module_ = HloTestBase::CreateNewModule();
+ module_ = CreateNewVerifiedModule(TestName());
}
return *module_;
}
HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) {
- modules_.emplace_back(HloTestBase::CreateNewModule());
+ modules_.emplace_back(CreateNewVerifiedModule(name));
return modules_.back().get();
}
void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config) {
CHECK(!module_) << "Called ParseModule when test already has a module.";
- TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config));
- VerifyModule(module_.get());
+ module_ = CreateNewVerifiedModule(TestName());
+ TF_CHECK_OK(ParseHloString(hlo_text, module_.get()));
+ module_->VerifyOrAddFailure("after parsing");
}
+
+StatusOr<std::unique_ptr<VerifiedHloModule>>
+HloVerifiedTestBase::ParseAndReturnVerifiedModule(
+ absl::string_view hlo_text, const HloModuleConfig& config) {
+ auto module = CreateNewVerifiedModule(TestName());
+ TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
+ TF_RETURN_IF_ERROR(module->Verify());
+ return std::move(module);
+}
+
+std::unique_ptr<VerifiedHloModule> HloVerifiedTestBase::CreateNewVerifiedModule(
+ const string& name) {
+ return absl::make_unique<VerifiedHloModule>(
+ name, GetModuleConfigForTest(), verifier_layout_sensitive_,
+ allow_mixed_precision_in_hlo_verifier_);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
index 8fbc4fa753..388a99bb36 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
@@ -20,53 +20,84 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/base/macros.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
-// A base class for HLO tests that stores a default HloModule, and automatically
-// performs verification on that module on tear-down.
+// An HLO module derived class which verifies itself on destruction. This class
+// is intended to be used in unit tests. Any verification errors are raised via
+// ADD_FAILURE.
+class VerifiedHloModule : public HloModule {
+ public:
+ VerifiedHloModule(const string& name, const HloModuleConfig& config,
+ bool verifier_layout_sensitive,
+ bool allow_mixed_precision_in_hlo_verifier)
+ : HloModule(name, config),
+ verifier_(verifier_layout_sensitive,
+ allow_mixed_precision_in_hlo_verifier) {}
+
+ ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
+
+ // Verifies the module using HloVerifier and returns the status.
+ Status Verify();
+
+ // Verifies the module and flags any error with ADD_FAILURE. 'message' is
+ // included in the failure message.
+ void VerifyOrAddFailure(const string& message);
+
+ private:
+ HloVerifier verifier_;
+};
+
+// A base class for HLO tests that stores a default VerifiedHloModule.
class HloVerifiedTestBase : public HloTestBase {
protected:
- explicit HloVerifiedTestBase(bool layout_sensitive = false,
- bool allow_mixed_precision = false);
- ~HloVerifiedTestBase() override;
+ HloVerifiedTestBase(bool layout_sensitive = false,
+ bool allow_mixed_precision = false);
// Constructs a default shape verifier.
std::unique_ptr<ShapeVerifier> MakeShapeVerifier();
- // Performs verification on the default HloModule returned by module().
- // Automatically called by the testing framework for each test.
- //
- // REQUIRED: subclasses that override TearDown() must call this explicitly.
- void TearDown() override;
-
// Returns the default HloModule, lazily creating it if necessary via
// HloTestBase::CreateNewModule().
+ ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.")
HloModule& module();
+
+ ABSL_DEPRECATED("Use ParseAndReturnVerifiedModule() instead.")
void ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config = HloModuleConfig());
+ // Parses the given string and returns module as a VerifiedHloModule.
+ StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
+ absl::string_view hlo_text,
+ const HloModuleConfig& config = HloModuleConfig());
+
// Creates a new module for a test, and stores it in modules_ so it can be
// verified. Intentionally hides HloTestBase::CreateNewModule, to prevent
// creation of unverified modules.
+ ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.")
HloModule* CreateNewModule(const string& name = TestName());
- private:
- void VerifyModule(HloModule* module);
+ // Creates and returns a verified HLO module with the given name.
+ std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
+ const string& name = TestName());
+ private:
// It is confusing to store modules created by module() and CreateNewModule()
// in different fields, but it allows us to migrate tests to
// HloVerifiedTestBase more easily, so it's a win because we can verify more
// modules. See b/80488902.
//
// Lazily populated. Access via module().
- std::unique_ptr<HloModule> module_;
+ std::unique_ptr<VerifiedHloModule> module_;
+
// Populated by calls to CreateNewModule.
- std::vector<std::unique_ptr<HloModule>> modules_;
+ std::vector<std::unique_ptr<VerifiedHloModule>> modules_;
- bool tear_down_called_ = false;
+ bool verifier_layout_sensitive_;
+ bool allow_mixed_precision_in_hlo_verifier_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc
new file mode 100644
index 0000000000..5c0263e811
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc
@@ -0,0 +1,158 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_verifier.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+// This class includes unit tests which are expected to fail because invalid HLO
+// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to
+// include the necessary gunit parts to test this test machinery (needs the
+// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the
+// disabled tests enabled and failures can be manually compared against
+// expectations.
+class HloVerifiedTestBaseTest : public HloVerifiedTestBase {};
+
+XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) {
+ // Test shouldn't fail if no module is created at all.
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, GoodLazilyCreatedModule) {
+ // Use module() to lazily create an empty module, build it up, and verify no
+ // failures.
+ HloModule& hlo_module = module();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ hlo_module.AddEntryComputation(builder.Build());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadLazilyCreatedModule) {
+ // Use module() to lazily create an empty module and build up an invalid
+ // module.
+ HloModule& hlo_module = module();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ hlo_module.AddEntryComputation(builder.Build());
+
+ *hlo_module.entry_computation()->root_instruction()->mutable_shape() =
+ ShapeUtil::MakeShape(PRED, {1, 2, 3});
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewModule) {
+ // Call CreateNewModule and build up a valid module.
+ HloModule* module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ module->AddEntryComputation(builder.Build());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewModule) {
+ // Call CreateNewModule and build up a invalid module.
+ HloModule* module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ module->AddEntryComputation(builder.Build());
+
+ *module->entry_computation()->root_instruction()->mutable_shape() =
+ ShapeUtil::MakeShape(PRED, {1, 2, 3});
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndVerifyModuleGood) {
+ const char* const hlo_string = R"(
+HloModule ParseAndVerifyModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+)";
+
+ ParseAndVerifyModule(hlo_string);
+ EXPECT_EQ(module().entry_computation()->instruction_count(), 3);
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ EXPECT_EQ(module->entry_computation()->instruction_count(), 3);
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+
+RANDOM GARBAGE
+)";
+
+ ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleBad
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[1234] add(x,y)
+}
+)";
+
+ ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
+}
+
+} // namespace
+} // namespace xla