diff options
author | Adrian Kuegel <akuegel@google.com> | 2018-09-03 04:52:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-03 04:57:16 -0700 |
commit | 0f4ad7ff0e5ce38bc09ddd008e4f32d2af321495 (patch) | |
tree | cc76401cd5057d454fc0dc82178c3747a122cb3f | |
parent | 44c884dc5d02abc7c50abea24c8caee6dcadda9a (diff) |
Call Cudnn also for grouped convolutions.
Cudnn supports grouped convolutions, so we don't need the
ConvolutionFeatureGroupConverter pass and can instead set the group_count
parameter on the cudnn custom calls.
PiperOrigin-RevId: 211339551
17 files changed, 147 insertions, 102 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index d780b5751c..a68b7a1bef 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -676,7 +676,6 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", - "//tensorflow/compiler/xla/service:convolution_feature_group_converter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index eea31f3de1..05448d863d 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -37,8 +37,8 @@ ConvolutionThunk::ConvolutionThunk( const BufferAllocation::Slice& tuple_result_buffer, const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, - bool tensor_ops_enabled, const HloInstruction* hlo) + const ConvolutionDimensionNumbers& dim_nums, int64 feature_group_count, + int64 algorithm, bool tensor_ops_enabled, const HloInstruction* hlo) : Thunk(Kind::kConvolution, hlo), convolution_kind_(convolution_kind), input_buffer_(input_buffer), @@ -51,6 +51,7 @@ ConvolutionThunk::ConvolutionThunk( output_shape_(output_shape), window_(window), dim_nums_(dim_nums), + feature_group_count_(feature_group_count), algorithm_(algorithm), tensor_ops_enabled_(tensor_ops_enabled) {} @@ -72,8 +73,8 @@ Status ConvolutionThunk::ExecuteOnStream( auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); TF_RETURN_IF_ERROR(RunCudnnConvolution( convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, - filter_data, output_data, scratch, window_, dim_nums_, algorithm_config, - stream)); + filter_data, output_data, scratch, window_, dim_nums_, + feature_group_count_, algorithm_config, stream)); // Figure out which of output/input/filter is the result produced by // this op, and write the result tuple. diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index f7952787c1..68d67c40c5 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -59,7 +59,8 @@ class ConvolutionThunk : public Thunk { const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, + const ConvolutionDimensionNumbers& dim_nums, + int64 feature_group_count, int64 algorithm, bool tensor_ops_enabled, const HloInstruction* hlo); ConvolutionThunk(const ConvolutionThunk&) = delete; @@ -71,19 +72,6 @@ class ConvolutionThunk : public Thunk { HloExecutionProfiler* profiler) override; private: - class ScratchAllocator; - - Status Convolve(const se::dnn::BatchDescriptor& input_descriptor, - se::DeviceMemory<float> input_data, - const se::dnn::FilterDescriptor& filter_descriptor, - se::DeviceMemory<float> filter_data, - const se::dnn::BatchDescriptor& output_descriptor, - se::DeviceMemory<float> output_data, - const se::dnn::ConvolutionDescriptor& convolution_descriptor, - const se::dnn::AlgorithmConfig& algorithm_config, - se::Stream* stream, ScratchAllocator* scratch_allocator, - se::dnn::ProfileResult* profile_result); - const CudnnConvKind convolution_kind_; const BufferAllocation::Slice input_buffer_; @@ -98,6 +86,7 @@ class ConvolutionThunk : public Thunk { const Window window_; const ConvolutionDimensionNumbers dim_nums_; + int64 feature_group_count_; int64 algorithm_; bool tensor_ops_enabled_; }; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 2af31a52f9..0baced71c0 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -178,7 +178,8 @@ StatusOr<std::tuple<int64, bool, int64>> CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) { + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, + HloInstruction* instr) { CHECK_EQ(input_shape.element_type(), filter_shape.element_type()); CHECK_EQ(input_shape.element_type(), output_shape.element_type()); // TODO(timshen): for now only check fp16. It can be expanded to other types, @@ -289,10 +290,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( << instr->ToString(); bool launch_ok = - RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - input_buf, filter_buf, output_buf, - &scratch_allocator, window, dnums, - AlgorithmConfig(alg), &stream, &profile_result) + RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, input_buf, + filter_buf, output_buf, &scratch_allocator, window, dnums, + feature_group_count, AlgorithmConfig(alg), &stream, &profile_result) .ok(); if (launch_ok && profile_result.is_valid()) { @@ -378,17 +379,20 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction( PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape, /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, instr->window(), - instr->convolution_dimension_numbers(), instr); + instr->convolution_dimension_numbers(), + instr->feature_group_count(), instr); } else if (call_target == kCudnnConvBackwardInputCallTarget) { alg_scratch_and_tc = PickBestAlgorithm( CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape, /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(), - instr->convolution_dimension_numbers(), instr); + instr->convolution_dimension_numbers(), instr->feature_group_count(), + instr); } else if (call_target == kCudnnConvBackwardFilterCallTarget) { alg_scratch_and_tc = PickBestAlgorithm( CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape, /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, - instr->window(), instr->convolution_dimension_numbers(), instr); + instr->window(), instr->convolution_dimension_numbers(), + instr->feature_group_count(), instr); } else { LOG(FATAL) << "Unknown custom call target for cudnn conv: " << instr->ToString(); @@ -422,14 +426,9 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction( backend_config.set_algorithm(algorithm); backend_config.set_tensor_ops_enabled(tensor_ops_enabled); - HloInstruction* new_call = - computation->AddInstruction(HloInstruction::CreateCustomCall( - new_call_shape, - {instr->mutable_operand(0), instr->mutable_operand(1)}, - instr->custom_call_target())); - new_call->set_window(instr->window()); - new_call->set_convolution_dimension_numbers( - instr->convolution_dimension_numbers()); + HloInstruction* new_call = computation->AddInstruction( + instr->CloneWithNewOperands(new_call_shape, {instr->mutable_operand(0), + instr->mutable_operand(1)})); TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); // Repackage new_call so it has the same shape as the original call, namely diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index f76d273e8c..0cb01161b0 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -51,7 +51,8 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, HloInstruction* instr); + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, + HloInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index 0b1ee2dc33..9bf721ecd2 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -59,6 +59,11 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter( HloInstruction* conv) { const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); + // TODO(b/31709653): Figure out if we can use grouped convolutions also on + // backward filter. + if (conv->feature_group_count() > 1) { + return no_match_result; + } // Step 1: match the instruction pattern without considering the paddings and // dimension numbers just yet. We may need some generic pattern matcher // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h @@ -218,6 +223,12 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput( const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); + // TODO(b/31709653): Figure out if we can use grouped convolutions also on + // backward input. + if (conv->feature_group_count() > 1) { + return no_match_result; + } + // Match instruction pattern. CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); HloInstruction* reverse_filter = conv->mutable_operand(1); @@ -425,7 +436,7 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) { if (match) { return CreateCudnnConvBackwardFilter( conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1), - window, dnums); + window, dnums, conv->feature_group_count()); } std::tie(match, window, dnums) = MatchBackwardInput(conv); @@ -435,15 +446,17 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) { CHECK_EQ(reverse->opcode(), HloOpcode::kReverse); HloInstruction* rhs = reverse->mutable_operand(0); - return CreateCudnnConvBackwardInput( - conv->shape(), conv->mutable_operand(0), rhs, window, dnums); + return CreateCudnnConvBackwardInput(conv->shape(), + conv->mutable_operand(0), rhs, window, + dnums, conv->feature_group_count()); } // If all else fails, try a forward convolution. if (CanImplementAsCudnnForwardConv(conv)) { return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1), conv->window(), - conv->convolution_dimension_numbers()); + conv->convolution_dimension_numbers(), + conv->feature_group_count()); } return nullptr; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 07b96fbd3f..05125e9d1f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -77,8 +77,9 @@ Status RunCudnnConvolution( const Shape& output_shape, DeviceMemory<T> input_buf, DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf, se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm, - Stream* stream, ProfileResult* profile_result /*= nullptr*/) { + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, + AlgorithmConfig algorithm, Stream* stream, + ProfileResult* profile_result /*= nullptr*/) { VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id(); VLOG(3) << "tensor_ops_enabled: " << algorithm.algorithm().tensor_ops_enabled(); @@ -144,6 +145,7 @@ Status RunCudnnConvolution( } ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); + convolution_descriptor.set_group_count(feature_group_count); for (int dim = 0; dim < num_dimensions; ++dim) { convolution_descriptor .set_zero_padding( @@ -222,14 +224,14 @@ Status RunCudnnConvolution( const Shape& output_shape, se::DeviceMemoryBase input_buf, se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, se::dnn::AlgorithmConfig algorithm, se::Stream* stream, se::dnn::ProfileResult* profile_result) { ScratchBufAllocator scratch_allocator(scratch_buf); - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - input_buf, filter_buf, output_buf, - &scratch_allocator, window, dnums, algorithm, - stream, profile_result); + return RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, input_buf, filter_buf, + output_buf, &scratch_allocator, window, dnums, feature_group_count, + algorithm, stream, profile_result); } Status RunCudnnConvolution( @@ -237,32 +239,32 @@ Status RunCudnnConvolution( const Shape& output_shape, se::DeviceMemoryBase input_buf, se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, se::dnn::AlgorithmConfig algorithm, se::Stream* stream, se::dnn::ProfileResult* profile_result) { PrimitiveType output_primitive_type = output_shape.element_type(); switch (output_primitive_type) { case F16: - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - se::DeviceMemory<Eigen::half>(input_buf), - se::DeviceMemory<Eigen::half>(filter_buf), - se::DeviceMemory<Eigen::half>(output_buf), - scratch_allocator, window, dnums, algorithm, - stream, profile_result); + return RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, + se::DeviceMemory<Eigen::half>(input_buf), + se::DeviceMemory<Eigen::half>(filter_buf), + se::DeviceMemory<Eigen::half>(output_buf), scratch_allocator, window, + dnums, feature_group_count, algorithm, stream, profile_result); case F32: - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - se::DeviceMemory<float>(input_buf), - se::DeviceMemory<float>(filter_buf), - se::DeviceMemory<float>(output_buf), - scratch_allocator, window, dnums, algorithm, - stream, profile_result); + return RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, + se::DeviceMemory<float>(input_buf), + se::DeviceMemory<float>(filter_buf), + se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums, + feature_group_count, algorithm, stream, profile_result); case F64: - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - se::DeviceMemory<double>(input_buf), - se::DeviceMemory<double>(filter_buf), - se::DeviceMemory<double>(output_buf), - scratch_allocator, window, dnums, algorithm, - stream, profile_result); + return RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, + se::DeviceMemory<double>(input_buf), + se::DeviceMemory<double>(filter_buf), + se::DeviceMemory<double>(output_buf), scratch_allocator, window, + dnums, feature_group_count, algorithm, stream, profile_result); default: LOG(FATAL) << ShapeUtil::HumanString(output_shape); } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h index 944e4ac686..a1b4fc71d0 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h @@ -75,7 +75,7 @@ Status RunCudnnConvolution( const Shape& output_shape, se::DeviceMemoryBase input_buf, se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, se::dnn::AlgorithmConfig algorithm, se::Stream* stream, se::dnn::ProfileResult* profile_result = nullptr); @@ -84,7 +84,7 @@ Status RunCudnnConvolution( const Shape& output_shape, se::DeviceMemoryBase input_buf, se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, se::dnn::AlgorithmConfig algorithm, se::Stream* stream, se::dnn::ProfileResult* profile_result = nullptr); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 9c90f4d46b..20d523abe0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -144,10 +144,12 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo) { IsCustomCallToDnnConvolution(hlo); } -static HloInstruction* CreateCudnnConv( - const char* call_target, const Shape& shape, HloInstruction* lhs, - HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dnums) { +static HloInstruction* CreateCudnnConv(const char* call_target, + const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs, + const Window& window, + const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count) { HloComputation* computation = lhs->parent(); // This call returns a tuple of (conv_result, scratch_memory), where @@ -165,28 +167,34 @@ static HloInstruction* CreateCudnnConv( HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); custom_call->set_window(window); custom_call->set_convolution_dimension_numbers(dnums); + custom_call->set_feature_group_count(feature_group_count); return custom_call; } -HloInstruction* CreateCudnnConvForward( - const Shape& shape, HloInstruction* input, HloInstruction* kernel, - const Window& window, const ConvolutionDimensionNumbers& dnums) { +HloInstruction* CreateCudnnConvForward(const Shape& shape, + HloInstruction* input, + HloInstruction* kernel, + const Window& window, + const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count) { return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel, - window, dnums); + window, dnums, feature_group_count); } HloInstruction* CreateCudnnConvBackwardInput( const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, - const Window& window, const ConvolutionDimensionNumbers& dnums) { + const Window& window, const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count) { return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output, - reverse_filter, window, dnums); + reverse_filter, window, dnums, feature_group_count); } HloInstruction* CreateCudnnConvBackwardFilter( const Shape& shape, HloInstruction* input, HloInstruction* output, - const Window& window, const ConvolutionDimensionNumbers& dnums) { + const Window& window, const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count) { return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input, - output, window, dnums); + output, window, dnums, feature_group_count); } bool IsReductionToVector(const HloInstruction& reduce) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index d242897e16..59c65fc268 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -109,15 +109,20 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo); // // The created cudnn call will use the default cudnn algorithm and no scratch // space. -HloInstruction* CreateCudnnConvForward( - const Shape& shape, HloInstruction* input, HloInstruction* kernel, - const Window& window, const ConvolutionDimensionNumbers& dnums); +HloInstruction* CreateCudnnConvForward(const Shape& shape, + HloInstruction* input, + HloInstruction* kernel, + const Window& window, + const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count); HloInstruction* CreateCudnnConvBackwardInput( const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, - const Window& window, const ConvolutionDimensionNumbers& dnums); + const Window& window, const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count); HloInstruction* CreateCudnnConvBackwardFilter( const Shape& shape, HloInstruction* input, HloInstruction* output, - const Window& window, const ConvolutionDimensionNumbers& dnums); + const Window& window, const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count); // Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm // or cuDNN convolution. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 78f61a4987..389a98facb 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -489,8 +489,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - backend_config.algorithm(), backend_config.tensor_ops_enabled(), - custom_call); + custom_call->feature_group_count(), backend_config.algorithm(), + backend_config.tensor_ops_enabled(), custom_call); } else if (target == kCudnnConvBackwardInputCallTarget) { thunk = absl::make_unique<ConvolutionThunk>( CudnnConvKind::kBackwardInput, @@ -503,8 +503,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - backend_config.algorithm(), backend_config.tensor_ops_enabled(), - custom_call); + custom_call->feature_group_count(), backend_config.algorithm(), + backend_config.tensor_ops_enabled(), custom_call); } else if (target == kCudnnConvBackwardFilterCallTarget) { thunk = absl::make_unique<ConvolutionThunk>( CudnnConvKind::kBackwardFilter, @@ -517,8 +517,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - backend_config.algorithm(), backend_config.tensor_ops_enabled(), - custom_call); + custom_call->feature_group_count(), backend_config.algorithm(), + backend_config.tensor_ops_enabled(), custom_call); } else { LOG(FATAL) << "Unexpected custom call target: " << custom_call->custom_call_target(); diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 8ce67c03b6..f6325b3368 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" -#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" @@ -208,8 +207,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, HloPassPipeline pipeline("conv_canonicalization"); pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - // TODO(b/31709653): Directly use the grouped convolution support of Cudnn. - pipeline.AddPass<ConvolutionFeatureGroupConverter>(); pipeline.AddPass<CudnnConvolutionRewriter>(); // CudnnConvolutionRewriter may add instructions of the form // reverse(constant), which it expects will be simplified by constant diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 98cc21ccac..9d85d746d8 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -166,9 +166,9 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { Shape old_conv_shape = conv->shape().tuple_shapes(0); VLOG(1) << "Canonicalizing forward conv"; - auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel, - new_conv_window, - conv->convolution_dimension_numbers()); + auto new_conv = CreateCudnnConvForward( + old_conv_shape, new_input, new_kernel, new_conv_window, + conv->convolution_dimension_numbers(), conv->feature_group_count()); VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " << new_conv->ToString(); TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv)); @@ -247,7 +247,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter( backward_conv_shape, padded_input, output, new_backward_conv_window, - backward_conv_dnums); + backward_conv_dnums, backward_conv->feature_group_count()); VLOG(1) << "Canonicalizing backward filter conv"; VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " @@ -312,7 +312,7 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput( new_backward_conv_shape, output, filter, new_backward_conv_window, - backward_conv_dnums); + backward_conv_dnums, backward_conv->feature_group_count()); // The CustomCall created above returns a tuple (conv_result, scratch_memory). // Extract out the two elements. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index bd0b6af10d..6d13f85cbb 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -385,6 +385,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( ->set_convolution_dimension_numbers( proto.convolution_dimension_numbers()); } + static_cast<HloCustomCallInstruction*>(instruction.get()) + ->set_feature_group_count( + std::max(static_cast<int64>(proto.feature_group_count()), 1LL)); break; case HloOpcode::kPad: TF_RET_CHECK(proto.operand_ids_size() == 2) @@ -3269,7 +3272,15 @@ void HloInstruction::set_convolution_dimension_numbers( } int64 HloInstruction::feature_group_count() const { - return Cast<HloConvolutionInstruction>(this)->feature_group_count(); + if (auto convolution = DynCast<HloConvolutionInstruction>(this)) { + return convolution->feature_group_count(); + } + return Cast<HloCustomCallInstruction>(this)->feature_group_count(); +} + +void HloInstruction::set_feature_group_count(int64 feature_group_count) { + Cast<HloCustomCallInstruction>(this)->set_feature_group_count( + feature_group_count); } HloComputation* HloInstruction::select() const { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 08f3d5356f..cca134e8b4 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1475,6 +1475,8 @@ class HloInstruction { // dimension and output feature dimension. int64 feature_group_count() const; + void set_feature_group_count(int64 feature_group_count); + // Delegates to HloSelectAndScatterInstruction::select. HloComputation* select() const; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 6871953755..e46afa764f 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1660,6 +1660,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const { *proto.mutable_window() = window_; *proto.mutable_convolution_dimension_numbers() = convolution_dimension_numbers_; + proto.set_feature_group_count(feature_group_count_); return proto; } @@ -1681,6 +1682,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath( eq_computations) const { const auto& casted_other = static_cast<const HloConvolutionInstruction&>(other); + if (feature_group_count_ != other.feature_group_count()) { + return false; + } return protobuf_util::ProtobufEquals(window(), casted_other.window()) && protobuf_util::ProtobufEquals( convolution_dimension_numbers(), @@ -1793,8 +1797,8 @@ HloCustomCallInstruction::HloCustomCallInstruction( const Shape& shape, absl::Span<HloInstruction* const> operands, absl::string_view custom_call_target) : HloInstruction(HloOpcode::kCustomCall, shape), - custom_call_target_(custom_call_target.begin(), - custom_call_target.end()) { + custom_call_target_(custom_call_target.begin(), custom_call_target.end()), + feature_group_count_(1) { for (auto operand : operands) { AppendOperand(operand); } @@ -1810,6 +1814,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { *convolution_dimension_numbers_; } proto.set_custom_call_target(custom_call_target_); + proto.set_feature_group_count(feature_group_count_); return proto; } @@ -1824,6 +1829,9 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl( "dim_labels=", ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); } + if (feature_group_count_ != 1) { + extra.push_back(StrCat("feature_group_count=", feature_group_count_)); + } // By contract, we print the custom call target even if // options.print_subcomputation_mode() == kOff, because the call target is not // an HloComputation. @@ -1851,6 +1859,9 @@ bool HloCustomCallInstruction::IdenticalSlowPath( casted_other.convolution_dimension_numbers()))) { return false; } + if (feature_group_count_ != casted_other.feature_group_count_) { + return false; + } return custom_call_target_ == casted_other.custom_call_target_; } @@ -1866,6 +1877,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( if (convolution_dimension_numbers_ != nullptr) { cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_); } + cloned->set_feature_group_count(feature_group_count_); return std::move(cloned); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 45a648bbe4..3230383579 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1079,6 +1079,10 @@ class HloCustomCallInstruction : public HloInstruction { absl::make_unique<ConvolutionDimensionNumbers>(dnums); } const string& custom_call_target() const { return custom_call_target_; } + void set_feature_group_count(int64 feature_group_count) { + feature_group_count_ = feature_group_count; + } + int64 feature_group_count() const { return feature_group_count_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1099,6 +1103,8 @@ class HloCustomCallInstruction : public HloInstruction { std::unique_ptr<Window> window_; // Describes the dimension numbers used for a convolution. std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_; + // The number of feature groups. This is used for grouped convolutions. + int64 feature_group_count_; }; class HloPadInstruction : public HloInstruction { |