diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc | 58 |
1 files changed, 47 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index 3d1266355b..ef29237301 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -35,6 +36,32 @@ namespace gpu { namespace { +HloInstruction* CreateCudnnConv(const char* call_target, const Shape& shape, + HloInstruction* lhs, HloInstruction* rhs, + const Window& window, + const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count) { + HloComputation* computation = lhs->parent(); + + // This call returns a tuple of (conv_result, scratch_memory), where + // conv_result is the actual result of the convolution, and scratch_memory is + // temporary memory used by cudnn. + // + // At the moment, we don't know how much scratch memory this conv is going to + // use, so we put u8[0] in this place. Later on another pass will choose + // which conv algorithm to use, and at that point we'll modify the shape of + // this second tuple element. + Shape call_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); + + HloInstruction* custom_call = computation->AddInstruction( + HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); + custom_call->set_window(window); + custom_call->set_convolution_dimension_numbers(dnums); + custom_call->set_feature_group_count(feature_group_count); + return custom_call; +} + bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { const ConvolutionDimensionNumbers& dnums = conv->convolution_dimension_numbers(); @@ -263,7 +290,7 @@ MatchBackwardInput(HloInstruction* conv) { !(window_util::HasBaseDilation(conv->window()) && (reverse_filter->IsConstant() || is_1x1_filter))) { VLOG(1) << "Can't match to backwards convolution. Either filter is not " - "kReverse, or it's not a base-dialted conv with a 1x1 or " + "kReverse, or it's not a base-dilated conv with a 1x1 or " "constant filter."; return no_match_result; } @@ -450,6 +477,12 @@ MatchBackwardInput(HloInstruction* conv) { return std::make_tuple(true, new_window, dnums, rhs); } +CudnnConvBackendConfig GetDefaultBackendConfig() { + CudnnConvBackendConfig config; + config.set_conv_result_scale(1); + return config; +} + // Tries to rewrite a single convolution into a call to cudnn. StatusOr<bool> RunOnInstruction(HloInstruction* conv) { CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); @@ -462,24 +495,24 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) { std::tie(match, window, dnums) = MatchBackwardFilter(conv); if (match) { - return CreateCudnnConvBackwardFilter( - conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1), - window, dnums, conv->feature_group_count()); + return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), + conv->mutable_operand(0), conv->mutable_operand(1), + window, dnums, conv->feature_group_count()); } std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv); if (match) { - return CreateCudnnConvBackwardInput(conv->shape(), - conv->mutable_operand(0), rhs, window, - dnums, conv->feature_group_count()); + return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, conv->shape(), + conv->mutable_operand(0), rhs, window, dnums, + conv->feature_group_count()); } // If all else fails, try a forward convolution. if (CanImplementAsCudnnForwardConv(conv)) { - return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0), - conv->mutable_operand(1), conv->window(), - conv->convolution_dimension_numbers(), - conv->feature_group_count()); + return CreateCudnnConv( + kCudnnConvForwardCallTarget, conv->shape(), conv->mutable_operand(0), + conv->mutable_operand(1), conv->window(), + conv->convolution_dimension_numbers(), conv->feature_group_count()); } return nullptr; @@ -489,6 +522,9 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) { return false; } + TF_RETURN_IF_ERROR( + custom_call->set_backend_config(GetDefaultBackendConfig())); + // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out // the conv result and replace `conv` with it. TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( |