diff options
author | Mark Heffernan <meheff@google.com> | 2017-08-25 14:03:07 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-25 14:16:11 -0700 |
commit | d304c3b2a00d9d52d758d211286c14e356b5e1ed (patch) | |
tree | a410c5f636fdc397600e8935aae95f1b79dcacf1 /tensorflow/compiler/xla/service/hlo_computation.h | |
parent | f6fea18a630cf80d3e78bad4d98533a599936e25 (diff) |
Add option to HloComputation::DeepCopyInstruction for selectively copying only
certain indices. Also, add mechanism for returning the kCopy instructions
added to create the deep copy.
PiperOrigin-RevId: 166521917
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.h | 24 |
1 files changed, 18 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 0a33d0c1cf..f383a17fb8 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -125,7 +126,8 @@ class HloComputation { // Returns the parameter instruction for the given parameter number. HloInstruction* parameter_instruction(int64 param_no) const { CHECK_GE(param_no, 0); - CHECK_LT(param_no, static_cast<int64>(param_instructions_.size())); + CHECK_LT(param_no, static_cast<int64>(param_instructions_.size())) + << "Computation " << name() << " has no parameter number " << param_no; return param_instructions_[param_no]; } @@ -199,8 +201,16 @@ class HloComputation { // producing the copied result. All instructions performing the copy are added // to the computation. For array-shaped values, this method trivially returns // a kCopy instruction. For tuple-shaped instructions, the copy is performed - // with a series of kGetTupleElement and kTuple instructions. - StatusOr<HloInstruction*> DeepCopyInstruction(HloInstruction* instruction); + // with a series of kGetTupleElement and kTuple instructions. If + // indices_to_copy is non-null then this ShapeTree indicates which elements + // (arrays) of the shape to copy. Non-copied elements are passed through + // transparently. If copies_added is non-null, then the added kCopy + // instructions will be inserted in the respective index in the given + // ShapeTree. + StatusOr<HloInstruction*> DeepCopyInstruction( + HloInstruction* instruction, + const ShapeTree<bool>* indices_to_copy = nullptr, + ShapeTree<HloInstruction*>* copies_added = nullptr); // Computes and returns the ProgramShape of this computation (shape of // parameters and result without layout). @@ -287,9 +297,11 @@ class HloComputation { tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse, HloInstruction* fusion_instruction); - // Internal helper for copying a tuple value. Creates and returns a deep copy - // of the given instruction. The given instruction must be tuple-shaped. - StatusOr<HloInstruction*> DeepCopyTuple(HloInstruction* instruction); + // Internal helper for recursive copying of an instruction. Creates and + // returns a deep copy of the given instruction. + StatusOr<HloInstruction*> DeepCopyHelper( + HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy, + ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index); // Internal helper to collect unreachable roots. std::vector<HloInstruction*> CollectUnreachableRoots() const; |