aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/despecializer.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-08-15 12:24:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-15 12:28:56 -0700
commitd1b9cc82d8b5c4541ebf9f01def504dd20d1da00 (patch)
tree6fcf0bbc96ca19a038f6c1a1da463aaaf9df9161 /tensorflow/compiler/xla/service/despecializer.cc
parent982be2b71d08cda624c3d95dfee31271e9829170 (diff)
Strip control dependencies in despecializer.
Control dependencies are an artifact left over from lowering. These should be removed when raising to a despecialized form. PiperOrigin-RevId: 208862729
Diffstat (limited to 'tensorflow/compiler/xla/service/despecializer.cc')
-rw-r--r--tensorflow/compiler/xla/service/despecializer.cc25
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc
index d938f3a2c4..48e4471499 100644
--- a/tensorflow/compiler/xla/service/despecializer.cc
+++ b/tensorflow/compiler/xla/service/despecializer.cc
@@ -21,8 +21,33 @@ limitations under the License.
namespace xla {
+namespace {
+
+// Pass which strips control dependencies from all instructions in the module.
+class ControlDepRemover : public HloPassInterface {
+ public:
+ ControlDepRemover() = default;
+ tensorflow::StringPiece name() const override {
+ return "control-dep-remover";
+ }
+
+ StatusOr<bool> Run(HloModule* module) override {
+ bool changed = false;
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ changed = changed || !instruction->control_predecessors().empty();
+ TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
+ }
+ }
+ return changed;
+ }
+};
+
+} // namespace
+
Despecializer::Despecializer() : pipeline_("despecializer") {
// TODO(b/70588125): Also deal with window reversal in a fast way.
+ pipeline_.AddPass<ControlDepRemover>();
pipeline_.AddPass<Defuser>();
pipeline_.AddPass<ImplicitBroadcastRemover>();
pipeline_.AddPass<BFloat16MixedPrecisionRemoval>();