diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-04-05 21:04:00 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-05 22:27:06 -0700 |
commit | ffd4e6223d9cf388594256be80621d071661b307 (patch) | |
tree | c8103d449641e041013a930cc653fa9c765c1a27 | |
parent | 4b2d3e7a3f7a103a381ebc840d536fe1e094908c (diff) |
[XLA] Change HloPassPipeline to disallow Add* calls after Run.
Change: 152345578
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_pass_pipeline.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_pass_pipeline.h | 3 |
2 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 6e3c983071..eb7fe467b3 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -40,6 +40,8 @@ void DumpModule(const Compiler::HloDumper& dumper_, const HloModule& module, } // namespace StatusOr<bool> HloPassPipeline::Run(HloModule* module) { + run_called_ = true; + legacy_flags::HloPassPipelineFlags* flags = legacy_flags::GetHloPassPipelineFlags(); std::vector<string> tmp = diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index a8c2d51873..682c4b952d 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -47,6 +47,7 @@ class HloPassPipeline : public HloPassInterface { // Returns a reference to the added pass. template <typename T, typename... Args> T& AddPass(Args&&... args) { + CHECK(!run_called_) << "AddPass cannot be called after Run"; auto pass = new T(std::forward<Args>(args)...); passes_.push_back(std::unique_ptr<T>(pass)); return *pass; @@ -57,6 +58,7 @@ class HloPassPipeline : public HloPassInterface { // (it is required to always return "false" from its Run() method). template <typename T, typename... Args> T& AddInvariantChecker(Args&&... args) { + CHECK(!run_called_) << "AddInvariantChecker cannot be called after Run"; auto pass = new T(std::forward<Args>(args)...); invariant_checkers_.push_back(std::unique_ptr<T>(pass)); return *pass; @@ -70,6 +72,7 @@ class HloPassPipeline : public HloPassInterface { Compiler::HloDumper dumper_; std::vector<std::unique_ptr<HloPassInterface>> passes_; std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_; + bool run_called_ = false; TF_DISALLOW_COPY_AND_ASSIGN(HloPassPipeline); }; |