aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-09-03 04:52:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-03 04:57:16 -0700
commit0f4ad7ff0e5ce38bc09ddd008e4f32d2af321495 (patch)
treecc76401cd5057d454fc0dc82178c3747a122cb3f
parent44c884dc5d02abc7c50abea24c8caee6dcadda9a (diff)
Call Cudnn also for grouped convolutions.
Cudnn supports grouped convolutions, so we don't need the ConvolutionFeatureGroupConverter pass and can instead set the group_count parameter on the cudnn custom calls. PiperOrigin-RevId: 211339551
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h17
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc31
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc21
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc54
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc32
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.h15
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc12
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc16
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h6
17 files changed, 147 insertions, 102 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index d780b5751c..a68b7a1bef 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -676,7 +676,6 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_liveness",
"//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:conditional_simplifier",
- "//tensorflow/compiler/xla/service:convolution_feature_group_converter",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
"//tensorflow/compiler/xla/service:hlo",
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index eea31f3de1..05448d863d 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -37,8 +37,8 @@ ConvolutionThunk::ConvolutionThunk(
const BufferAllocation::Slice& tuple_result_buffer,
const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape,
const Shape& filter_shape, const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dim_nums, int64 algorithm,
- bool tensor_ops_enabled, const HloInstruction* hlo)
+ const ConvolutionDimensionNumbers& dim_nums, int64 feature_group_count,
+ int64 algorithm, bool tensor_ops_enabled, const HloInstruction* hlo)
: Thunk(Kind::kConvolution, hlo),
convolution_kind_(convolution_kind),
input_buffer_(input_buffer),
@@ -51,6 +51,7 @@ ConvolutionThunk::ConvolutionThunk(
output_shape_(output_shape),
window_(window),
dim_nums_(dim_nums),
+ feature_group_count_(feature_group_count),
algorithm_(algorithm),
tensor_ops_enabled_(tensor_ops_enabled) {}
@@ -72,8 +73,8 @@ Status ConvolutionThunk::ExecuteOnStream(
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
TF_RETURN_IF_ERROR(RunCudnnConvolution(
convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data,
- filter_data, output_data, scratch, window_, dim_nums_, algorithm_config,
- stream));
+ filter_data, output_data, scratch, window_, dim_nums_,
+ feature_group_count_, algorithm_config, stream));
// Figure out which of output/input/filter is the result produced by
// this op, and write the result tuple.
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index f7952787c1..68d67c40c5 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -59,7 +59,8 @@ class ConvolutionThunk : public Thunk {
const BufferAllocation::Slice& scratch_buffer,
const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dim_nums, int64 algorithm,
+ const ConvolutionDimensionNumbers& dim_nums,
+ int64 feature_group_count, int64 algorithm,
bool tensor_ops_enabled, const HloInstruction* hlo);
ConvolutionThunk(const ConvolutionThunk&) = delete;
@@ -71,19 +72,6 @@ class ConvolutionThunk : public Thunk {
HloExecutionProfiler* profiler) override;
private:
- class ScratchAllocator;
-
- Status Convolve(const se::dnn::BatchDescriptor& input_descriptor,
- se::DeviceMemory<float> input_data,
- const se::dnn::FilterDescriptor& filter_descriptor,
- se::DeviceMemory<float> filter_data,
- const se::dnn::BatchDescriptor& output_descriptor,
- se::DeviceMemory<float> output_data,
- const se::dnn::ConvolutionDescriptor& convolution_descriptor,
- const se::dnn::AlgorithmConfig& algorithm_config,
- se::Stream* stream, ScratchAllocator* scratch_allocator,
- se::dnn::ProfileResult* profile_result);
-
const CudnnConvKind convolution_kind_;
const BufferAllocation::Slice input_buffer_;
@@ -98,6 +86,7 @@ class ConvolutionThunk : public Thunk {
const Window window_;
const ConvolutionDimensionNumbers dim_nums_;
+ int64 feature_group_count_;
int64 algorithm_;
bool tensor_ops_enabled_;
};
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 2af31a52f9..0baced71c0 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -178,7 +178,8 @@ StatusOr<std::tuple<int64, bool, int64>>
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) {
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
+ HloInstruction* instr) {
CHECK_EQ(input_shape.element_type(), filter_shape.element_type());
CHECK_EQ(input_shape.element_type(), output_shape.element_type());
// TODO(timshen): for now only check fp16. It can be expanded to other types,
@@ -289,10 +290,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
<< instr->ToString();
bool launch_ok =
- RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- input_buf, filter_buf, output_buf,
- &scratch_allocator, window, dnums,
- AlgorithmConfig(alg), &stream, &profile_result)
+ RunCudnnConvolution(
+ kind, input_shape, filter_shape, output_shape, input_buf,
+ filter_buf, output_buf, &scratch_allocator, window, dnums,
+ feature_group_count, AlgorithmConfig(alg), &stream, &profile_result)
.ok();
if (launch_ok && profile_result.is_valid()) {
@@ -378,17 +379,20 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
/*filter_shape=*/rhs_shape,
/*output_shape=*/conv_result_shape, instr->window(),
- instr->convolution_dimension_numbers(), instr);
+ instr->convolution_dimension_numbers(),
+ instr->feature_group_count(), instr);
} else if (call_target == kCudnnConvBackwardInputCallTarget) {
alg_scratch_and_tc = PickBestAlgorithm(
CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape,
/*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(),
- instr->convolution_dimension_numbers(), instr);
+ instr->convolution_dimension_numbers(), instr->feature_group_count(),
+ instr);
} else if (call_target == kCudnnConvBackwardFilterCallTarget) {
alg_scratch_and_tc = PickBestAlgorithm(
CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape,
/*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape,
- instr->window(), instr->convolution_dimension_numbers(), instr);
+ instr->window(), instr->convolution_dimension_numbers(),
+ instr->feature_group_count(), instr);
} else {
LOG(FATAL) << "Unknown custom call target for cudnn conv: "
<< instr->ToString();
@@ -422,14 +426,9 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
backend_config.set_algorithm(algorithm);
backend_config.set_tensor_ops_enabled(tensor_ops_enabled);
- HloInstruction* new_call =
- computation->AddInstruction(HloInstruction::CreateCustomCall(
- new_call_shape,
- {instr->mutable_operand(0), instr->mutable_operand(1)},
- instr->custom_call_target()));
- new_call->set_window(instr->window());
- new_call->set_convolution_dimension_numbers(
- instr->convolution_dimension_numbers());
+ HloInstruction* new_call = computation->AddInstruction(
+ instr->CloneWithNewOperands(new_call_shape, {instr->mutable_operand(0),
+ instr->mutable_operand(1)}));
TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config));
// Repackage new_call so it has the same shape as the original call, namely
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index f76d273e8c..0cb01161b0 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -51,7 +51,8 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dnums, HloInstruction* instr);
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
+ HloInstruction* instr);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index 0b1ee2dc33..9bf721ecd2 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -59,6 +59,11 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
HloInstruction* conv) {
const auto no_match_result =
std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
+ // TODO(b/31709653): Figure out if we can use grouped convolutions also on
+ // backward filter.
+ if (conv->feature_group_count() > 1) {
+ return no_match_result;
+ }
// Step 1: match the instruction pattern without considering the paddings and
// dimension numbers just yet. We may need some generic pattern matcher
// similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
@@ -218,6 +223,12 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
const auto no_match_result =
std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
+ // TODO(b/31709653): Figure out if we can use grouped convolutions also on
+ // backward input.
+ if (conv->feature_group_count() > 1) {
+ return no_match_result;
+ }
+
// Match instruction pattern.
CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
HloInstruction* reverse_filter = conv->mutable_operand(1);
@@ -425,7 +436,7 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
if (match) {
return CreateCudnnConvBackwardFilter(
conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1),
- window, dnums);
+ window, dnums, conv->feature_group_count());
}
std::tie(match, window, dnums) = MatchBackwardInput(conv);
@@ -435,15 +446,17 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
CHECK_EQ(reverse->opcode(), HloOpcode::kReverse);
HloInstruction* rhs = reverse->mutable_operand(0);
- return CreateCudnnConvBackwardInput(
- conv->shape(), conv->mutable_operand(0), rhs, window, dnums);
+ return CreateCudnnConvBackwardInput(conv->shape(),
+ conv->mutable_operand(0), rhs, window,
+ dnums, conv->feature_group_count());
}
// If all else fails, try a forward convolution.
if (CanImplementAsCudnnForwardConv(conv)) {
return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0),
conv->mutable_operand(1), conv->window(),
- conv->convolution_dimension_numbers());
+ conv->convolution_dimension_numbers(),
+ conv->feature_group_count());
}
return nullptr;
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 07b96fbd3f..05125e9d1f 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -77,8 +77,9 @@ Status RunCudnnConvolution(
const Shape& output_shape, DeviceMemory<T> input_buf,
DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf,
se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm,
- Stream* stream, ProfileResult* profile_result /*= nullptr*/) {
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
+ AlgorithmConfig algorithm, Stream* stream,
+ ProfileResult* profile_result /*= nullptr*/) {
VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id();
VLOG(3) << "tensor_ops_enabled: "
<< algorithm.algorithm().tensor_ops_enabled();
@@ -144,6 +145,7 @@ Status RunCudnnConvolution(
}
ConvolutionDescriptor convolution_descriptor(effective_num_dimensions);
+ convolution_descriptor.set_group_count(feature_group_count);
for (int dim = 0; dim < num_dimensions; ++dim) {
convolution_descriptor
.set_zero_padding(
@@ -222,14 +224,14 @@ Status RunCudnnConvolution(
const Shape& output_shape, se::DeviceMemoryBase input_buf,
se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
se::DeviceMemoryBase scratch_buf, const Window& window,
- const ConvolutionDimensionNumbers& dnums,
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
ScratchBufAllocator scratch_allocator(scratch_buf);
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- input_buf, filter_buf, output_buf,
- &scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
+ return RunCudnnConvolution(
+ kind, input_shape, filter_shape, output_shape, input_buf, filter_buf,
+ output_buf, &scratch_allocator, window, dnums, feature_group_count,
+ algorithm, stream, profile_result);
}
Status RunCudnnConvolution(
@@ -237,32 +239,32 @@ Status RunCudnnConvolution(
const Shape& output_shape, se::DeviceMemoryBase input_buf,
se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums,
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
PrimitiveType output_primitive_type = output_shape.element_type();
switch (output_primitive_type) {
case F16:
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<Eigen::half>(input_buf),
- se::DeviceMemory<Eigen::half>(filter_buf),
- se::DeviceMemory<Eigen::half>(output_buf),
- scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
+ return RunCudnnConvolution(
+ kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<Eigen::half>(input_buf),
+ se::DeviceMemory<Eigen::half>(filter_buf),
+ se::DeviceMemory<Eigen::half>(output_buf), scratch_allocator, window,
+ dnums, feature_group_count, algorithm, stream, profile_result);
case F32:
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<float>(input_buf),
- se::DeviceMemory<float>(filter_buf),
- se::DeviceMemory<float>(output_buf),
- scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
+ return RunCudnnConvolution(
+ kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<float>(input_buf),
+ se::DeviceMemory<float>(filter_buf),
+ se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums,
+ feature_group_count, algorithm, stream, profile_result);
case F64:
- return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape,
- se::DeviceMemory<double>(input_buf),
- se::DeviceMemory<double>(filter_buf),
- se::DeviceMemory<double>(output_buf),
- scratch_allocator, window, dnums, algorithm,
- stream, profile_result);
+ return RunCudnnConvolution(
+ kind, input_shape, filter_shape, output_shape,
+ se::DeviceMemory<double>(input_buf),
+ se::DeviceMemory<double>(filter_buf),
+ se::DeviceMemory<double>(output_buf), scratch_allocator, window,
+ dnums, feature_group_count, algorithm, stream, profile_result);
default:
LOG(FATAL) << ShapeUtil::HumanString(output_shape);
}
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
index 944e4ac686..a1b4fc71d0 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
@@ -75,7 +75,7 @@ Status RunCudnnConvolution(
const Shape& output_shape, se::DeviceMemoryBase input_buf,
se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
se::DeviceMemoryBase scratch_buf, const Window& window,
- const ConvolutionDimensionNumbers& dnums,
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
se::dnn::ProfileResult* profile_result = nullptr);
@@ -84,7 +84,7 @@ Status RunCudnnConvolution(
const Shape& output_shape, se::DeviceMemoryBase input_buf,
se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
se::ScratchAllocator* scratch_allocator, const Window& window,
- const ConvolutionDimensionNumbers& dnums,
+ const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
se::dnn::ProfileResult* profile_result = nullptr);
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 9c90f4d46b..20d523abe0 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -144,10 +144,12 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
IsCustomCallToDnnConvolution(hlo);
}
-static HloInstruction* CreateCudnnConv(
- const char* call_target, const Shape& shape, HloInstruction* lhs,
- HloInstruction* rhs, const Window& window,
- const ConvolutionDimensionNumbers& dnums) {
+static HloInstruction* CreateCudnnConv(const char* call_target,
+ const Shape& shape, HloInstruction* lhs,
+ HloInstruction* rhs,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count) {
HloComputation* computation = lhs->parent();
// This call returns a tuple of (conv_result, scratch_memory), where
@@ -165,28 +167,34 @@ static HloInstruction* CreateCudnnConv(
HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
custom_call->set_window(window);
custom_call->set_convolution_dimension_numbers(dnums);
+ custom_call->set_feature_group_count(feature_group_count);
return custom_call;
}
-HloInstruction* CreateCudnnConvForward(
- const Shape& shape, HloInstruction* input, HloInstruction* kernel,
- const Window& window, const ConvolutionDimensionNumbers& dnums) {
+HloInstruction* CreateCudnnConvForward(const Shape& shape,
+ HloInstruction* input,
+ HloInstruction* kernel,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count) {
return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel,
- window, dnums);
+ window, dnums, feature_group_count);
}
HloInstruction* CreateCudnnConvBackwardInput(
const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter,
- const Window& window, const ConvolutionDimensionNumbers& dnums) {
+ const Window& window, const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count) {
return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output,
- reverse_filter, window, dnums);
+ reverse_filter, window, dnums, feature_group_count);
}
HloInstruction* CreateCudnnConvBackwardFilter(
const Shape& shape, HloInstruction* input, HloInstruction* output,
- const Window& window, const ConvolutionDimensionNumbers& dnums) {
+ const Window& window, const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count) {
return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input,
- output, window, dnums);
+ output, window, dnums, feature_group_count);
}
bool IsReductionToVector(const HloInstruction& reduce) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index d242897e16..59c65fc268 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -109,15 +109,20 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo);
//
// The created cudnn call will use the default cudnn algorithm and no scratch
// space.
-HloInstruction* CreateCudnnConvForward(
- const Shape& shape, HloInstruction* input, HloInstruction* kernel,
- const Window& window, const ConvolutionDimensionNumbers& dnums);
+HloInstruction* CreateCudnnConvForward(const Shape& shape,
+ HloInstruction* input,
+ HloInstruction* kernel,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count);
HloInstruction* CreateCudnnConvBackwardInput(
const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter,
- const Window& window, const ConvolutionDimensionNumbers& dnums);
+ const Window& window, const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count);
HloInstruction* CreateCudnnConvBackwardFilter(
const Shape& shape, HloInstruction* input, HloInstruction* output,
- const Window& window, const ConvolutionDimensionNumbers& dnums);
+ const Window& window, const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count);
// Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm
// or cuDNN convolution.
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 78f61a4987..389a98facb 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -489,8 +489,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
/*filter_shape=*/rhs_shape,
/*output_shape=*/conv_result_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
- backend_config.algorithm(), backend_config.tensor_ops_enabled(),
- custom_call);
+ custom_call->feature_group_count(), backend_config.algorithm(),
+ backend_config.tensor_ops_enabled(), custom_call);
} else if (target == kCudnnConvBackwardInputCallTarget) {
thunk = absl::make_unique<ConvolutionThunk>(
CudnnConvKind::kBackwardInput,
@@ -503,8 +503,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
/*filter_shape=*/rhs_shape,
/*output_shape=*/lhs_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
- backend_config.algorithm(), backend_config.tensor_ops_enabled(),
- custom_call);
+ custom_call->feature_group_count(), backend_config.algorithm(),
+ backend_config.tensor_ops_enabled(), custom_call);
} else if (target == kCudnnConvBackwardFilterCallTarget) {
thunk = absl::make_unique<ConvolutionThunk>(
CudnnConvKind::kBackwardFilter,
@@ -517,8 +517,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
/*filter_shape=*/conv_result_shape,
/*output_shape=*/rhs_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
- backend_config.algorithm(), backend_config.tensor_ops_enabled(),
- custom_call);
+ custom_call->feature_group_count(), backend_config.algorithm(),
+ backend_config.tensor_ops_enabled(), custom_call);
} else {
LOG(FATAL) << "Unexpected custom call target: "
<< custom_call->custom_call_target();
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 8ce67c03b6..f6325b3368 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -36,7 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
-#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
@@ -208,8 +207,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
HloPassPipeline pipeline("conv_canonicalization");
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
- // TODO(b/31709653): Directly use the grouped convolution support of Cudnn.
- pipeline.AddPass<ConvolutionFeatureGroupConverter>();
pipeline.AddPass<CudnnConvolutionRewriter>();
// CudnnConvolutionRewriter may add instructions of the form
// reverse(constant), which it expects will be simplified by constant
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 98cc21ccac..9d85d746d8 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -166,9 +166,9 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
Shape old_conv_shape = conv->shape().tuple_shapes(0);
VLOG(1) << "Canonicalizing forward conv";
- auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel,
- new_conv_window,
- conv->convolution_dimension_numbers());
+ auto new_conv = CreateCudnnConvForward(
+ old_conv_shape, new_input, new_kernel, new_conv_window,
+ conv->convolution_dimension_numbers(), conv->feature_group_count());
VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n "
<< new_conv->ToString();
TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
@@ -247,7 +247,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter(
backward_conv_shape, padded_input, output, new_backward_conv_window,
- backward_conv_dnums);
+ backward_conv_dnums, backward_conv->feature_group_count());
VLOG(1) << "Canonicalizing backward filter conv";
VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
@@ -312,7 +312,7 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput(
new_backward_conv_shape, output, filter, new_backward_conv_window,
- backward_conv_dnums);
+ backward_conv_dnums, backward_conv->feature_group_count());
// The CustomCall created above returns a tuple (conv_result, scratch_memory).
// Extract out the two elements.
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index bd0b6af10d..6d13f85cbb 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -385,6 +385,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
->set_convolution_dimension_numbers(
proto.convolution_dimension_numbers());
}
+ static_cast<HloCustomCallInstruction*>(instruction.get())
+ ->set_feature_group_count(
+ std::max(static_cast<int64>(proto.feature_group_count()), 1LL));
break;
case HloOpcode::kPad:
TF_RET_CHECK(proto.operand_ids_size() == 2)
@@ -3269,7 +3272,15 @@ void HloInstruction::set_convolution_dimension_numbers(
}
int64 HloInstruction::feature_group_count() const {
- return Cast<HloConvolutionInstruction>(this)->feature_group_count();
+ if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
+ return convolution->feature_group_count();
+ }
+ return Cast<HloCustomCallInstruction>(this)->feature_group_count();
+}
+
+void HloInstruction::set_feature_group_count(int64 feature_group_count) {
+ Cast<HloCustomCallInstruction>(this)->set_feature_group_count(
+ feature_group_count);
}
HloComputation* HloInstruction::select() const {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 08f3d5356f..cca134e8b4 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -1475,6 +1475,8 @@ class HloInstruction {
// dimension and output feature dimension.
int64 feature_group_count() const;
+ void set_feature_group_count(int64 feature_group_count);
+
// Delegates to HloSelectAndScatterInstruction::select.
HloComputation* select() const;
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 6871953755..e46afa764f 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1660,6 +1660,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const {
*proto.mutable_window() = window_;
*proto.mutable_convolution_dimension_numbers() =
convolution_dimension_numbers_;
+ proto.set_feature_group_count(feature_group_count_);
return proto;
}
@@ -1681,6 +1682,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath(
eq_computations) const {
const auto& casted_other =
static_cast<const HloConvolutionInstruction&>(other);
+ if (feature_group_count_ != other.feature_group_count()) {
+ return false;
+ }
return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
protobuf_util::ProtobufEquals(
convolution_dimension_numbers(),
@@ -1793,8 +1797,8 @@ HloCustomCallInstruction::HloCustomCallInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target)
: HloInstruction(HloOpcode::kCustomCall, shape),
- custom_call_target_(custom_call_target.begin(),
- custom_call_target.end()) {
+ custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
+ feature_group_count_(1) {
for (auto operand : operands) {
AppendOperand(operand);
}
@@ -1810,6 +1814,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
*convolution_dimension_numbers_;
}
proto.set_custom_call_target(custom_call_target_);
+ proto.set_feature_group_count(feature_group_count_);
return proto;
}
@@ -1824,6 +1829,9 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
"dim_labels=",
ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
}
+ if (feature_group_count_ != 1) {
+ extra.push_back(StrCat("feature_group_count=", feature_group_count_));
+ }
// By contract, we print the custom call target even if
// options.print_subcomputation_mode() == kOff, because the call target is not
// an HloComputation.
@@ -1851,6 +1859,9 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
casted_other.convolution_dimension_numbers()))) {
return false;
}
+ if (feature_group_count_ != casted_other.feature_group_count_) {
+ return false;
+ }
return custom_call_target_ == casted_other.custom_call_target_;
}
@@ -1866,6 +1877,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
if (convolution_dimension_numbers_ != nullptr) {
cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
}
+ cloned->set_feature_group_count(feature_group_count_);
return std::move(cloned);
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 45a648bbe4..3230383579 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1079,6 +1079,10 @@ class HloCustomCallInstruction : public HloInstruction {
absl::make_unique<ConvolutionDimensionNumbers>(dnums);
}
const string& custom_call_target() const { return custom_call_target_; }
+ void set_feature_group_count(int64 feature_group_count) {
+ feature_group_count_ = feature_group_count;
+ }
+ int64 feature_group_count() const { return feature_group_count_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -1099,6 +1103,8 @@ class HloCustomCallInstruction : public HloInstruction {
std::unique_ptr<Window> window_;
// Describes the dimension numbers used for a convolution.
std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
+ // The number of feature groups. This is used for grouped convolutions.
+ int64 feature_group_count_;
};
class HloPadInstruction : public HloInstruction {