aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-09-19 08:12:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 08:16:42 -0700
commitf8655c08cfe3bd99ec1703211e1c9154a14a6150 (patch)
tree90bf5c29d3a1f77764c2f2392c4b0564b490c995 /tensorflow/compiler/xla/service/hlo_pass_pipeline.h
parente1db78697b05be673562fe2b1c9a995d25a71d4c (diff)
Add interface for HLO passes which run on HloModuleGroup.
Derive HloModulePass and HloModuleGroupPass from HloPassInterface which run module-scoped and module-group-scoped respectively. Replace all existing uses of HloPassInterface with HloModulePass because all existing passes are module-scoped. Also rewrite HloPassPipeline to support both module-scoped and module-group-scoped passes. PiperOrigin-RevId: 213629604
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_pass_pipeline.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.h38
1 files changed, 37 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
index 1d41a4dac1..09e7033ea4 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -61,10 +62,45 @@ class HloPassPipeline : public HloPassInterface {
return *pass;
}
- // Run all passes on the given HLO module.
StatusOr<bool> Run(HloModule* module) override;
+ StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override;
private:
+ // Returns the set of passes which are enabled. DebugOptions can selectively
+ // disable passes via --xla_disable_hlo_passes flag.
+ std::vector<HloPassInterface*> GetEnabledPasses(
+ const DebugOptions& debug_options);
+
+ // Maybe dumps the given module or module group depending on flag values
+ // contained in DebugOptions of module config.
+ void MaybeDumpHlo(const HloModuleGroup& module_group,
+ absl::string_view after_pass_name,
+ absl::string_view before_pass_name);
+ void MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name,
+ absl::string_view before_pass_name);
+
+ // Runs the invariant checker on the given HLO. HloT can be either HloModule
+ // or HloModuleGroup.
+ template <typename HloT>
+ Status RunInvariantCheckers(HloT* hlo, absl::string_view after_pass_name);
+
+ // Helper which runs the given pass on the given HLO. HloT can be either
+ // HloModule or HloModuleGroup.
+ template <typename HloT>
+ StatusOr<bool> RunPassesInternal(HloT* hlo,
+ absl::Span<HloPassInterface* const> passes);
+
+ // Helpers which run the given passes on the given HLO construct. These
+ // helpers enable templating of the core of the pipeline logic by providing
+ // HloModule and HloModuleGroup specific methods with the same name.
+ static StatusOr<bool> RunHelper(HloPassInterface* pass, HloModule* module) {
+ return pass->Run(module);
+ }
+ static StatusOr<bool> RunHelper(HloPassInterface* pass,
+ HloModuleGroup* module_group) {
+ return pass->RunOnModuleGroup(module_group);
+ }
+
const string name_;
std::vector<std::unique_ptr<HloPassInterface>> passes_;
std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;