diff options
author | Adrian Kuegel <akuegel@google.com> | 2018-08-16 01:32:25 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 01:36:41 -0700 |
commit | 72b829dcca2d1acaeea130e580ce780b1a7d550a (patch) | |
tree | 6c7e26f84f8d7eb5eeaf7f802db716b931757df7 | |
parent | 9d97b34bde77762a7499306ee74a56bcc91a95dc (diff) |
Add a feature_group_size parameter to the Convolution HLO op.
This is a first step towards supporting grouped convolutions, which are a
generalization of depthwise convolution.
PiperOrigin-RevId: 208950311
-rw-r--r-- | tensorflow/compiler/xla/client/xla_builder.cc | 64 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/xla_builder.h | 45 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo.proto | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 19 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 7 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 13 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.h | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser_test.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_verifier.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.h | 3 | ||||
-rw-r--r-- | tensorflow/docs_src/performance/xla/operation_semantics.md | 43 |
13 files changed, 160 insertions, 78 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 31dedd54b0..aa47f992bc 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -882,24 +882,28 @@ Status XlaBuilder::VerifyConvolution( XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, - Padding padding) { + Padding padding, int64 feature_group_count) { return ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); + CreateDefaultConvDimensionNumbers(window_strides.size()), + feature_group_count); } XlaOp XlaBuilder::ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, - tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) { + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + int64 feature_group_count) { return ConvGeneral(lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); + CreateDefaultConvDimensionNumbers(window_strides.size()), + feature_group_count); } XlaOp XlaBuilder::ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers) { + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -926,7 +930,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( return ConvGeneral(lhs, rhs, window_strides, MakePadding(base_area_dimensions, window_dimensions, window_strides, padding), - dimension_numbers); + dimension_numbers, feature_group_count); }); } @@ -934,9 +938,10 @@ XlaOp XlaBuilder::ConvGeneral( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, - const ConvolutionDimensionNumbers& dimension_numbers) { + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, - dimension_numbers); + dimension_numbers, feature_group_count); } XlaOp XlaBuilder::ConvGeneralDilated( @@ -945,7 +950,8 @@ XlaOp XlaBuilder::ConvGeneralDilated( tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, tensorflow::gtl::ArraySlice<int64> lhs_dilation, tensorflow::gtl::ArraySlice<int64> rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers) { + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -964,12 +970,13 @@ XlaOp XlaBuilder::ConvGeneralDilated( MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, rhs_dilation)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, instr.window(), - dimension_numbers)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, instr.window(), + dimension_numbers, feature_group_count)); *instr.mutable_convolution_dimension_numbers() = dimension_numbers; + instr.set_feature_group_count(feature_group_count); return AddInstruction(std::move(instr), HloOpcode::kConvolution, {lhs, rhs}); @@ -2562,32 +2569,38 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, } XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) { - return lhs.builder()->Conv(lhs, rhs, window_strides, padding); + tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, + int64 feature_group_count) { + return lhs.builder()->Conv(lhs, rhs, window_strides, padding, + feature_group_count); } XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, - tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) { + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + int64 feature_group_count) { return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides, - padding); + padding, feature_group_count); } XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers) { + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { return lhs.builder()->ConvWithGeneralDimensions(lhs, rhs, window_strides, - padding, dimension_numbers); + padding, dimension_numbers, + feature_group_count); } XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, - const ConvolutionDimensionNumbers& dimension_numbers) { + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, - dimension_numbers); + dimension_numbers, feature_group_count); } XlaOp ConvGeneralDilated( @@ -2596,10 +2609,11 @@ XlaOp ConvGeneralDilated( tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, tensorflow::gtl::ArraySlice<int64> lhs_dilation, tensorflow::gtl::ArraySlice<int64> rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers) { - return lhs.builder()->ConvGeneralDilated(lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation, - dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { + return lhs.builder()->ConvGeneralDilated( + lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, + dimension_numbers, feature_group_count); } XlaOp Fft(const XlaOp& operand, FftType fft_type, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 9403d7ca8d..78aec770a6 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -512,22 +512,24 @@ class XlaBuilder { // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice<int64> window_strides, - Padding padding); + tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, + int64 feature_group_count = 1); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, - tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding); + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + int64 feature_group_count = 1); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. @@ -535,7 +537,8 @@ class XlaBuilder { const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. @@ -545,7 +548,8 @@ class XlaBuilder { tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, tensorflow::gtl::ArraySlice<int64> lhs_dilation, tensorflow::gtl::ArraySlice<int64> rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. @@ -1161,27 +1165,31 @@ class XlaBuilder { const DotDimensionNumbers& dimension_numbers); friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, - Padding padding); + Padding padding, int64 feature_group_count); friend XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, - tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding); + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + int64 feature_group_count); friend XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count); friend XlaOp ConvGeneral( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count); friend XlaOp ConvGeneralDilated( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, tensorflow::gtl::ArraySlice<int64> lhs_dilation, tensorflow::gtl::ArraySlice<int64> rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count); friend XlaOp Fft(const XlaOp& operand, FftType fft_type, tensorflow::gtl::ArraySlice<int64> fft_length); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, @@ -1646,28 +1654,32 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding); + tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, + int64 feature_group_count = 1); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, - tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding); + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + int64 feature_group_count = 1); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. @@ -1677,7 +1689,8 @@ XlaOp ConvGeneralDilated( tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, tensorflow::gtl::ArraySlice<int64> lhs_dilation, tensorflow::gtl::ArraySlice<int64> rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index be9098f555..9d24b42401 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,6 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. +// Next ID: 51 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -74,6 +75,11 @@ message HloInstructionProto { // Describes the dimension numbers used for a convolution. xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16; + // The number of feature groups. Used for a convolution. Must be a divisor of + // the input feature dimension and output feature dimension. If not specified, + // it will use a default value of 1. + int64 feature_group_count = 50; + // Describes the [begin, end) index range and stride for slices. message SliceDimensions { int64 start = 1; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 0c92cd1225..7371fde79b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -322,9 +322,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( << proto.operand_ids_size(); TF_RET_CHECK(proto.has_window()); TF_RET_CHECK(proto.has_convolution_dimension_numbers()); - instruction = - CreateConvolve(proto.shape(), operands(0), operands(1), - proto.window(), proto.convolution_dimension_numbers()); + instruction = CreateConvolve( + proto.shape(), operands(0), operands(1), proto.window(), + proto.convolution_dimension_numbers(), + std::max(static_cast<int64>(proto.feature_group_count()), 1LL)); break; case HloOpcode::kReduceWindow: TF_RET_CHECK(proto.operand_ids_size() == 2) @@ -609,10 +610,10 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers) { - return MakeUnique<HloConvolutionInstruction>(shape, lhs, rhs, window, - dimension_numbers); + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { + return MakeUnique<HloConvolutionInstruction>( + shape, lhs, rhs, window, dimension_numbers, feature_group_count); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft( @@ -3181,6 +3182,10 @@ void HloInstruction::set_convolution_dimension_numbers( } } +int64 HloInstruction::feature_group_count() const { + return Cast<HloConvolutionInstruction>(this)->feature_group_count(); +} + HloComputation* HloInstruction::select() const { return Cast<HloSelectAndScatterInstruction>(this)->select(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index efaddfb95a..b3eee90099 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -402,7 +402,8 @@ class HloInstruction { static std::unique_ptr<HloInstruction> CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Creates an FFT op, of the type indicated by fft_type. static std::unique_ptr<HloInstruction> CreateFft( @@ -1455,6 +1456,10 @@ class HloInstruction { void set_convolution_dimension_numbers( const ConvolutionDimensionNumbers& dnums); + // The number of feature groups. Must be a divisor of the input feature + // dimension and output feature dimension. + int64 feature_group_count() const; + // Delegates to HloSelectAndScatterInstruction::select. HloComputation* select() const; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 8d3ef57757..233cdda7b0 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1606,10 +1606,12 @@ std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl( HloConvolutionInstruction::HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers) + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) : HloInstruction(HloOpcode::kConvolution, shape), window_(window), - convolution_dimension_numbers_(dimension_numbers) { + convolution_dimension_numbers_(dimension_numbers), + feature_group_count_(feature_group_count) { if (window_util::HasBaseDilation(window)) { SetAndSanitizeName(StrCat(name(), "-base-dilated")); } @@ -1647,6 +1649,7 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl( } extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString( convolution_dimension_numbers_))); + extra.push_back(StrCat("feature_group_count=", feature_group_count_)); return extra; } @@ -1668,9 +1671,9 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique<HloConvolutionInstruction>(shape, new_operands[0], - new_operands[1], window(), - convolution_dimension_numbers_); + return MakeUnique<HloConvolutionInstruction>( + shape, new_operands[0], new_operands[1], window(), + convolution_dimension_numbers_, feature_group_count_); } HloReduceWindowInstruction::HloReduceWindowInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index dd20c7c206..546949bc72 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -955,7 +955,8 @@ class HloConvolutionInstruction : public HloInstruction { explicit HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count); const Window& window() const override { return window_; } void set_window(const Window& window) override { window_ = window; } const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { @@ -965,6 +966,9 @@ class HloConvolutionInstruction : public HloInstruction { const ConvolutionDimensionNumbers& dnums) { convolution_dimension_numbers_ = dnums; } + // The number of feature groups. Must be a divisor of the input feature + // dimension and output feature dimension. + int64 feature_group_count() const { return feature_group_count_; } string ToCategory() const override; // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -984,6 +988,9 @@ class HloConvolutionInstruction : public HloInstruction { Window window_; // Describes the dimension numbers used for a convolution. ConvolutionDimensionNumbers convolution_dimension_numbers_; + // The number of feature groups. Must be a divisor of the input feature + // dimension and output feature dimension. + int64 feature_group_count_; }; class HloReduceWindowInstruction : public HloInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 4dfb9435cb..eb48337cd7 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -825,9 +825,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kConvolution: { optional<Window> window; optional<ConvolutionDimensionNumbers> dnums; + optional<int64> feature_group_count; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/true, AttrTy::kConvolutionDimensionNumbers, &dnums}; + attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, + &feature_group_count}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { return false; @@ -835,8 +838,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!window) { window.emplace(); } + if (!feature_group_count) { + feature_group_count = 1; + } instruction = builder->AddInstruction(HloInstruction::CreateConvolve( - shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums)); + shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums, + feature_group_count.value())); break; } case HloOpcode::kFft: { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 5990a3d478..6fa3c63d83 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -380,7 +380,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=1 } )" @@ -393,7 +393,7 @@ R"(HloModule ConvolveR2_module ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { %input = f32[1,2]{1,0} parameter(0) %filter = f32[1,1]{1,0} parameter(1) - ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf + ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf, feature_group_count=1 } )" @@ -406,7 +406,7 @@ R"(HloModule ConvolveBackward_module ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] { %input = f32[128,7,7,512]{0,3,2,1} parameter(0) %filter = f32[3,3,512,512]{3,2,1,0} parameter(1) - ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f + ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f, feature_group_count=1 } )" @@ -1370,7 +1370,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), feature_group_count=1, sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} } )"; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 9ebd5eb7a5..949a4d1110 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -84,7 +84,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { const Shape expected, ShapeInference::InferConvolveShape( convolution->operand(0)->shape(), convolution->operand(1)->shape(), - convolution->window(), convolution->convolution_dimension_numbers())); + convolution->window(), convolution->convolution_dimension_numbers(), + convolution->feature_group_count())); return CheckShape(convolution, expected); } diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index a4ea2b28f4..ec5743a777 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1530,7 +1530,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape( const Shape& lhs, const Shape& rhs, const Window& window, - const ConvolutionDimensionNumbers& dnums) { + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); @@ -1640,12 +1640,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 kernel_output_features = rhs.dimensions(dnums.kernel_output_feature_dimension()); - if (input_features != kernel_input_features) { + if (input_features != kernel_input_features * feature_group_count) { return InvalidArgument( "Expected LHS feature dimension (value %lld) to match RHS " - "input feature dimension (value %lld); got <conv>(%s, %s)\n" + "input feature dimension * feature_group_count (value %lld); " + "got <conv>(%s, %s)\n" "Dimension numbers: {%s}.", - input_features, kernel_input_features, + input_features, kernel_input_features * feature_group_count, ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str()); } diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index c185b0a1bd..bfd79a4433 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -112,7 +112,8 @@ class ShapeInference { // filter (rhs) to lhs in the way specified by the fields on window. static StatusOr<Shape> InferConvolveShape( const Shape& lhs, const Shape& rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Infers the shape produced by the given FFT type on the given operand. static StatusOr<Shape> InferFftShape( diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index e24a7cda73..8c9d26fcbb 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -505,16 +505,17 @@ Computes a convolution of the kind used in neural networks. Here, a convolution can be thought of as a n-dimensional window moving across a n-dimensional base area and a computation is performed for each possible position of the window. -| Arguments | Type | Semantics | -| ---------------- | ----------------------- | ----------------------------- | -| `lhs` | `XlaOp` | rank n+2 array of inputs | -| `rhs` | `XlaOp` | rank n+2 array of kernel | -: : : weights : -| `window_strides` | `ArraySlice<int64>` | n-d array of kernel strides | -| `padding` | `ArraySlice<pair<int64, | n-d array of (low, high) | -: : int64>>` : padding : -| `lhs_dilation` | `ArraySlice<int64>` | n-d lhs dilation factor array | -| `rhs_dilation` | `ArraySlice<int64>` | n-d rhs dilation factor array | +| Arguments | Type | Semantics | +| --------------------- | -------------------- | ----------------------------- | +| `lhs` | `XlaOp` | rank n+2 array of inputs | +| `rhs` | `XlaOp` | rank n+2 array of kernel | +: : : weights : +| `window_strides` | `ArraySlice<int64>` | n-d array of kernel strides | +| `padding` | `ArraySlice< | n-d array of (low, high) | +: : pair<int64, int64>>` : padding : +| `lhs_dilation` | `ArraySlice<int64>` | n-d lhs dilation factor array | +| `rhs_dilation` | `ArraySlice<int64>` | n-d rhs dilation factor array | +| `feature_group_count` | int64 | the number of feature groups | Let n be the number of spatial dimensions. The `lhs` argument is a rank n+2 array describing the base area. This is called the input, even though of course @@ -532,8 +533,8 @@ The `rhs` argument is a rank n+2 array describing the convolutional filter/kernel/window. The dimensions are, in this order: * `output-z`: The `z` dimension of the output. -* `input-z`: The size of this dimension should equal the size of the `z` - dimension in lhs. +* `input-z`: The size of this dimension times `feature_group_count` should + equal the size of the `z` dimension in lhs. * `spatial_dims`: Describes the `n` spatial dimensions that define the n-d window that moves across the base area. @@ -566,6 +567,24 @@ Dilation of the rhs is also called atrous convolution. For more details, see `tf.nn.atrous_conv2d`. Dilation of the lhs is also called transposed convolution. For more details, see `tf.nn.conv2d_transpose`. +The `feature_group_count` argument (default value 1) can be used for grouped +convolutions. `feature_group_count` needs to be a divisor of both the input and +the output feature dimension. If `feature_group_count` is greater than 1, it +means that conceptually the input and output feature dimension and the `rhs` +output feature dimension are split evenly into `feature_group_count` many +groups, each group consisting of a consecutive subsequence of features. The +input feature dimension of `rhs` needs to be equal to the `lhs` input feature +dimension divided by `feature_group_count` (so it already has the size of a +group of input features). The i-th groups are used together to compute +`feature_group_count` many separate convolutions. The results of these +convolutions are concatenated together in the output feature dimension. + +For depthwise convolution the `feature_group_count` argument would be set to the +input feature dimension, and the filter would be reshaped from +`[filter_height, filter_width, in_channels, channel_multiplier]` to +`[filter_height, filter_width, 1, in_channels * channel_multiplier]`. For more +details, see `tf.nn.depthwise_conv2d`. + The output shape has these dimensions, in this order: * `batch`: Same size as `batch` on the input (`lhs`). |