aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-05 21:04:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-05 22:27:06 -0700
commitffd4e6223d9cf388594256be80621d071661b307 (patch)
treec8103d449641e041013a930cc653fa9c765c1a27
parent4b2d3e7a3f7a103a381ebc840d536fe1e094908c (diff)
[XLA] Change HloPassPipeline to disallow Add* calls after Run.
Change: 152345578
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.h3
2 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index 6e3c983071..eb7fe467b3 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -40,6 +40,8 @@ void DumpModule(const Compiler::HloDumper& dumper_, const HloModule& module,
} // namespace
StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
+ run_called_ = true;
+
legacy_flags::HloPassPipelineFlags* flags =
legacy_flags::GetHloPassPipelineFlags();
std::vector<string> tmp =
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
index a8c2d51873..682c4b952d 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
@@ -47,6 +47,7 @@ class HloPassPipeline : public HloPassInterface {
// Returns a reference to the added pass.
template <typename T, typename... Args>
T& AddPass(Args&&... args) {
+ CHECK(!run_called_) << "AddPass cannot be called after Run";
auto pass = new T(std::forward<Args>(args)...);
passes_.push_back(std::unique_ptr<T>(pass));
return *pass;
@@ -57,6 +58,7 @@ class HloPassPipeline : public HloPassInterface {
// (it is required to always return "false" from its Run() method).
template <typename T, typename... Args>
T& AddInvariantChecker(Args&&... args) {
+ CHECK(!run_called_) << "AddInvariantChecker cannot be called after Run";
auto pass = new T(std::forward<Args>(args)...);
invariant_checkers_.push_back(std::unique_ptr<T>(pass));
return *pass;
@@ -70,6 +72,7 @@ class HloPassPipeline : public HloPassInterface {
Compiler::HloDumper dumper_;
std::vector<std::unique_ptr<HloPassInterface>> passes_;
std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;
+ bool run_called_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(HloPassPipeline);
};