aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.h
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-08-25 14:03:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-25 14:16:11 -0700
commitd304c3b2a00d9d52d758d211286c14e356b5e1ed (patch)
treea410c5f636fdc397600e8935aae95f1b79dcacf1 /tensorflow/compiler/xla/service/hlo_computation.h
parentf6fea18a630cf80d3e78bad4d98533a599936e25 (diff)
Add option to HloComputation::DeepCopyInstruction for selectively copying only
certain indices. Also, add mechanism for returning the kCopy instructions added to create the deep copy. PiperOrigin-RevId: 166521917
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h24
1 files changed, 18 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 0a33d0c1cf..f383a17fb8 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -125,7 +126,8 @@ class HloComputation {
// Returns the parameter instruction for the given parameter number.
HloInstruction* parameter_instruction(int64 param_no) const {
CHECK_GE(param_no, 0);
- CHECK_LT(param_no, static_cast<int64>(param_instructions_.size()));
+ CHECK_LT(param_no, static_cast<int64>(param_instructions_.size()))
+ << "Computation " << name() << " has no parameter number " << param_no;
return param_instructions_[param_no];
}
@@ -199,8 +201,16 @@ class HloComputation {
// producing the copied result. All instructions performing the copy are added
// to the computation. For array-shaped values, this method trivially returns
// a kCopy instruction. For tuple-shaped instructions, the copy is performed
- // with a series of kGetTupleElement and kTuple instructions.
- StatusOr<HloInstruction*> DeepCopyInstruction(HloInstruction* instruction);
+ // with a series of kGetTupleElement and kTuple instructions. If
+ // indices_to_copy is non-null then this ShapeTree indicates which elements
+ // (arrays) of the shape to copy. Non-copied elements are passed through
+ // transparently. If copies_added is non-null, then the added kCopy
+ // instructions will be inserted in the respective index in the given
+ // ShapeTree.
+ StatusOr<HloInstruction*> DeepCopyInstruction(
+ HloInstruction* instruction,
+ const ShapeTree<bool>* indices_to_copy = nullptr,
+ ShapeTree<HloInstruction*>* copies_added = nullptr);
// Computes and returns the ProgramShape of this computation (shape of
// parameters and result without layout).
@@ -287,9 +297,11 @@ class HloComputation {
tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
HloInstruction* fusion_instruction);
- // Internal helper for copying a tuple value. Creates and returns a deep copy
- // of the given instruction. The given instruction must be tuple-shaped.
- StatusOr<HloInstruction*> DeepCopyTuple(HloInstruction* instruction);
+ // Internal helper for recursive copying of an instruction. Creates and
+ // returns a deep copy of the given instruction.
+ StatusOr<HloInstruction*> DeepCopyHelper(
+ HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
+ ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index);
// Internal helper to collect unreachable roots.
std::vector<HloInstruction*> CollectUnreachableRoots() const;