diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_schedule.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_schedule.h | 151 |
1 files changed, 151 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h new file mode 100644 index 0000000000..21c6988638 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -0,0 +1,151 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ + +#include <vector> + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { + +// Class representing a sequence of HLO instructions such as the sequential +// execution order of an HLO computation. +class HloInstructionSequence { + public: + HloInstructionSequence() = default; + HloInstructionSequence(absl::Span<const HloInstruction* const> instructions) { + for (const HloInstruction* instruction : instructions) { + push_back(instruction); + } + } + + // Adds the instruction to the end of the sequence. + void push_back(const HloInstruction* instruction) { + instruction_sequence_.push_back(instruction); + id_sequence_.push_back(instruction->unique_id()); + } + + // Clears the sequence of all instructions. + void clear() { + instruction_sequence_.clear(); + id_sequence_.clear(); + } + + int64 size() const { return instruction_sequence_.size(); } + + // Returns the sequence of HLO instructions. + const std::vector<const HloInstruction*>& instructions() const { + return instruction_sequence_; + } + + // Returns the unique IDs of the instructions in the sequence (in order). + const std::vector<int>& ids() const { return id_sequence_; } + + private: + // The sequence as HloInstructions. + std::vector<const HloInstruction*> instruction_sequence_; + + // The sequence of HLO instructions, represented by their unique IDs. The + // sequence is stored as both HloInstructions and unique IDs because the + // sequence may be referenced after transformations to the HLO graph and HLO + // pointers can be invalidated or recycled in this process (see + // HloSchedule::Update). + std::vector<int> id_sequence_; +}; + +// A class representing a sequential schedule of instructions for an HLO +// module. A complete HLO schedule contains an instruction sequence for every +// non-fusion computation in the HLO module. +class HloSchedule { + public: + HloSchedule(const HloModule* module) : module_(module) {} + + // Returns a reference to the sequence for the given computation. + const HloInstructionSequence& sequence( + const HloComputation* computation) const; + + // Returns the sequence for the given computation. An empty sequence is + // created if none exists for the computation. + HloInstructionSequence& GetOrCreateSequence( + const HloComputation* computation); + + // Sets the sequence for the given computation to the given sequence. + void set_sequence(const HloComputation* computation, + absl::Span<const HloInstruction* const> sequence); + void set_sequence(const HloComputation* computation, + HloInstructionSequence sequence); + + // Returns a map from HloComputation unique ID to instruction sequence. The + // map contains all sequences in the schedule. + const tensorflow::gtl::FlatMap<int64, HloInstructionSequence>& sequences() + const { + return sequences_; + } + + // Returns true if the schedule has a sequence for the given computation. + bool is_computation_scheduled(const HloComputation* computation) const { + return sequences_.count(computation->unique_id()) == 1; + } + + // Updates the schedule such that it is (again) a valid schedule for the + // module. This is used to update a schedule after the HLO module has been + // transformed in some way. In general, the only transformations to the module + // for which a schedule can be updated is the addition or removal of + // instructions and removal of computations. Updating the schedule after new + // dependencies between existing instructions in the module is not supported + // and may result in an error status returned. + // + // Instructions in the module which also exist in the given schedule will + // remain in the same order in the updated schedule. Instructions which exist + // in the module but not in the given schedule will be placed as early as + // possible in the updated schedule. + Status Update(); + + // Verifies that the given schedule is valid for the given module. + // Specifically, the schedule contains exactly the instructions in the + // non-fusion computations in the module and every dependency in the module is + // satisfied in the schedule. + Status Verify() const; + + string ToString() const; + + bool empty() const { return sequences_.empty(); } + + const HloModule* module() const { return module_; } + + private: + // Updates the instruction sequence for the given computation. + Status UpdateComputationSchedule(const HloComputation* computation); + + const HloModule* module_; + + // A map from computation unique ID to instruction sequence. Unique IDs are + // used rather than HloComputation pointers because HLO pointers are not + // unique across HLO transformations because pointers may be recycled. + tensorflow::gtl::FlatMap<int64, HloInstructionSequence> sequences_; +}; + +std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ |