aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc11
1 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 7c60fc79ed..4b38ecd5de 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -306,11 +306,13 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices) {
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides) {
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSlice, shape));
instruction->AppendOperand(operand);
instruction->slice_starts_.assign(start_indices.begin(), start_indices.end());
instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end());
+ instruction->slice_strides_.assign(strides.begin(), strides.end());
return instruction;
}
@@ -852,7 +854,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
return CreateReshape(shape, new_operands[0]);
case HloOpcode::kSlice:
CHECK_EQ(new_operands.size(), 1);
- return CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_);
+ return CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_,
+ slice_strides_);
case HloOpcode::kDynamicSlice:
return CreateDynamicSlice(shape, new_operands[0], new_operands[1],
dynamic_slice_sizes_);
@@ -1672,6 +1675,8 @@ string HloInstruction::ToCategory() const {
case FusionKind::kConvBackwardFilter:
case FusionKind::kConvBackwardInput:
return "convolution fusion";
+ case FusionKind::kCustom:
+ return "custom fusion";
}
}
@@ -2339,6 +2344,8 @@ string ToString(HloInstruction::FusionKind kind) {
return "kConvBackwardFilter";
case HloInstruction::FusionKind::kConvBackwardInput:
return "kConvBackwardInput";
+ case HloInstruction::FusionKind::kCustom:
+ return "kCustom";
}
}