diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.h | 19 |
1 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index c1c3e79ebc..49ed65910f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ +#include <functional> #include <list> #include <memory> #include <string> @@ -254,6 +255,14 @@ class HloComputation { const ShapeTree<bool>* indices_to_copy = nullptr, ShapeTree<HloInstruction*>* copies_added = nullptr); + // As above, but uses a custom function to copy the leaf nodes, which could + // create alternative HLOs other than kCopy, or even pass-throughs. + StatusOr<HloInstruction*> DeepCopyInstructionWithCustomCopier( + HloInstruction* instruction, + const std::function< + HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* computation)>& copy_leaf); + // Computes and returns the ProgramShape of this computation (shape of // parameters and result with layout). ProgramShape ComputeProgramShape() const; @@ -356,6 +365,10 @@ class HloComputation { unique_id_ = id; } + // Returns the instruction in this computation that has name `name`. Returns + // null if there is no such computation. + HloInstruction* GetInstructionWithName(tensorflow::StringPiece name); + int64 unique_id() const { return unique_id_; } private: @@ -378,8 +391,10 @@ class HloComputation { // Internal helper for recursive copying of an instruction. Creates and // returns a deep copy of the given instruction. StatusOr<HloInstruction*> DeepCopyHelper( - HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy, - ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index); + HloInstruction* instruction, ShapeIndex* index, + const std::function< + HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* computation)>& copy_leaf); // Internal helper to collect unreachable roots. std::vector<HloInstruction*> CollectUnreachableRoots() const; |