aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-04 12:29:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 12:36:24 -0700
commit900d115135656229e3667025f925eb92687dce18 (patch)
tree31ee19149765eb8d8d75118fb5d8f3b48c8a3896 /tensorflow/compiler
parent5bdd0f7c2807ed413cfc60319f1e75b1e6a4a5b5 (diff)
[XLA] Move FusionQueue class declaration into separate header
PiperOrigin-RevId: 215783391
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/BUILD9
-rw-r--r--tensorflow/compiler/xla/service/fusion_queue.h53
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h28
4 files changed, 64 insertions, 27 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index f329a27e14..2f8bab0614 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1324,10 +1324,19 @@ cc_library(
)
cc_library(
+ name = "fusion_queue",
+ hdrs = ["fusion_queue.h"],
+ deps = [
+ ":hlo",
+ ],
+)
+
+cc_library(
name = "instruction_fusion",
srcs = ["instruction_fusion.cc"],
hdrs = ["instruction_fusion.h"],
deps = [
+ ":fusion_queue",
":hlo",
":hlo_pass",
"//tensorflow/compiler/xla:util",
diff --git a/tensorflow/compiler/xla/service/fusion_queue.h b/tensorflow/compiler/xla/service/fusion_queue.h
new file mode 100644
index 0000000000..1208a7dda8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/fusion_queue.h
@@ -0,0 +1,53 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_
+
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+
+namespace xla {
+
+// A queue interface that allows implementations to choose fusion candidates in
+// custom order.
+class FusionQueue {
+ public:
+ FusionQueue() = default;
+ virtual ~FusionQueue() = default;
+
+ // Dequeues the next fusion candidates: a consumer and the list of producers
+ // as operand indices.
+ virtual std::pair<HloInstruction*, std::vector<int64>>
+ DequeueNextInstructionAndOperandsToFuseInOrder() = 0;
+
+ // A callback passed to the queue implementation right before the producer is
+ // fused into the consumer.
+ virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {}
+
+ // A callback passed to the queue implementation right after the fusion is
+ // created. Note that original_producer could have been destroyed.
+ virtual void OnFusingInstruction(HloInstruction* fusion,
+ HloInstruction* original_producer,
+ HloInstruction* original_consumer) {}
+
+ // A callback passed to the queue implementation to notify the removal of an
+ // instruction.
+ virtual void RemoveInstruction(HloInstruction* instruction) = 0;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 5a99c40df4..69a4c160ee 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/fusion_queue.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index da2032f6c7..f14c667520 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -17,6 +17,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
+#include "tensorflow/compiler/xla/service/fusion_queue.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -25,33 +26,6 @@ limitations under the License.
namespace xla {
-// A queue interface that allows implementations to choose fusion candidates in
-// custom order.
-class FusionQueue {
- public:
- FusionQueue() = default;
- virtual ~FusionQueue() = default;
-
- // Dequeues the next fusion candidates: a consumer and the list of producers
- // as operand indices.
- virtual std::pair<HloInstruction*, std::vector<int64>>
- DequeueNextInstructionAndOperandsToFuseInOrder() = 0;
-
- // A callback passed to the queue implementation right before the producer is
- // fused into the consumer.
- virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {}
-
- // A callback passed to the queue implementation right after the fusion is
- // created. Note that original_producer could have been destroyed.
- virtual void OnFusingInstruction(HloInstruction* fusion,
- HloInstruction* original_producer,
- HloInstruction* original_consumer) {}
-
- // A callback passed to the queue implementation to notify the removal of an
- // instruction.
- virtual void RemoveInstruction(HloInstruction* instruction) = 0;
-};
-
// HLO pass which performs instruction fusion. Instructions are fused
// "vertically", meaning producing instructions are fused into their consumers
// with the intent that the loops which compute their values will be fused in