aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-09-24 17:44:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 17:48:32 -0700
commit391cdd80952e9cc546d82a8bf2fe7dd04f46cb2f (patch)
treea253b5bd01d2088d07e1cae505028a600d834384
parent9ab01c6732dae1143e22713375a9cc7758216787 (diff)
Add cuDNN fused convolution forward support.
The tests are in the next patch. PiperOrigin-RevId: 214362688
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD6
-rw-r--r--tensorflow/compiler/xla/service/gpu/backend_configs.proto14
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc96
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc15
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.h9
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc6
-rw-r--r--tensorflow/stream_executor/dnn.h4
11 files changed, 169 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 7231fd844e..2775527e0c 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -433,6 +433,7 @@ cc_library(
srcs = ["cudnn_convolution_rewriter.cc"],
hdrs = ["cudnn_convolution_rewriter.h"],
deps = [
+ ":backend_configs",
":ir_emission_utils",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:util",
@@ -597,14 +598,11 @@ cc_library(
hdrs = ["pad_for_tensor_cores.h"],
deps = [
":ir_emission_utils",
- "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/service:hlo_creation_utils",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_pass",
- "//tensorflow/compiler/xla/service:shape_inference",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto
index 640c6392b8..78e14d860e 100644
--- a/tensorflow/compiler/xla/service/gpu/backend_configs.proto
+++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto
@@ -24,4 +24,18 @@ message CudnnConvBackendConfig {
// true, cudnn may choose not to use tensor cores, e.g. because the GPU or
// selected algorithm doesn't support it.
bool tensor_ops_enabled = 2;
+
+ // The scaling factor multiplied with the convolution result.
+ double conv_result_scale = 4;
+
+ // Below are the fields related to cuDNN's fused convolution. Refer to
+ // CudnnConvParams for their meanings.
+
+ // The requested activation (e.g. relu) after the convolution. It is with type
+ // stream_executor::dnn::ActivationMode.
+ int64 activation_mode = 3;
+
+ // The scaling factor multiplied with the side input. If no side input buffer
+ // is provided, this field must be 0.
+ double side_input_scale = 5;
}
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 391456576f..7125673887 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -89,6 +89,7 @@ std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
succ = stream_exec->GetConvolveBackwardDataAlgorithms(true, &algorithms);
break;
case CudnnConvKind::kForward:
+ case CudnnConvKind::kForwardActivation:
succ = stream_exec->GetConvolveAlgorithms(true, &algorithms);
break;
}
@@ -363,8 +364,8 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
backend_config.set_tensor_ops_enabled(tensor_ops_enabled);
HloInstruction* new_call = computation->AddInstruction(
- instr->CloneWithNewOperands(new_call_shape, {instr->mutable_operand(0),
- instr->mutable_operand(1)}));
+ instr->CloneWithNewOperands(new_call_shape, instr->operands()));
+
TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config));
// Repackage new_call so it has the same shape as the original call, namely
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index 2834d47412..ef29237301 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -476,6 +477,12 @@ MatchBackwardInput(HloInstruction* conv) {
return std::make_tuple(true, new_window, dnums, rhs);
}
+CudnnConvBackendConfig GetDefaultBackendConfig() {
+ CudnnConvBackendConfig config;
+ config.set_conv_result_scale(1);
+ return config;
+}
+
// Tries to rewrite a single convolution into a call to cudnn.
StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
@@ -515,6 +522,9 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
return false;
}
+ TF_RETURN_IF_ERROR(
+ custom_call->set_backend_config(GetDefaultBackendConfig()));
+
// The CustomCall returns a tuple (conv_result, scratch_memory). Extract out
// the conv result and replace `conv` with it.
TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 32d67084b3..89dd1bb272 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -40,6 +40,25 @@ using se::dnn::FilterLayout;
using se::dnn::ProfileResult;
struct CudnnConvParams {
+ // Here are the fields related to cuDNN's fused convolution. The result thus
+ // is defined as:
+ // activation(conv_result_scale * conv(x, w) +
+ // side_input_scale * side_input + broadcast(bias))
+ //
+ // The most common fused conv is conv forward + relu/identity, for example.
+ //
+ // bias_buf is a single-dimensional array, with the length equal to the number
+ // of output features. It'll be broadcasted to the output shape in order to be
+ // added to the final results.
+ //
+ // side_input_buf, if valid, must have the same shape as the output buffer.
+ struct FusionParams {
+ se::dnn::ActivationMode mode;
+ double side_input_scale;
+ se::DeviceMemoryBase bias_buf;
+ se::DeviceMemoryBase side_input_buf; // nullable
+ };
+
CudnnConvKind kind;
const Shape* input_shape;
const Shape* filter_shape;
@@ -51,6 +70,9 @@ struct CudnnConvParams {
const ConvolutionDimensionNumbers* dnums;
int64 feature_group_count;
se::dnn::AlgorithmConfig algorithm;
+ double conv_result_scale;
+
+ absl::optional<FusionParams> fusion;
};
// A StreamExecutor ScratchAllocator that wraps a single XLA allocation,
@@ -202,23 +224,73 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params,
switch (kind) {
case CudnnConvKind::kForward:
+ if (params.conv_result_scale != 1) {
+ return InternalError(
+ "StreamExecutor doesn't support scaled convolution: %lf.",
+ params.conv_result_scale);
+ }
stream->ThenConvolveWithAlgorithm(
input_descriptor, input_buf, filter_descriptor, filter_buf,
convolution_descriptor, output_descriptor, &output_buf,
scratch_allocator, algorithm, profile_result);
break;
case CudnnConvKind::kBackwardInput:
+ if (params.conv_result_scale != 1) {
+ return InternalError(
+ "StreamExecutor doesn't support scaled convolution: %lf.",
+ params.conv_result_scale);
+ }
stream->ThenConvolveBackwardDataWithAlgorithm(
filter_descriptor, filter_buf, output_descriptor, output_buf,
convolution_descriptor, input_descriptor, &input_buf,
scratch_allocator, algorithm, profile_result);
break;
case CudnnConvKind::kBackwardFilter:
+ if (params.conv_result_scale != 1) {
+ return InternalError(
+ "StreamExecutor doesn't support scaled convolution: %lf.",
+ params.conv_result_scale);
+ }
stream->ThenConvolveBackwardFilterWithAlgorithm(
input_descriptor, input_buf, output_descriptor, output_buf,
convolution_descriptor, filter_descriptor, &filter_buf,
scratch_allocator, algorithm, profile_result);
break;
+ case CudnnConvKind::kForwardActivation: {
+ BatchDescriptor bias_desc;
+ bias_desc.set_count(1)
+ .set_height(1)
+ .set_width(1)
+ .set_feature_map_count(
+ output_shape.dimensions(dnums.output_feature_dimension()))
+ .set_layout(output_dl);
+
+ se::DeviceMemory<T> side_input(params.fusion->side_input_buf);
+ // If there is no side input, use output as the side input.
+ if (side_input.is_null()) {
+ if (params.fusion->side_input_scale != 0) {
+ return InternalError(
+ "Side input scale is not 0, yet no side input buffer is "
+ "provided");
+ }
+ // Since side-input scale is 0, the values in the side input don't
+ // matter. The simplest thing to do would be to pass in a null buffer
+ // for the side input, but cudnn doesn't allow this. cudnn does promise
+ // that if side-input-scale is 0 the side input won't be read, so we
+ // just pass in the output buffer, since it's handy and has the correct
+ // size.
+ side_input = output_buf;
+ }
+
+ stream->ThenFusedConvolveWithAlgorithm(
+ input_descriptor, input_buf, params.conv_result_scale,
+ filter_descriptor, filter_buf, convolution_descriptor, side_input,
+ params.fusion->side_input_scale, bias_desc,
+ DeviceMemory<T>(params.fusion->bias_buf), params.fusion->mode,
+ output_descriptor, &output_buf, scratch_allocator, algorithm,
+ profile_result);
+ break;
+ }
}
if (!stream->ok()) {
@@ -250,6 +322,7 @@ StatusOr<CudnnConvParams> GetCudnnConvParams(
params.feature_group_count = conv->feature_group_count();
params.algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
backend_config.algorithm(), backend_config.tensor_ops_enabled()));
+ params.conv_result_scale = backend_config.conv_result_scale();
if (target == kCudnnConvForwardCallTarget) {
params.kind = CudnnConvKind::kForward;
@@ -275,6 +348,29 @@ StatusOr<CudnnConvParams> GetCudnnConvParams(
params.input_buf = operand_buffers[0];
params.filter_buf = result_buffer;
params.output_buf = operand_buffers[1];
+ } else if (target == kCudnnConvBiasActivationForwardCallTarget) {
+ params.kind = CudnnConvKind::kForwardActivation;
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &conv_result_shape;
+ params.fusion.emplace();
+ auto& fusion = *params.fusion;
+ if (backend_config.activation_mode() <
+ static_cast<int64>(se::dnn::ActivationMode::kNumActivationModes)) {
+ fusion.mode = static_cast<se::dnn::ActivationMode>(
+ backend_config.activation_mode());
+ } else {
+ return InternalError("Bad activation mode: %s",
+ backend_config.ShortDebugString());
+ }
+ fusion.side_input_scale = backend_config.side_input_scale();
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = result_buffer;
+ params.fusion->bias_buf = operand_buffers[2];
+ if (operand_buffers.size() >= 4) {
+ params.fusion->side_input_buf = operand_buffers[3];
+ }
} else {
return InternalError("Unexpected custom call target: %s", target);
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index 06314e413e..74352f26aa 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -104,6 +104,7 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr));
switch (kind) {
case CudnnConvKind::kForward:
+ case CudnnConvKind::kForwardActivation:
input_shape = &lhs_shape;
filter_shape = &rhs_shape;
output_shape = &result_shape;
@@ -153,6 +154,20 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, instr, 1));
TF_RETURN_IF_ERROR(
constraints->SetBufferLayout(result_shape.layout(), *call_result_buf));
+ // instr->operand(2), if exists, is the bias buffer. There is no need to
+ // assign layout to it, as it has only one dimension.
+
+ // instr->opernad(3), if exists, is the side input buffer.
+ if (instr->operand_count() == 4) {
+ if (kind != CudnnConvKind::kForwardActivation) {
+ return InternalError(
+ "Invalid convolution. Conv has a side input, but kind is not fused "
+ "conv forward: %s",
+ instr->ToString());
+ }
+ // The side input layout must match the output layout.
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(*output_shape, instr, 3));
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 76757faf60..ec3d8f9405 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -129,6 +129,8 @@ const char* const kCudnnConvBackwardInputCallTarget =
"__cudnn$convBackwardInput";
const char* const kCudnnConvBackwardFilterCallTarget =
"__cudnn$convBackwardFilter";
+const char* const kCudnnConvBiasActivationForwardCallTarget =
+ "__cudnn$convBiasActivationForward";
bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
if (hlo.opcode() != HloOpcode::kCustomCall) {
@@ -137,7 +139,8 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
const auto& target = hlo.custom_call_target();
return target == kCudnnConvForwardCallTarget ||
target == kCudnnConvBackwardInputCallTarget ||
- target == kCudnnConvBackwardFilterCallTarget;
+ target == kCudnnConvBackwardFilterCallTarget ||
+ target == kCudnnConvBiasActivationForwardCallTarget;
}
bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
@@ -247,6 +250,9 @@ StatusOr<CudnnConvKind> GetCudnnConvKind(
if (target == kCudnnConvBackwardFilterCallTarget) {
return CudnnConvKind::kBackwardFilter;
}
+ if (target == kCudnnConvBiasActivationForwardCallTarget) {
+ return CudnnConvKind::kForwardActivation;
+ }
return InternalError("Unexpected call target: %s", target);
}
@@ -258,6 +264,8 @@ string CudnnConvKindToString(CudnnConvKind kind) {
return "backward_filter";
case CudnnConvKind::kBackwardInput:
return "backward_input";
+ case CudnnConvKind::kForwardActivation:
+ return "forward with activation";
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index 744346abf3..a64a616ab1 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -44,9 +44,11 @@ namespace gpu {
// "connectivity" (i.e. which elements of the input affect which elements of
// the output) are concerned.
enum class CudnnConvKind {
- kForward, // input + filter => output
- kBackwardInput, // filter + output => input
- kBackwardFilter, // input + output => filter
+ kForward, // input + filter => output
+ kBackwardInput, // filter + output => input
+ kBackwardFilter, // input + output => filter
+ kForwardActivation, // activation(conv(input, filter) + broadcast(bias) +
+ // (optionally) side_input) => output
};
StatusOr<CudnnConvKind> GetCudnnConvKind(const HloCustomCallInstruction* instr);
@@ -119,6 +121,7 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo);
extern const char* const kCudnnConvForwardCallTarget;
extern const char* const kCudnnConvBackwardInputCallTarget;
extern const char* const kCudnnConvBackwardFilterCallTarget;
+extern const char* const kCudnnConvBiasActivationForwardCallTarget;
// Returns true if `hlo` will be implemented as a call to a cuDNN convolution
// routine.
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
index b0061fa655..2d270f630b 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
@@ -209,7 +210,11 @@ static std::vector<HloInstruction*> GetRelevantConvs(HloComputation* comp) {
std::vector<HloInstruction*> convs;
for (HloInstruction* instr : comp->instructions()) {
if (IsCustomCallToDnnConvolution(*instr) &&
- instr->operand(0)->shape().element_type() == F16) {
+ instr->operand(0)->shape().element_type() == F16 &&
+ // TODO(timshen): Disable for fused conv for now. Implement it if it's
+ // needed.
+ Cast<HloCustomCallInstruction>(instr)->custom_call_target() !=
+ kCudnnConvBiasActivationForwardCallTarget) {
convs.push_back(instr);
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index eead408f10..7e77dc9ac6 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -162,8 +162,12 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
// The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract
// out the shape of conv_result.
VLOG(1) << "Canonicalizing forward conv";
+ std::vector<HloInstruction*> operands(conv->operands().begin(),
+ conv->operands().end());
+ operands[0] = new_input;
+ operands[1] = new_kernel;
auto new_conv = conv->parent()->AddInstruction(
- conv->CloneWithNewOperands(conv->shape(), {new_input, new_kernel}));
+ conv->CloneWithNewOperands(conv->shape(), operands));
new_conv->set_window(new_conv_window);
VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n "
<< new_conv->ToString();
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 9abfa1db6a..621b155240 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -873,7 +873,7 @@ class NormalizeDescriptor {
// Describes a kind of non-linearity (threshold-like mathematical function).
enum class ActivationMode {
- kNone,
+ kNone = 0,
kSigmoid,
// Rectified linear activation: f(x) = x < 0 ? 0 : x
kRelu,
@@ -885,6 +885,8 @@ enum class ActivationMode {
kTanh,
// Like ReluX, but passes all values in the range [-X,X].
kBandPass,
+
+ kNumActivationModes, // Always in the end.
};
// Returns a string representation of the given activation mode.