aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h19
1 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index c1c3e79ebc..49ed65910f 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
+#include <functional>
#include <list>
#include <memory>
#include <string>
@@ -254,6 +255,14 @@ class HloComputation {
const ShapeTree<bool>* indices_to_copy = nullptr,
ShapeTree<HloInstruction*>* copies_added = nullptr);
+ // As above, but uses a custom function to copy the leaf nodes, which could
+ // create alternative HLOs other than kCopy, or even pass-throughs.
+ StatusOr<HloInstruction*> DeepCopyInstructionWithCustomCopier(
+ HloInstruction* instruction,
+ const std::function<
+ HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* computation)>& copy_leaf);
+
// Computes and returns the ProgramShape of this computation (shape of
// parameters and result with layout).
ProgramShape ComputeProgramShape() const;
@@ -356,6 +365,10 @@ class HloComputation {
unique_id_ = id;
}
+ // Returns the instruction in this computation that has name `name`. Returns
+ // null if there is no such computation.
+ HloInstruction* GetInstructionWithName(tensorflow::StringPiece name);
+
int64 unique_id() const { return unique_id_; }
private:
@@ -378,8 +391,10 @@ class HloComputation {
// Internal helper for recursive copying of an instruction. Creates and
// returns a deep copy of the given instruction.
StatusOr<HloInstruction*> DeepCopyHelper(
- HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
- ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index);
+ HloInstruction* instruction, ShapeIndex* index,
+ const std::function<
+ HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* computation)>& copy_leaf);
// Internal helper to collect unreachable roots.
std::vector<HloInstruction*> CollectUnreachableRoots() const;