aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc58
1 files changed, 47 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index 3d1266355b..ef29237301 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -35,6 +36,32 @@ namespace gpu {
namespace {
+HloInstruction* CreateCudnnConv(const char* call_target, const Shape& shape,
+ HloInstruction* lhs, HloInstruction* rhs,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count) {
+ HloComputation* computation = lhs->parent();
+
+ // This call returns a tuple of (conv_result, scratch_memory), where
+ // conv_result is the actual result of the convolution, and scratch_memory is
+ // temporary memory used by cudnn.
+ //
+ // At the moment, we don't know how much scratch memory this conv is going to
+ // use, so we put u8[0] in this place. Later on another pass will choose
+ // which conv algorithm to use, and at that point we'll modify the shape of
+ // this second tuple element.
+ Shape call_shape =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
+
+ HloInstruction* custom_call = computation->AddInstruction(
+ HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
+ custom_call->set_window(window);
+ custom_call->set_convolution_dimension_numbers(dnums);
+ custom_call->set_feature_group_count(feature_group_count);
+ return custom_call;
+}
+
bool CanImplementAsCudnnForwardConv(HloInstruction* conv) {
const ConvolutionDimensionNumbers& dnums =
conv->convolution_dimension_numbers();
@@ -263,7 +290,7 @@ MatchBackwardInput(HloInstruction* conv) {
!(window_util::HasBaseDilation(conv->window()) &&
(reverse_filter->IsConstant() || is_1x1_filter))) {
VLOG(1) << "Can't match to backwards convolution. Either filter is not "
- "kReverse, or it's not a base-dialted conv with a 1x1 or "
+ "kReverse, or it's not a base-dilated conv with a 1x1 or "
"constant filter.";
return no_match_result;
}
@@ -450,6 +477,12 @@ MatchBackwardInput(HloInstruction* conv) {
return std::make_tuple(true, new_window, dnums, rhs);
}
+CudnnConvBackendConfig GetDefaultBackendConfig() {
+ CudnnConvBackendConfig config;
+ config.set_conv_result_scale(1);
+ return config;
+}
+
// Tries to rewrite a single convolution into a call to cudnn.
StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
@@ -462,24 +495,24 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
std::tie(match, window, dnums) = MatchBackwardFilter(conv);
if (match) {
- return CreateCudnnConvBackwardFilter(
- conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1),
- window, dnums, conv->feature_group_count());
+ return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(),
+ conv->mutable_operand(0), conv->mutable_operand(1),
+ window, dnums, conv->feature_group_count());
}
std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv);
if (match) {
- return CreateCudnnConvBackwardInput(conv->shape(),
- conv->mutable_operand(0), rhs, window,
- dnums, conv->feature_group_count());
+ return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, conv->shape(),
+ conv->mutable_operand(0), rhs, window, dnums,
+ conv->feature_group_count());
}
// If all else fails, try a forward convolution.
if (CanImplementAsCudnnForwardConv(conv)) {
- return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0),
- conv->mutable_operand(1), conv->window(),
- conv->convolution_dimension_numbers(),
- conv->feature_group_count());
+ return CreateCudnnConv(
+ kCudnnConvForwardCallTarget, conv->shape(), conv->mutable_operand(0),
+ conv->mutable_operand(1), conv->window(),
+ conv->convolution_dimension_numbers(), conv->feature_group_count());
}
return nullptr;
@@ -489,6 +522,9 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
return false;
}
+ TF_RETURN_IF_ERROR(
+ custom_call->set_backend_config(GetDefaultBackendConfig()));
+
// The CustomCall returns a tuple (conv_result, scratch_memory). Extract out
// the conv result and replace `conv` with it.
TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(