diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 5885261e8f..bf50542cc0 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -63,6 +63,9 @@ class HloInstruction { kTransposeDot, // Fused into a dot with transposed operands. kConvBackwardFilter, // Fused into a backward filter convolution. kConvBackwardInput, // Fused into a backward input convolution. + + kCustom, // Custom category for backend-specific fusions that + // do not match any of the more specific ones. }; ~HloInstruction(); @@ -174,7 +177,8 @@ class HloInstruction { static std::unique_ptr<HloInstruction> CreateSlice( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> start_indices, - tensorflow::gtl::ArraySlice<int64> limit_indices); + tensorflow::gtl::ArraySlice<int64> limit_indices, + tensorflow::gtl::ArraySlice<int64> strides); // Creates a slice instruction, where the first operand is sliced by // start indices specified in the second operand, and by size specfied in @@ -662,6 +666,15 @@ class HloInstruction { return slice_limits_; } + // Returns the stride in the given dimension for a slice node. + // + // Precondition: opcode() == HloOpcode::kSlice + int64 slice_stride(int64 dimension) const { + CHECK_EQ(HloOpcode::kSlice, opcode_); + return slice_strides_[dimension]; + } + const std::vector<int64>& slice_strides() const { return slice_strides_; } + // Returns the size of the slice in the given dimension for a dynamic // slice node. // @@ -907,6 +920,7 @@ class HloInstruction { // Describes the [begin, end) index range for a slice. std::vector<int64> slice_starts_; std::vector<int64> slice_limits_; + std::vector<int64> slice_strides_; // The bit sizes for a reduce-precision operation. int32 exponent_bits_; |