aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_schedule.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_schedule.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.h151
1 files changed, 151 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h
new file mode 100644
index 0000000000..21c6988638
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_schedule.h
@@ -0,0 +1,151 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
+
+#include <vector>
+
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
+#include "tensorflow/compiler/xla/status.h"
+
+namespace xla {
+
+// Class representing a sequence of HLO instructions such as the sequential
+// execution order of an HLO computation.
+class HloInstructionSequence {
+ public:
+ HloInstructionSequence() = default;
+ HloInstructionSequence(absl::Span<const HloInstruction* const> instructions) {
+ for (const HloInstruction* instruction : instructions) {
+ push_back(instruction);
+ }
+ }
+
+ // Adds the instruction to the end of the sequence.
+ void push_back(const HloInstruction* instruction) {
+ instruction_sequence_.push_back(instruction);
+ id_sequence_.push_back(instruction->unique_id());
+ }
+
+ // Clears the sequence of all instructions.
+ void clear() {
+ instruction_sequence_.clear();
+ id_sequence_.clear();
+ }
+
+ int64 size() const { return instruction_sequence_.size(); }
+
+ // Returns the sequence of HLO instructions.
+ const std::vector<const HloInstruction*>& instructions() const {
+ return instruction_sequence_;
+ }
+
+ // Returns the unique IDs of the instructions in the sequence (in order).
+ const std::vector<int>& ids() const { return id_sequence_; }
+
+ private:
+ // The sequence as HloInstructions.
+ std::vector<const HloInstruction*> instruction_sequence_;
+
+ // The sequence of HLO instructions, represented by their unique IDs. The
+ // sequence is stored as both HloInstructions and unique IDs because the
+ // sequence may be referenced after transformations to the HLO graph and HLO
+ // pointers can be invalidated or recycled in this process (see
+ // HloSchedule::Update).
+ std::vector<int> id_sequence_;
+};
+
+// A class representing a sequential schedule of instructions for an HLO
+// module. A complete HLO schedule contains an instruction sequence for every
+// non-fusion computation in the HLO module.
+class HloSchedule {
+ public:
+ HloSchedule(const HloModule* module) : module_(module) {}
+
+ // Returns a reference to the sequence for the given computation.
+ const HloInstructionSequence& sequence(
+ const HloComputation* computation) const;
+
+ // Returns the sequence for the given computation. An empty sequence is
+ // created if none exists for the computation.
+ HloInstructionSequence& GetOrCreateSequence(
+ const HloComputation* computation);
+
+ // Sets the sequence for the given computation to the given sequence.
+ void set_sequence(const HloComputation* computation,
+ absl::Span<const HloInstruction* const> sequence);
+ void set_sequence(const HloComputation* computation,
+ HloInstructionSequence sequence);
+
+ // Returns a map from HloComputation unique ID to instruction sequence. The
+ // map contains all sequences in the schedule.
+ const tensorflow::gtl::FlatMap<int64, HloInstructionSequence>& sequences()
+ const {
+ return sequences_;
+ }
+
+ // Returns true if the schedule has a sequence for the given computation.
+ bool is_computation_scheduled(const HloComputation* computation) const {
+ return sequences_.count(computation->unique_id()) == 1;
+ }
+
+ // Updates the schedule such that it is (again) a valid schedule for the
+ // module. This is used to update a schedule after the HLO module has been
+ // transformed in some way. In general, the only transformations to the module
+ // for which a schedule can be updated is the addition or removal of
+ // instructions and removal of computations. Updating the schedule after new
+ // dependencies between existing instructions in the module is not supported
+ // and may result in an error status returned.
+ //
+ // Instructions in the module which also exist in the given schedule will
+ // remain in the same order in the updated schedule. Instructions which exist
+ // in the module but not in the given schedule will be placed as early as
+ // possible in the updated schedule.
+ Status Update();
+
+ // Verifies that the given schedule is valid for the given module.
+ // Specifically, the schedule contains exactly the instructions in the
+ // non-fusion computations in the module and every dependency in the module is
+ // satisfied in the schedule.
+ Status Verify() const;
+
+ string ToString() const;
+
+ bool empty() const { return sequences_.empty(); }
+
+ const HloModule* module() const { return module_; }
+
+ private:
+ // Updates the instruction sequence for the given computation.
+ Status UpdateComputationSchedule(const HloComputation* computation);
+
+ const HloModule* module_;
+
+ // A map from computation unique ID to instruction sequence. Unique IDs are
+ // used rather than HloComputation pointers because HLO pointers are not
+ // unique across HLO transformations because pointers may be recycled.
+ tensorflow::gtl::FlatMap<int64, HloInstructionSequence> sequences_;
+};
+
+std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_