From 900d115135656229e3667025f925eb92687dce18 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 4 Oct 2018 12:29:50 -0700 Subject: [XLA] Move FusionQueue class declaration into separate header PiperOrigin-RevId: 215783391 --- tensorflow/compiler/xla/service/BUILD | 9 ++++ tensorflow/compiler/xla/service/fusion_queue.h | 53 ++++++++++++++++++++++ .../compiler/xla/service/instruction_fusion.cc | 1 + .../compiler/xla/service/instruction_fusion.h | 28 +----------- 4 files changed, 64 insertions(+), 27 deletions(-) create mode 100644 tensorflow/compiler/xla/service/fusion_queue.h (limited to 'tensorflow/compiler') diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index f329a27e14..2f8bab0614 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1323,11 +1323,20 @@ cc_library( ], ) +cc_library( + name = "fusion_queue", + hdrs = ["fusion_queue.h"], + deps = [ + ":hlo", + ], +) + cc_library( name = "instruction_fusion", srcs = ["instruction_fusion.cc"], hdrs = ["instruction_fusion.h"], deps = [ + ":fusion_queue", ":hlo", ":hlo_pass", "//tensorflow/compiler/xla:util", diff --git a/tensorflow/compiler/xla/service/fusion_queue.h b/tensorflow/compiler/xla/service/fusion_queue.h new file mode 100644 index 0000000000..1208a7dda8 --- /dev/null +++ b/tensorflow/compiler/xla/service/fusion_queue.h @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { + +// A queue interface that allows implementations to choose fusion candidates in +// custom order. +class FusionQueue { + public: + FusionQueue() = default; + virtual ~FusionQueue() = default; + + // Dequeues the next fusion candidates: a consumer and the list of producers + // as operand indices. + virtual std::pair> + DequeueNextInstructionAndOperandsToFuseInOrder() = 0; + + // A callback passed to the queue implementation right before the producer is + // fused into the consumer. + virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {} + + // A callback passed to the queue implementation right after the fusion is + // created. Note that original_producer could have been destroyed. + virtual void OnFusingInstruction(HloInstruction* fusion, + HloInstruction* original_producer, + HloInstruction* original_consumer) {} + + // A callback passed to the queue implementation to notify the removal of an + // instruction. + virtual void RemoveInstruction(HloInstruction* instruction) = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_ diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 5a99c40df4..69a4c160ee 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/fusion_queue.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index da2032f6c7..f14c667520 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -17,6 +17,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ +#include "tensorflow/compiler/xla/service/fusion_queue.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -25,33 +26,6 @@ limitations under the License. namespace xla { -// A queue interface that allows implementations to choose fusion candidates in -// custom order. -class FusionQueue { - public: - FusionQueue() = default; - virtual ~FusionQueue() = default; - - // Dequeues the next fusion candidates: a consumer and the list of producers - // as operand indices. - virtual std::pair> - DequeueNextInstructionAndOperandsToFuseInOrder() = 0; - - // A callback passed to the queue implementation right before the producer is - // fused into the consumer. - virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {} - - // A callback passed to the queue implementation right after the fusion is - // created. Note that original_producer could have been destroyed. - virtual void OnFusingInstruction(HloInstruction* fusion, - HloInstruction* original_producer, - HloInstruction* original_consumer) {} - - // A callback passed to the queue implementation to notify the removal of an - // instruction. - virtual void RemoveInstruction(HloInstruction* instruction) = 0; -}; - // HLO pass which performs instruction fusion. Instructions are fused // "vertically", meaning producing instructions are fused into their consumers // with the intent that the loops which compute their values will be fused in -- cgit v1.2.3