diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_ordering.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_ordering.h | 22 |
1 files changed, 6 insertions, 16 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index 985f3fa64d..b21071c4b2 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -183,17 +184,8 @@ class DependencyHloOrdering : public PredecessorHloOrdering { // interference is reduced relative to DependencyHloOrdering. class SequentialHloOrdering : public HloOrdering { public: - // TODO(dimvar): HloModuleSequence is not a good name because it sounds like - // a sequence of modules, instead of a map of schedules for all computations - // in a module. We should change it at some point. - // - // A sequence of instructions for each computation in the module. - using HloModuleSequence = - tensorflow::gtl::FlatMap<const HloComputation*, - std::vector<const HloInstruction*>>; - - SequentialHloOrdering(const HloModule* module, - const HloModuleSequence& module_sequence); + SequentialHloOrdering(const HloSchedule& schedule); + SequentialHloOrdering(HloSchedule&& schedule); ~SequentialHloOrdering() override = default; // Returns the sequential instruction order for the given computation. @@ -203,10 +195,12 @@ class SequentialHloOrdering : public HloOrdering { string ToString() const override; protected: + void Initialize(); + bool ExecutesBeforeInSameComputation(const HloInstruction* a, const HloInstruction* b) const override; - const HloModuleSequence module_sequence_; + const HloSchedule schedule_; // The position of every instruction in the HLO module in its respective // computation sequence (a value of zero indicates the instruction is first in @@ -217,10 +211,6 @@ class SequentialHloOrdering : public HloOrdering { tensorflow::gtl::FlatMap<const HloInstruction*, int> order_position_; }; -std::ostream& operator<<( - std::ostream& out, - const SequentialHloOrdering::HloModuleSequence& module_sequence); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_ |