diff options
author | Adrian Kuegel <akuegel@google.com> | 2018-08-16 01:32:25 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 01:36:41 -0700 |
commit | 72b829dcca2d1acaeea130e580ce780b1a7d550a (patch) | |
tree | 6c7e26f84f8d7eb5eeaf7f802db716b931757df7 /tensorflow/compiler/xla/service/hlo_instructions.cc | |
parent | 9d97b34bde77762a7499306ee74a56bcc91a95dc (diff) |
Add a feature_group_size parameter to the Convolution HLO op.
This is a first step towards supporting grouped convolutions, which are a
generalization of depthwise convolution.
PiperOrigin-RevId: 208950311
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 8d3ef57757..233cdda7b0 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1606,10 +1606,12 @@ std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl( HloConvolutionInstruction::HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers) + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) : HloInstruction(HloOpcode::kConvolution, shape), window_(window), - convolution_dimension_numbers_(dimension_numbers) { + convolution_dimension_numbers_(dimension_numbers), + feature_group_count_(feature_group_count) { if (window_util::HasBaseDilation(window)) { SetAndSanitizeName(StrCat(name(), "-base-dilated")); } @@ -1647,6 +1649,7 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl( } extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString( convolution_dimension_numbers_))); + extra.push_back(StrCat("feature_group_count=", feature_group_count_)); return extra; } @@ -1668,9 +1671,9 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique<HloConvolutionInstruction>(shape, new_operands[0], - new_operands[1], window(), - convolution_dimension_numbers_); + return MakeUnique<HloConvolutionInstruction>( + shape, new_operands[0], new_operands[1], window(), + convolution_dimension_numbers_, feature_group_count_); } HloReduceWindowInstruction::HloReduceWindowInstruction( |