diff options
author | Mark Heffernan <meheff@google.com> | 2018-08-15 12:24:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-15 12:28:56 -0700 |
commit | d1b9cc82d8b5c4541ebf9f01def504dd20d1da00 (patch) | |
tree | 6fcf0bbc96ca19a038f6c1a1da463aaaf9df9161 /tensorflow/compiler/xla/service/despecializer.cc | |
parent | 982be2b71d08cda624c3d95dfee31271e9829170 (diff) |
Strip control dependencies in despecializer.
Control dependencies are an artifact left over from lowering. These should be removed when raising to a despecialized form.
PiperOrigin-RevId: 208862729
Diffstat (limited to 'tensorflow/compiler/xla/service/despecializer.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/despecializer.cc | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index d938f3a2c4..48e4471499 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -21,8 +21,33 @@ limitations under the License. namespace xla { +namespace { + +// Pass which strips control dependencies from all instructions in the module. +class ControlDepRemover : public HloPassInterface { + public: + ControlDepRemover() = default; + tensorflow::StringPiece name() const override { + return "control-dep-remover"; + } + + StatusOr<bool> Run(HloModule* module) override { + bool changed = false; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + changed = changed || !instruction->control_predecessors().empty(); + TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); + } + } + return changed; + } +}; + +} // namespace + Despecializer::Despecializer() : pipeline_("despecializer") { // TODO(b/70588125): Also deal with window reversal in a fast way. + pipeline_.AddPass<ControlDepRemover>(); pipeline_.AddPass<Defuser>(); pipeline_.AddPass<ImplicitBroadcastRemover>(); pipeline_.AddPass<BFloat16MixedPrecisionRemoval>(); |