aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-09-24 12:07:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 12:12:18 -0700
commitf361fb8e4b4a9838e60a11ab45391c308bcb90da (patch)
treecd9d7bce33362a5e417f1df59ff6c55ef92eaaee /tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
parent28eeda839f124cf5ba648576e86214b38141e4ab (diff)
Further simplify the cuDNN wrappers. Instead of passing around
CudnnConvParams, just pass around the HloInstruction. This is based on the observation that most code doesn't care about the convolution semantics like which operand is input vs filter vs output. In fact, only layout assignment and conv runner care about them. PiperOrigin-RevId: 214307399
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc92
1 files changed, 78 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 3310ee848e..32d67084b3 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -37,6 +39,20 @@ using se::dnn::FilterDescriptor;
using se::dnn::FilterLayout;
using se::dnn::ProfileResult;
+struct CudnnConvParams {
+ CudnnConvKind kind;
+ const Shape* input_shape;
+ const Shape* filter_shape;
+ const Shape* output_shape;
+ se::DeviceMemoryBase input_buf;
+ se::DeviceMemoryBase filter_buf;
+ se::DeviceMemoryBase output_buf;
+ const Window* window;
+ const ConvolutionDimensionNumbers* dnums;
+ int64 feature_group_count;
+ se::dnn::AlgorithmConfig algorithm;
+};
+
// A StreamExecutor ScratchAllocator that wraps a single XLA allocation,
// returning it (in its entirety) the first time Allocate() is called.
class ScratchBufAllocator : public se::ScratchAllocator {
@@ -214,32 +230,80 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params,
return Status::OK();
}
-} // anonymous namespace
+// Returns the cudnn convolution parameters generated from conv, which must be a
+// custom-call to a cudnn convolution.
+StatusOr<CudnnConvParams> GetCudnnConvParams(
+ const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer) {
+ CudnnConvParams params;
-string CudnnConvKindToString(CudnnConvKind kind) {
- switch (kind) {
- case CudnnConvKind::kForward:
- return "forward";
- case CudnnConvKind::kBackwardFilter:
- return "backward_filter";
- case CudnnConvKind::kBackwardInput:
- return "backward_input";
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
+ conv->backend_config<CudnnConvBackendConfig>());
+ const auto& target = conv->custom_call_target();
+ const auto& lhs_shape = conv->operand(0)->shape();
+ const auto& rhs_shape = conv->operand(1)->shape();
+ const auto& conv_result_shape = conv->shape().tuple_shapes(0);
+
+ params.window = &conv->window();
+ params.dnums = &conv->convolution_dimension_numbers();
+ params.feature_group_count = conv->feature_group_count();
+ params.algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
+ backend_config.algorithm(), backend_config.tensor_ops_enabled()));
+
+ if (target == kCudnnConvForwardCallTarget) {
+ params.kind = CudnnConvKind::kForward;
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &conv_result_shape;
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = result_buffer;
+ } else if (target == kCudnnConvBackwardInputCallTarget) {
+ params.kind = CudnnConvKind::kBackwardInput;
+ params.input_shape = &conv_result_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &lhs_shape;
+ params.input_buf = result_buffer;
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = operand_buffers[0];
+ } else if (target == kCudnnConvBackwardFilterCallTarget) {
+ params.kind = CudnnConvKind::kBackwardFilter;
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &conv_result_shape;
+ params.output_shape = &rhs_shape;
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = result_buffer;
+ params.output_buf = operand_buffers[1];
+ } else {
+ return InternalError("Unexpected custom call target: %s", target);
}
+ return params;
}
-Status RunCudnnConvolution(CudnnConvParams params,
+} // anonymous namespace
+
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::DeviceMemoryBase scratch_buf, se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
ScratchBufAllocator scratch_allocator(scratch_buf);
- return RunCudnnConvolution(params, &scratch_allocator, stream,
- profile_result);
+ return RunCudnnConvolution(conv, operand_buffers, result_buffer,
+ &scratch_allocator, stream, profile_result);
}
-Status RunCudnnConvolution(CudnnConvParams params,
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::ScratchAllocator* scratch_allocator,
se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
- PrimitiveType output_primitive_type = params.output_shape->element_type();
+ TF_ASSIGN_OR_RETURN(CudnnConvParams params,
+ GetCudnnConvParams(conv, operand_buffers, result_buffer));
+
+ PrimitiveType output_primitive_type =
+ conv->shape().tuple_shapes(0).element_type();
switch (output_primitive_type) {
case F16:
return RunCudnnConvolutionImpl<Eigen::half>(params, scratch_allocator,