aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h16
1 files changed, 15 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 5885261e8f..bf50542cc0 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -63,6 +63,9 @@ class HloInstruction {
kTransposeDot, // Fused into a dot with transposed operands.
kConvBackwardFilter, // Fused into a backward filter convolution.
kConvBackwardInput, // Fused into a backward input convolution.
+
+ kCustom, // Custom category for backend-specific fusions that
+ // do not match any of the more specific ones.
};
~HloInstruction();
@@ -174,7 +177,8 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateSlice(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices);
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides);
// Creates a slice instruction, where the first operand is sliced by
// start indices specified in the second operand, and by size specfied in
@@ -662,6 +666,15 @@ class HloInstruction {
return slice_limits_;
}
+ // Returns the stride in the given dimension for a slice node.
+ //
+ // Precondition: opcode() == HloOpcode::kSlice
+ int64 slice_stride(int64 dimension) const {
+ CHECK_EQ(HloOpcode::kSlice, opcode_);
+ return slice_strides_[dimension];
+ }
+ const std::vector<int64>& slice_strides() const { return slice_strides_; }
+
// Returns the size of the slice in the given dimension for a dynamic
// slice node.
//
@@ -907,6 +920,7 @@ class HloInstruction {
// Describes the [begin, end) index range for a slice.
std::vector<int64> slice_starts_;
std::vector<int64> slice_limits_;
+ std::vector<int64> slice_strides_;
// The bit sizes for a reduce-precision operation.
int32 exponent_bits_;