aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-08-16 01:32:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 01:36:41 -0700
commit72b829dcca2d1acaeea130e580ce780b1a7d550a (patch)
tree6c7e26f84f8d7eb5eeaf7f802db716b931757df7
parent9d97b34bde77762a7499306ee74a56bcc91a95dc (diff)
Add a feature_group_size parameter to the Convolution HLO op.
This is a first step towards supporting grouped convolutions, which are a generalization of depthwise convolution. PiperOrigin-RevId: 208950311
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc64
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h45
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto6
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc19
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h9
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc3
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc9
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h3
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md43
13 files changed, 160 insertions, 78 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 31dedd54b0..aa47f992bc 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -882,24 +882,28 @@ Status XlaBuilder::VerifyConvolution(
XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding) {
+ Padding padding, int64 feature_group_count) {
return ConvWithGeneralDimensions(
lhs, rhs, window_strides, padding,
- CreateDefaultConvDimensionNumbers(window_strides.size()));
+ CreateDefaultConvDimensionNumbers(window_strides.size()),
+ feature_group_count);
}
XlaOp XlaBuilder::ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ int64 feature_group_count) {
return ConvGeneral(lhs, rhs, window_strides, padding,
- CreateDefaultConvDimensionNumbers(window_strides.size()));
+ CreateDefaultConvDimensionNumbers(window_strides.size()),
+ feature_group_count);
}
XlaOp XlaBuilder::ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
@@ -926,7 +930,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions(
return ConvGeneral(lhs, rhs, window_strides,
MakePadding(base_area_dimensions, window_dimensions,
window_strides, padding),
- dimension_numbers);
+ dimension_numbers, feature_group_count);
});
}
@@ -934,9 +938,10 @@ XlaOp XlaBuilder::ConvGeneral(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
- dimension_numbers);
+ dimension_numbers, feature_group_count);
}
XlaOp XlaBuilder::ConvGeneralDilated(
@@ -945,7 +950,8 @@ XlaOp XlaBuilder::ConvGeneralDilated(
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -964,12 +970,13 @@ XlaOp XlaBuilder::ConvGeneralDilated(
MakeWindow(window_dimensions, window_strides, padding,
lhs_dilation, rhs_dilation));
- TF_ASSIGN_OR_RETURN(
- *instr.mutable_shape(),
- ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, instr.window(),
- dimension_numbers));
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferConvolveShape(
+ lhs_shape, rhs_shape, instr.window(),
+ dimension_numbers, feature_group_count));
*instr.mutable_convolution_dimension_numbers() = dimension_numbers;
+ instr.set_feature_group_count(feature_group_count);
return AddInstruction(std::move(instr), HloOpcode::kConvolution,
{lhs, rhs});
@@ -2562,32 +2569,38 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
}
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
- return lhs.builder()->Conv(lhs, rhs, window_strides, padding);
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ int64 feature_group_count) {
+ return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
+ feature_group_count);
}
XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ int64 feature_group_count) {
return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides,
- padding);
+ padding, feature_group_count);
}
XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
return lhs.builder()->ConvWithGeneralDimensions(lhs, rhs, window_strides,
- padding, dimension_numbers);
+ padding, dimension_numbers,
+ feature_group_count);
}
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding,
- dimension_numbers);
+ dimension_numbers, feature_group_count);
}
XlaOp ConvGeneralDilated(
@@ -2596,10 +2609,11 @@ XlaOp ConvGeneralDilated(
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers) {
- return lhs.builder()->ConvGeneralDilated(lhs, rhs, window_strides, padding,
- lhs_dilation, rhs_dilation,
- dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
+ return lhs.builder()->ConvGeneralDilated(
+ lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
+ dimension_numbers, feature_group_count);
}
XlaOp Fft(const XlaOp& operand, FftType fft_type,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 9403d7ca8d..78aec770a6 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -512,22 +512,24 @@ class XlaBuilder {
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding);
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ int64 feature_group_count = 1);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding().
XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ int64 feature_group_count = 1);
// Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration.
XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers.
@@ -535,7 +537,8 @@ class XlaBuilder {
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers.
@@ -545,7 +548,8 @@ class XlaBuilder {
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1);
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.
@@ -1161,27 +1165,31 @@ class XlaBuilder {
const DotDimensionNumbers& dimension_numbers);
friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
- Padding padding);
+ Padding padding, int64 feature_group_count);
friend XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ int64 feature_group_count);
friend XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count);
friend XlaOp ConvGeneral(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count);
friend XlaOp ConvGeneralDilated(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count);
friend XlaOp Fft(const XlaOp& operand, FftType fft_type,
tensorflow::gtl::ArraySlice<int64> fft_length);
friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
@@ -1646,28 +1654,32 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
- tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ int64 feature_group_count = 1);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding().
XlaOp ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
- tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ int64 feature_group_count = 1);
// Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration.
XlaOp ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers.
XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers.
@@ -1677,7 +1689,8 @@ XlaOp ConvGeneralDilated(
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1);
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index be9098f555..9d24b42401 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -34,6 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
+// Next ID: 51
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@@ -74,6 +75,11 @@ message HloInstructionProto {
// Describes the dimension numbers used for a convolution.
xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16;
+ // The number of feature groups. Used for a convolution. Must be a divisor of
+ // the input feature dimension and output feature dimension. If not specified,
+ // it will use a default value of 1.
+ int64 feature_group_count = 50;
+
// Describes the [begin, end) index range and stride for slices.
message SliceDimensions {
int64 start = 1;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 0c92cd1225..7371fde79b 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -322,9 +322,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< proto.operand_ids_size();
TF_RET_CHECK(proto.has_window());
TF_RET_CHECK(proto.has_convolution_dimension_numbers());
- instruction =
- CreateConvolve(proto.shape(), operands(0), operands(1),
- proto.window(), proto.convolution_dimension_numbers());
+ instruction = CreateConvolve(
+ proto.shape(), operands(0), operands(1), proto.window(),
+ proto.convolution_dimension_numbers(),
+ std::max(static_cast<int64>(proto.feature_group_count()), 1LL));
break;
case HloOpcode::kReduceWindow:
TF_RET_CHECK(proto.operand_ids_size() == 2)
@@ -609,10 +610,10 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers) {
- return MakeUnique<HloConvolutionInstruction>(shape, lhs, rhs, window,
- dimension_numbers);
+ const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
+ return MakeUnique<HloConvolutionInstruction>(
+ shape, lhs, rhs, window, dimension_numbers, feature_group_count);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
@@ -3181,6 +3182,10 @@ void HloInstruction::set_convolution_dimension_numbers(
}
}
+int64 HloInstruction::feature_group_count() const {
+ return Cast<HloConvolutionInstruction>(this)->feature_group_count();
+}
+
HloComputation* HloInstruction::select() const {
return Cast<HloSelectAndScatterInstruction>(this)->select();
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index efaddfb95a..b3eee90099 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -402,7 +402,8 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateConvolve(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1);
// Creates an FFT op, of the type indicated by fft_type.
static std::unique_ptr<HloInstruction> CreateFft(
@@ -1455,6 +1456,10 @@ class HloInstruction {
void set_convolution_dimension_numbers(
const ConvolutionDimensionNumbers& dnums);
+ // The number of feature groups. Must be a divisor of the input feature
+ // dimension and output feature dimension.
+ int64 feature_group_count() const;
+
// Delegates to HloSelectAndScatterInstruction::select.
HloComputation* select() const;
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 8d3ef57757..233cdda7b0 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1606,10 +1606,12 @@ std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
HloConvolutionInstruction::HloConvolutionInstruction(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
- const Window& window, const ConvolutionDimensionNumbers& dimension_numbers)
+ const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count)
: HloInstruction(HloOpcode::kConvolution, shape),
window_(window),
- convolution_dimension_numbers_(dimension_numbers) {
+ convolution_dimension_numbers_(dimension_numbers),
+ feature_group_count_(feature_group_count) {
if (window_util::HasBaseDilation(window)) {
SetAndSanitizeName(StrCat(name(), "-base-dilated"));
}
@@ -1647,6 +1649,7 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
}
extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
convolution_dimension_numbers_)));
+ extra.push_back(StrCat("feature_group_count=", feature_group_count_));
return extra;
}
@@ -1668,9 +1671,9 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloConvolutionInstruction>(shape, new_operands[0],
- new_operands[1], window(),
- convolution_dimension_numbers_);
+ return MakeUnique<HloConvolutionInstruction>(
+ shape, new_operands[0], new_operands[1], window(),
+ convolution_dimension_numbers_, feature_group_count_);
}
HloReduceWindowInstruction::HloReduceWindowInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index dd20c7c206..546949bc72 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -955,7 +955,8 @@ class HloConvolutionInstruction : public HloInstruction {
explicit HloConvolutionInstruction(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count);
const Window& window() const override { return window_; }
void set_window(const Window& window) override { window_ = window; }
const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
@@ -965,6 +966,9 @@ class HloConvolutionInstruction : public HloInstruction {
const ConvolutionDimensionNumbers& dnums) {
convolution_dimension_numbers_ = dnums;
}
+ // The number of feature groups. Must be a divisor of the input feature
+ // dimension and output feature dimension.
+ int64 feature_group_count() const { return feature_group_count_; }
string ToCategory() const override;
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -984,6 +988,9 @@ class HloConvolutionInstruction : public HloInstruction {
Window window_;
// Describes the dimension numbers used for a convolution.
ConvolutionDimensionNumbers convolution_dimension_numbers_;
+ // The number of feature groups. Must be a divisor of the input feature
+ // dimension and output feature dimension.
+ int64 feature_group_count_;
};
class HloReduceWindowInstruction : public HloInstruction {
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 4dfb9435cb..eb48337cd7 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -825,9 +825,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kConvolution: {
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
+ optional<int64> feature_group_count;
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["dim_labels"] = {/*required=*/true,
AttrTy::kConvolutionDimensionNumbers, &dnums};
+ attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
+ &feature_group_count};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
@@ -835,8 +838,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!window) {
window.emplace();
}
+ if (!feature_group_count) {
+ feature_group_count = 1;
+ }
instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
- shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums));
+ shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums,
+ feature_group_count.value()));
break;
}
case HloOpcode::kFft: {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 5990a3d478..6fa3c63d83 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -380,7 +380,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
- ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f
+ ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=1
}
)"
@@ -393,7 +393,7 @@ R"(HloModule ConvolveR2_module
ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] {
%input = f32[1,2]{1,0} parameter(0)
%filter = f32[1,1]{1,0} parameter(1)
- ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf
+ ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf, feature_group_count=1
}
)"
@@ -406,7 +406,7 @@ R"(HloModule ConvolveBackward_module
ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] {
%input = f32[128,7,7,512]{0,3,2,1} parameter(0)
%filter = f32[3,3,512,512]{3,2,1,0} parameter(1)
- ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f
+ ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f, feature_group_count=1
}
)"
@@ -1370,7 +1370,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
- ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
+ ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), feature_group_count=1, sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
}
)";
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 9ebd5eb7a5..949a4d1110 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -84,7 +84,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
const Shape expected,
ShapeInference::InferConvolveShape(
convolution->operand(0)->shape(), convolution->operand(1)->shape(),
- convolution->window(), convolution->convolution_dimension_numbers()));
+ convolution->window(), convolution->convolution_dimension_numbers(),
+ convolution->feature_group_count()));
return CheckShape(convolution, expected);
}
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index a4ea2b28f4..ec5743a777 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1530,7 +1530,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
const Shape& lhs, const Shape& rhs, const Window& window,
- const ConvolutionDimensionNumbers& dnums) {
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count) {
TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
@@ -1640,12 +1640,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const int64 kernel_output_features =
rhs.dimensions(dnums.kernel_output_feature_dimension());
- if (input_features != kernel_input_features) {
+ if (input_features != kernel_input_features * feature_group_count) {
return InvalidArgument(
"Expected LHS feature dimension (value %lld) to match RHS "
- "input feature dimension (value %lld); got <conv>(%s, %s)\n"
+ "input feature dimension * feature_group_count (value %lld); "
+ "got <conv>(%s, %s)\n"
"Dimension numbers: {%s}.",
- input_features, kernel_input_features,
+ input_features, kernel_input_features * feature_group_count,
ShapeUtil::HumanString(lhs).c_str(),
ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str());
}
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index c185b0a1bd..bfd79a4433 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -112,7 +112,8 @@ class ShapeInference {
// filter (rhs) to lhs in the way specified by the fields on window.
static StatusOr<Shape> InferConvolveShape(
const Shape& lhs, const Shape& rhs, const Window& window,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count = 1);
// Infers the shape produced by the given FFT type on the given operand.
static StatusOr<Shape> InferFftShape(
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index e24a7cda73..8c9d26fcbb 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -505,16 +505,17 @@ Computes a convolution of the kind used in neural networks. Here, a convolution
can be thought of as a n-dimensional window moving across a n-dimensional base
area and a computation is performed for each possible position of the window.
-| Arguments | Type | Semantics |
-| ---------------- | ----------------------- | ----------------------------- |
-| `lhs` | `XlaOp` | rank n+2 array of inputs |
-| `rhs` | `XlaOp` | rank n+2 array of kernel |
-: : : weights :
-| `window_strides` | `ArraySlice<int64>` | n-d array of kernel strides |
-| `padding` | `ArraySlice<pair<int64, | n-d array of (low, high) |
-: : int64>>` : padding :
-| `lhs_dilation` | `ArraySlice<int64>` | n-d lhs dilation factor array |
-| `rhs_dilation` | `ArraySlice<int64>` | n-d rhs dilation factor array |
+| Arguments | Type | Semantics |
+| --------------------- | -------------------- | ----------------------------- |
+| `lhs` | `XlaOp` | rank n+2 array of inputs |
+| `rhs` | `XlaOp` | rank n+2 array of kernel |
+: : : weights :
+| `window_strides` | `ArraySlice<int64>` | n-d array of kernel strides |
+| `padding` | `ArraySlice< | n-d array of (low, high) |
+: : pair<int64, int64>>` : padding :
+| `lhs_dilation` | `ArraySlice<int64>` | n-d lhs dilation factor array |
+| `rhs_dilation` | `ArraySlice<int64>` | n-d rhs dilation factor array |
+| `feature_group_count` | int64 | the number of feature groups |
Let n be the number of spatial dimensions. The `lhs` argument is a rank n+2
array describing the base area. This is called the input, even though of course
@@ -532,8 +533,8 @@ The `rhs` argument is a rank n+2 array describing the convolutional
filter/kernel/window. The dimensions are, in this order:
* `output-z`: The `z` dimension of the output.
-* `input-z`: The size of this dimension should equal the size of the `z`
- dimension in lhs.
+* `input-z`: The size of this dimension times `feature_group_count` should
+ equal the size of the `z` dimension in lhs.
* `spatial_dims`: Describes the `n` spatial dimensions that define the n-d
window that moves across the base area.
@@ -566,6 +567,24 @@ Dilation of the rhs is also called atrous convolution. For more details, see
`tf.nn.atrous_conv2d`. Dilation of the lhs is also called transposed
convolution. For more details, see `tf.nn.conv2d_transpose`.
+The `feature_group_count` argument (default value 1) can be used for grouped
+convolutions. `feature_group_count` needs to be a divisor of both the input and
+the output feature dimension. If `feature_group_count` is greater than 1, it
+means that conceptually the input and output feature dimension and the `rhs`
+output feature dimension are split evenly into `feature_group_count` many
+groups, each group consisting of a consecutive subsequence of features. The
+input feature dimension of `rhs` needs to be equal to the `lhs` input feature
+dimension divided by `feature_group_count` (so it already has the size of a
+group of input features). The i-th groups are used together to compute
+`feature_group_count` many separate convolutions. The results of these
+convolutions are concatenated together in the output feature dimension.
+
+For depthwise convolution the `feature_group_count` argument would be set to the
+input feature dimension, and the filter would be reshaped from
+`[filter_height, filter_width, in_channels, channel_multiplier]` to
+`[filter_height, filter_width, 1, in_channels * channel_multiplier]`. For more
+details, see `tf.nn.depthwise_conv2d`.
+
The output shape has these dimensions, in this order:
* `batch`: Same size as `batch` on the input (`lhs`).