diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/despecializer.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/despecializer.cc | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index d938f3a2c4..48e4471499 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -21,8 +21,33 @@ limitations under the License. namespace xla { +namespace { + +// Pass which strips control dependencies from all instructions in the module. +class ControlDepRemover : public HloPassInterface { + public: + ControlDepRemover() = default; + tensorflow::StringPiece name() const override { + return "control-dep-remover"; + } + + StatusOr<bool> Run(HloModule* module) override { + bool changed = false; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + changed = changed || !instruction->control_predecessors().empty(); + TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); + } + } + return changed; + } +}; + +} // namespace + Despecializer::Despecializer() : pipeline_("despecializer") { // TODO(b/70588125): Also deal with window reversal in a fast way. + pipeline_.AddPass<ControlDepRemover>(); pipeline_.AddPass<Defuser>(); pipeline_.AddPass<ImplicitBroadcastRemover>(); pipeline_.AddPass<BFloat16MixedPrecisionRemoval>(); |