aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_pass_pipeline.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.h3
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
index a8c2d51873..682c4b952d 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
@@ -47,6 +47,7 @@ class HloPassPipeline : public HloPassInterface {
// Returns a reference to the added pass.
template <typename T, typename... Args>
T& AddPass(Args&&... args) {
+ CHECK(!run_called_) << "AddPass cannot be called after Run";
auto pass = new T(std::forward<Args>(args)...);
passes_.push_back(std::unique_ptr<T>(pass));
return *pass;
@@ -57,6 +58,7 @@ class HloPassPipeline : public HloPassInterface {
// (it is required to always return "false" from its Run() method).
template <typename T, typename... Args>
T& AddInvariantChecker(Args&&... args) {
+ CHECK(!run_called_) << "AddInvariantChecker cannot be called after Run";
auto pass = new T(std::forward<Args>(args)...);
invariant_checkers_.push_back(std::unique_ptr<T>(pass));
return *pass;
@@ -70,6 +72,7 @@ class HloPassPipeline : public HloPassInterface {
Compiler::HloDumper dumper_;
std::vector<std::unique_ptr<HloPassInterface>> passes_;
std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;
+ bool run_called_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(HloPassPipeline);
};