aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/despecializer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/despecializer.cc')
-rw-r--r--tensorflow/compiler/xla/service/despecializer.cc25
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc
index d938f3a2c4..48e4471499 100644
--- a/tensorflow/compiler/xla/service/despecializer.cc
+++ b/tensorflow/compiler/xla/service/despecializer.cc
@@ -21,8 +21,33 @@ limitations under the License.
namespace xla {
+namespace {
+
+// Pass which strips control dependencies from all instructions in the module.
+class ControlDepRemover : public HloPassInterface {
+ public:
+ ControlDepRemover() = default;
+ tensorflow::StringPiece name() const override {
+ return "control-dep-remover";
+ }
+
+ StatusOr<bool> Run(HloModule* module) override {
+ bool changed = false;
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ changed = changed || !instruction->control_predecessors().empty();
+ TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
+ }
+ }
+ return changed;
+ }
+};
+
+} // namespace
+
Despecializer::Despecializer() : pipeline_("despecializer") {
// TODO(b/70588125): Also deal with window reversal in a fast way.
+ pipeline_.AddPass<ControlDepRemover>();
pipeline_.AddPass<Defuser>();
pipeline_.AddPass<ImplicitBroadcastRemover>();
pipeline_.AddPass<BFloat16MixedPrecisionRemoval>();