aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_scheduling.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_scheduling.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.h54
1 files changed, 9 insertions, 45 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h
index d06b8d9a5c..54e32340ba 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.h
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -32,14 +33,14 @@ namespace xla {
// 'computation' that minimizes peak memory, given a points-to analysis result
// that describes buffer aliasing, together with a target-specific size function
// that maps a tensor's logical size to its padded size.
-typedef std::function<StatusOr<std::vector<const HloInstruction*>>(
+typedef std::function<StatusOr<HloInstructionSequence>(
const HloComputation&, const TuplePointsToAnalysis&,
const LogicalBuffer::SizeFunction&,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&)>
MemorySchedulerAlgorithm;
// List scheduler
-StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
+StatusOr<HloInstructionSequence> ListMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -47,7 +48,7 @@ StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
memory_by_computation);
// DFS-order scheduler
-StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
+StatusOr<HloInstructionSequence> DFSMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -55,7 +56,7 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
memory_by_computation);
// Naive Post Order scheduler
-StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
+StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -65,63 +66,26 @@ StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
// The default scheduling algorithm. Runs both the list scheduler
// and the DFS scheduler, and chooses whichever returns a lower min-memory,
// not accounting for fragmentation.
-StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
+StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
memory_by_computation);
-// Returns an HloModuleSequence which seeks to minimize the memory required for
+// Returns an HloSchedule which seeks to minimize the memory required for
// the computation. size_function is the function returning the number of bytes
// required for a LogicalBuffer.
-StatusOr<SequentialHloOrdering::HloModuleSequence> ScheduleComputationsInModule(
+StatusOr<HloSchedule> ScheduleModule(
const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm = {});
// Computes the schedule for a single computation.
// Currently only used by the GPU backend.
-StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
+StatusOr<HloInstructionSequence> ScheduleComputation(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function);
-// Transforms the given schedule such that it is (again) a valid schedule for
-// the module. This is used to update a schedule after the HLO module has been
-// transformed in some way. In general, the only transformations to the module
-// for which a schedule can be updated is the addition or removal of
-// instructions to/from the module. Updating the schedule after new dependencies
-// between existing instructions in the module is not supported and may result
-// in an error status returned.
-//
-// Instructions in the module which also exist in the given schedule will remain
-// in the same order in the updated schedule. Instructions which exist in the
-// module but not in the given schedule will be placed as early as possible in
-// the updated schedule.
-//
-// 'id_sequence' is a mirror of the given schedule 'sequence' but with
-// HloInstruction ids rather than HloInstruction pointers. This should be
-// constructed using ComputeIdSchedule below after the schedule is constructed
-// but before the HLO module is transformed.
-Status UpdateSchedule(
- const HloModule& module,
- const tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>&
- id_sequence,
- SequentialHloOrdering::HloModuleSequence* sequence);
-
-// Constructs a copy of the given schedule but with HloInstruction unique ids
-// rather than HloInstruction pointers. This is necessary for updating a
-// schedule as HloInstruction points in the schedule may become invalid if
-// instructions are removed from the module. Used by UpdateSchedule above..
-// TODO(b/113175018): Remove this function when HLO schedule is its own class.
-tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
-ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence);
-
-// Verifies that the given schedule is valid for the given module. Specifically,
-// the schedule contains exactly the instructions in the module and every
-// dependency in the module is satisfied in the schedule.
-Status VerifySchedule(const HloModule& module,
- const SequentialHloOrdering::HloModuleSequence& sequence);
-
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_