aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-08-16 01:32:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 01:36:41 -0700
commit72b829dcca2d1acaeea130e580ce780b1a7d550a (patch)
tree6c7e26f84f8d7eb5eeaf7f802db716b931757df7 /tensorflow/compiler/xla/service/hlo_instructions.cc
parent9d97b34bde77762a7499306ee74a56bcc91a95dc (diff)
Add a feature_group_size parameter to the Convolution HLO op.
This is a first step towards supporting grouped convolutions, which are a generalization of depthwise convolution. PiperOrigin-RevId: 208950311
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc13
1 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 8d3ef57757..233cdda7b0 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1606,10 +1606,12 @@ std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
HloConvolutionInstruction::HloConvolutionInstruction(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window, const ConvolutionDimensionNumbers& dimension_numbers)
+ const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count)
: HloInstruction(HloOpcode::kConvolution, shape),
window_(window),
- convolution_dimension_numbers_(dimension_numbers) {
+ convolution_dimension_numbers_(dimension_numbers),
+ feature_group_count_(feature_group_count) {
if (window_util::HasBaseDilation(window)) {
SetAndSanitizeName(StrCat(name(), "-base-dilated"));
}
@@ -1647,6 +1649,7 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
}
extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
convolution_dimension_numbers_)));
+ extra.push_back(StrCat("feature_group_count=", feature_group_count_));
return extra;
}
@@ -1668,9 +1671,9 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloConvolutionInstruction>(shape, new_operands[0],
- new_operands[1], window(),
- convolution_dimension_numbers_);
+ return MakeUnique<HloConvolutionInstruction>(
+ shape, new_operands[0], new_operands[1], window(),
+ convolution_dimension_numbers_, feature_group_count_);
}
HloReduceWindowInstruction::HloReduceWindowInstruction(