aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-09-24 12:07:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 12:12:18 -0700
commitf361fb8e4b4a9838e60a11ab45391c308bcb90da (patch)
treecd9d7bce33362a5e417f1df59ff6c55ef92eaaee
parent28eeda839f124cf5ba648576e86214b38141e4ab (diff)
Further simplify the cuDNN wrappers. Instead of passing around
CudnnConvParams, just pass around the HloInstruction. This is based on the observation that most code doesn't care about the convolution semantics like which operand is input vs filter vs output. In fact, only layout assignment and conv runner care about them. PiperOrigin-RevId: 214307399
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD6
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc49
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc72
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc92
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h55
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc81
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc55
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.h31
10 files changed, 219 insertions, 227 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index cbee4db06e..7231fd844e 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -371,7 +371,6 @@ cc_library(
hdrs = ["ir_emission_utils.h"],
deps = [
":backend_configs",
- ":cudnn_convolution_runner",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
@@ -412,6 +411,8 @@ cc_library(
srcs = ["cudnn_convolution_runner.cc"],
hdrs = ["cudnn_convolution_runner.h"],
deps = [
+ ":backend_configs",
+ ":ir_emission_utils",
":stream_executor_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
@@ -420,8 +421,10 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -781,6 +784,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:layout_assignment",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 85f3682a5a..4effea637d 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -44,52 +44,23 @@ ConvolutionThunk::ConvolutionThunk(
Status ConvolutionThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
- CudnnConvParams params;
- TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, &params));
-
- switch (params.kind) {
- case CudnnConvKind::kForward:
- params.input_buf =
- buffer_allocations.GetDeviceAddress(operand_buffers_[0]);
- params.filter_buf =
- buffer_allocations.GetDeviceAddress(operand_buffers_[1]);
- params.output_buf = buffer_allocations.GetDeviceAddress(result_buffer_);
- break;
- case CudnnConvKind::kBackwardInput:
- params.input_buf = buffer_allocations.GetDeviceAddress(result_buffer_);
- params.filter_buf =
- buffer_allocations.GetDeviceAddress(operand_buffers_[1]);
- params.output_buf =
- buffer_allocations.GetDeviceAddress(operand_buffers_[0]);
- break;
- case CudnnConvKind::kBackwardFilter:
- params.input_buf =
- buffer_allocations.GetDeviceAddress(operand_buffers_[0]);
- params.filter_buf = buffer_allocations.GetDeviceAddress(result_buffer_);
- params.output_buf =
- buffer_allocations.GetDeviceAddress(operand_buffers_[1]);
- break;
+ std::vector<se::DeviceMemoryBase> operand_se_buffers;
+ for (const auto& buffer : operand_buffers_) {
+ operand_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer));
}
+ se::DeviceMemoryBase result_buffer =
+ buffer_allocations.GetDeviceAddress(result_buffer_);
+
se::DeviceMemoryBase scratch =
buffer_allocations.GetDeviceAddress(scratch_buffer_);
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
- TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream));
+ TF_RETURN_IF_ERROR(RunCudnnConvolution(cudnn_call_,
+ absl::MakeSpan(operand_se_buffers),
+ result_buffer, scratch, stream));
- // Figure out which of output/input/filter is the result produced by
- // this op, and write the result tuple.
- void* result_ptr = [&] {
- switch (params.kind) {
- case CudnnConvKind::kForward:
- return params.output_buf.opaque();
- case CudnnConvKind::kBackwardInput:
- return params.input_buf.opaque();
- case CudnnConvKind::kBackwardFilter:
- return params.filter_buf.opaque();
- }
- }();
- void* ptrs[] = {result_ptr, scratch.opaque()};
+ void* ptrs[] = {result_buffer.opaque(), scratch.opaque()};
se::DeviceMemory<void*> tuple_addr(
buffer_allocations.GetDeviceAddress(tuple_result_buffer_));
stream->ThenMemcpyH2D<void*>(ptrs, &tuple_addr);
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 9eee9ebbd7..391456576f 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -146,19 +146,11 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
// caching would speed up compilation a lot.
StatusOr<std::tuple<int64, bool, int64>>
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
- const HloCustomCallInstruction* instr) {
- CudnnConvParams params;
- TF_RETURN_IF_ERROR(PopulateCudnnConvParams(instr, &params));
-
- const Shape& input_shape = *params.input_shape;
- const Shape& filter_shape = *params.filter_shape;
- const Shape& output_shape = *params.output_shape;
-
- CHECK_EQ(input_shape.element_type(), filter_shape.element_type());
- CHECK_EQ(input_shape.element_type(), output_shape.element_type());
+ HloCustomCallInstruction* instr) {
// TODO(timshen): for now only check fp16. It can be expanded to other types,
// with some work on the HLO routines.
- const bool cross_check_enabled = input_shape.element_type() == xla::F16;
+ const bool cross_check_enabled =
+ instr->shape().tuple_shapes(0).element_type() == xla::F16;
// Don't run this function concurrently on the same GPU.
//
@@ -226,48 +218,43 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
// use a ScratchAllocator for this instead of calling allocator_ directly so
// that our allocations don't leak.
ScratchAllocator input_output_allocator(device_ordinal, allocator);
- TF_ASSIGN_OR_RETURN(params.input_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(input_shape)));
- TF_ASSIGN_OR_RETURN(params.filter_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(filter_shape)));
- TF_ASSIGN_OR_RETURN(params.output_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(output_shape)));
-
- initialize_buffer(params.input_buf);
- initialize_buffer(params.filter_buf);
- initialize_buffer(params.output_buf);
-
- DeviceMemoryBase* result_buf = [&] {
- switch (params.kind) {
- case CudnnConvKind::kBackwardFilter:
- return &params.filter_buf;
- case CudnnConvKind::kBackwardInput:
- return &params.input_buf;
- case CudnnConvKind::kForward:
- return &params.output_buf;
- }
- }();
+ std::vector<se::DeviceMemoryBase> operand_buffers;
+ for (const auto* operand : instr->operands()) {
+ TF_ASSIGN_OR_RETURN(auto buffer,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(operand->shape())));
+ initialize_buffer(buffer);
+ operand_buffers.push_back(buffer);
+ }
+ TF_ASSIGN_OR_RETURN(
+ auto result_buffer,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0))));
+ initialize_buffer(result_buffer);
se::dnn::ProfileResult best_result;
int64 best_result_bytes_used = 0;
+ TF_ASSIGN_OR_RETURN(auto backend_config,
+ instr->backend_config<CudnnConvBackendConfig>());
optional<F16BufferComparator> comparator;
// Use the first algorithm that's supported as reference. There isn't a
// particular reason to use it, as any algorithm sufficies. It doesn't make
// this algorithm considered correct, though.
optional<AlgorithmDesc> first_algorithm;
- for (const AlgorithmDesc& alg : GetAlgorithms(params.kind, stream_exec_)) {
+ TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr));
+ for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) {
ScratchAllocator scratch_allocator(device_ordinal, allocator);
se::dnn::ProfileResult profile_result;
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
<< instr->ToString();
- params.algorithm = AlgorithmConfig(alg);
- bool launch_ok = RunCudnnConvolution(params, &scratch_allocator, &stream,
- &profile_result)
+ backend_config.set_algorithm(alg.algo_id());
+ backend_config.set_tensor_ops_enabled(alg.tensor_ops_enabled());
+ TF_RETURN_IF_ERROR(instr->set_backend_config(backend_config));
+ bool launch_ok = RunCudnnConvolution(instr, absl::MakeSpan(operand_buffers),
+ result_buffer, &scratch_allocator,
+ &stream, &profile_result)
.ok();
if (launch_ok && profile_result.is_valid()) {
@@ -278,7 +265,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
.xla_gpu_crash_on_verification_failures();
if (comparator.has_value()) {
StatusOr<bool> result = comparator->CompareEqual(
- se::DeviceMemory<Eigen::half>(*result_buf));
+ se::DeviceMemory<Eigen::half>(result_buffer));
if (!result.ok()) {
LOG(ERROR) << "Unable to compare "
<< AlgorithmToString(*first_algorithm) << " against "
@@ -296,7 +283,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
}
} else if (cross_check_enabled) {
auto comp = F16BufferComparator::Create(
- se::DeviceMemory<Eigen::half>(*result_buf), compiler_, allocator,
+ se::DeviceMemory<Eigen::half>(result_buffer), compiler_, allocator,
&stream);
if (comp.ok()) {
comparator.emplace(comp.ConsumeValueOrDie());
@@ -370,7 +357,8 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0),
ShapeUtil::MakeShape(U8, {scratch_bytes})});
- CudnnConvBackendConfig backend_config;
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
+ instr->backend_config<CudnnConvBackendConfig>());
backend_config.set_algorithm(algorithm);
backend_config.set_tensor_ops_enabled(tensor_ops_enabled);
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index ce0189543c..aeda2fc7f8 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -50,7 +50,7 @@ class CudnnConvolutionAlgorithmPicker : public HloModulePass {
StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm(
- const HloCustomCallInstruction* instr);
+ HloCustomCallInstruction* instr);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 3310ee848e..32d67084b3 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -37,6 +39,20 @@ using se::dnn::FilterDescriptor;
using se::dnn::FilterLayout;
using se::dnn::ProfileResult;
+struct CudnnConvParams {
+ CudnnConvKind kind;
+ const Shape* input_shape;
+ const Shape* filter_shape;
+ const Shape* output_shape;
+ se::DeviceMemoryBase input_buf;
+ se::DeviceMemoryBase filter_buf;
+ se::DeviceMemoryBase output_buf;
+ const Window* window;
+ const ConvolutionDimensionNumbers* dnums;
+ int64 feature_group_count;
+ se::dnn::AlgorithmConfig algorithm;
+};
+
// A StreamExecutor ScratchAllocator that wraps a single XLA allocation,
// returning it (in its entirety) the first time Allocate() is called.
class ScratchBufAllocator : public se::ScratchAllocator {
@@ -214,32 +230,80 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params,
return Status::OK();
}
-} // anonymous namespace
+// Returns the cudnn convolution parameters generated from conv, which must be a
+// custom-call to a cudnn convolution.
+StatusOr<CudnnConvParams> GetCudnnConvParams(
+ const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer) {
+ CudnnConvParams params;
-string CudnnConvKindToString(CudnnConvKind kind) {
- switch (kind) {
- case CudnnConvKind::kForward:
- return "forward";
- case CudnnConvKind::kBackwardFilter:
- return "backward_filter";
- case CudnnConvKind::kBackwardInput:
- return "backward_input";
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
+ conv->backend_config<CudnnConvBackendConfig>());
+ const auto& target = conv->custom_call_target();
+ const auto& lhs_shape = conv->operand(0)->shape();
+ const auto& rhs_shape = conv->operand(1)->shape();
+ const auto& conv_result_shape = conv->shape().tuple_shapes(0);
+
+ params.window = &conv->window();
+ params.dnums = &conv->convolution_dimension_numbers();
+ params.feature_group_count = conv->feature_group_count();
+ params.algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
+ backend_config.algorithm(), backend_config.tensor_ops_enabled()));
+
+ if (target == kCudnnConvForwardCallTarget) {
+ params.kind = CudnnConvKind::kForward;
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &conv_result_shape;
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = result_buffer;
+ } else if (target == kCudnnConvBackwardInputCallTarget) {
+ params.kind = CudnnConvKind::kBackwardInput;
+ params.input_shape = &conv_result_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &lhs_shape;
+ params.input_buf = result_buffer;
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = operand_buffers[0];
+ } else if (target == kCudnnConvBackwardFilterCallTarget) {
+ params.kind = CudnnConvKind::kBackwardFilter;
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &conv_result_shape;
+ params.output_shape = &rhs_shape;
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = result_buffer;
+ params.output_buf = operand_buffers[1];
+ } else {
+ return InternalError("Unexpected custom call target: %s", target);
}
+ return params;
}
-Status RunCudnnConvolution(CudnnConvParams params,
+} // anonymous namespace
+
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::DeviceMemoryBase scratch_buf, se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
ScratchBufAllocator scratch_allocator(scratch_buf);
- return RunCudnnConvolution(params, &scratch_allocator, stream,
- profile_result);
+ return RunCudnnConvolution(conv, operand_buffers, result_buffer,
+ &scratch_allocator, stream, profile_result);
}
-Status RunCudnnConvolution(CudnnConvParams params,
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::ScratchAllocator* scratch_allocator,
se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
- PrimitiveType output_primitive_type = params.output_shape->element_type();
+ TF_ASSIGN_OR_RETURN(CudnnConvParams params,
+ GetCudnnConvParams(conv, operand_buffers, result_buffer));
+
+ PrimitiveType output_primitive_type =
+ conv->shape().tuple_shapes(0).element_type();
switch (output_primitive_type) {
case F16:
return RunCudnnConvolutionImpl<Eigen::half>(params, scratch_allocator,
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
index 381aa37a1b..61aec1cecc 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
@@ -16,6 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -27,52 +30,8 @@ namespace gpu {
// This file contains low-level routines for running cudnn convolutions.
-// Different types of convolutions supported by cudnn.
-//
-// A way to think about these is that a convolution is defined by three arrays
-// -- the "input", the "filter", and the "output" -- and given any two of these,
-// we can compute the third. For example, a backward-input convolution takes as
-// input a filter and an "output" and produces an "input" such that if one were
-// to do a forward convolution of "input" using filter, the result would be
-// something with the same shape as "output".
-//
-// This way of thinking is not correct if you look at the values produced. For
-// example, a backward-input convolution is not actually the mathematical
-// inverse of a forward convolution. But it's right as far as the shapes and
-// "connectivity" (i.e. which elements of the input affect which elements of
-// the output) are concerned.
-enum class CudnnConvKind {
- kForward, // input + filter => output
- kBackwardInput, // filter + output => input
- kBackwardFilter, // input + output => filter
-};
-
-struct CudnnConvParams {
- CudnnConvKind kind;
- const Shape* input_shape;
- const Shape* filter_shape;
- const Shape* output_shape;
- se::DeviceMemoryBase input_buf;
- se::DeviceMemoryBase filter_buf;
- se::DeviceMemoryBase output_buf;
- const Window* window;
- const ConvolutionDimensionNumbers* dnums;
- int64 feature_group_count;
- se::dnn::AlgorithmConfig algorithm;
-};
-
-// Converts a CudnnConvKind value to a string.
-string CudnnConvKindToString(CudnnConvKind kind);
-
// Calls into cudnn to run the specified convolution.
//
-// Note that depending on the value of CudnnConvKind, the result of this call
-// may be written into input_buf, filter_buf, or output_buf!
-//
-// At the moment convolution with half data type is implemented with cudnn
-// PSEUDO_HALF configuration, that is, the input values are half and the
-// internal computation type is float.
-//
// We provide one overload which takes a scratch buffer, and another which takes
// an allocator which is responsible for allocating the scratch space. In
// theory the second one shouldn't be necessary -- users of this function could
@@ -83,11 +42,15 @@ string CudnnConvKindToString(CudnnConvKind kind);
// allocator and take note of how much memory is used. The next time you call
// the same conv, you can provide an explicitly preallocated scratch buffer of
// that size, if you like.
-Status RunCudnnConvolution(CudnnConvParams params,
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::DeviceMemoryBase scratch_buf, se::Stream* stream,
se::dnn::ProfileResult* profile_result = nullptr);
-Status RunCudnnConvolution(CudnnConvParams params,
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::ScratchAllocator* scratch_allocator,
se::Stream* stream,
se::dnn::ProfileResult* profile_result = nullptr);
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index d033faee8d..06314e413e 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -21,8 +21,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_options.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -90,27 +92,32 @@ HeuristicLayoutAssignment(const HloInstruction* instr,
// operands and the output shape. Depending on the underlying algorithm, one of
// { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen.
Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
- HloInstruction* instr, LayoutConstraints* constraints) {
- CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString();
- Shape input_shape;
- Shape filter_shape;
- Shape output_shape;
- const auto& target = instr->custom_call_target();
- if (target == kCudnnConvForwardCallTarget) {
- input_shape = instr->operand(0)->shape();
- filter_shape = instr->operand(1)->shape();
- output_shape = instr->shape().tuple_shapes(0);
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- input_shape = instr->shape().tuple_shapes(0);
- filter_shape = instr->operand(1)->shape();
- output_shape = instr->operand(0)->shape();
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- input_shape = instr->operand(0)->shape();
- filter_shape = instr->shape().tuple_shapes(0);
- output_shape = instr->operand(1)->shape();
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << instr->custom_call_target();
+ HloCustomCallInstruction* instr, LayoutConstraints* constraints) {
+ Shape lhs_shape = instr->operand(0)->shape();
+ Shape rhs_shape = instr->operand(1)->shape();
+ Shape result_shape = instr->shape().tuple_shapes(0);
+
+ Shape* input_shape;
+ Shape* filter_shape;
+ Shape* output_shape;
+
+ TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr));
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ input_shape = &lhs_shape;
+ filter_shape = &rhs_shape;
+ output_shape = &result_shape;
+ break;
+ case CudnnConvKind::kBackwardInput:
+ input_shape = &result_shape;
+ filter_shape = &rhs_shape;
+ output_shape = &lhs_shape;
+ break;
+ case CudnnConvKind::kBackwardFilter:
+ input_shape = &lhs_shape;
+ filter_shape = &result_shape;
+ output_shape = &rhs_shape;
+ break;
}
{
@@ -127,8 +134,9 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
}
TF_ASSIGN_OR_RETURN(
- std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(),
- *output_shape.mutable_layout()),
+ std::tie(*input_shape->mutable_layout(),
+ *filter_shape->mutable_layout(),
+ *output_shape->mutable_layout()),
StreamExecutorConvLayoutsToXlaLayouts(
instr->convolution_dimension_numbers(), input, filter, output));
}
@@ -141,25 +149,10 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
instr, /*index=*/{0}));
// Set layouts of the instructions' shapes.
- if (target == kCudnnConvForwardCallTarget) {
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0));
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1));
- TF_RETURN_IF_ERROR(
- constraints->SetBufferLayout(output_shape.layout(), *call_result_buf));
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 0));
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1));
- TF_RETURN_IF_ERROR(
- constraints->SetBufferLayout(input_shape.layout(), *call_result_buf));
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0));
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 1));
- TF_RETURN_IF_ERROR(
- constraints->SetBufferLayout(filter_shape.layout(), *call_result_buf));
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << instr->custom_call_target();
- }
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, instr, 0));
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, instr, 1));
+ TF_RETURN_IF_ERROR(
+ constraints->SetBufferLayout(result_shape.layout(), *call_result_buf));
return Status::OK();
}
@@ -173,8 +166,8 @@ Status GpuLayoutAssignment::AddBackendConstraints(
++iterator) {
HloInstruction* instruction = *iterator;
if (IsCustomCallToDnnConvolution(*instruction)) {
- TF_RETURN_IF_ERROR(
- AddBackendConstraintsToDnnConvCustomCall(instruction, constraints));
+ TF_RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall(
+ Cast<HloCustomCallInstruction>(instruction), constraints));
}
// For batched dot we require the default layout.
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
index ce24af1cf8..e2b96a81d4 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_
#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/layout_assignment.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -47,7 +48,7 @@ class GpuLayoutAssignment : public LayoutAssignment {
private:
Status AddBackendConstraintsToDnnConvCustomCall(
- HloInstruction* instr, LayoutConstraints* constraints);
+ HloCustomCallInstruction* instr, LayoutConstraints* constraints);
se::StreamExecutor* stream_executor_;
};
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 22f43bc08b..b57ac5fd09 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -288,41 +288,30 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
value->getType());
}
-Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
- CudnnConvParams* params) {
- TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
- custom_call->backend_config<CudnnConvBackendConfig>());
- const auto& target = custom_call->custom_call_target();
- const auto& lhs_shape = custom_call->operand(0)->shape();
- const auto& rhs_shape = custom_call->operand(1)->shape();
- const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
-
- params->window = &custom_call->window();
- params->dnums = &custom_call->convolution_dimension_numbers();
- params->feature_group_count = custom_call->feature_group_count();
- params->algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
- backend_config.algorithm(), backend_config.tensor_ops_enabled()));
-
+StatusOr<CudnnConvKind> GetCudnnConvKind(
+ const HloCustomCallInstruction* instr) {
+ absl::string_view target = instr->custom_call_target();
if (target == kCudnnConvForwardCallTarget) {
- params->kind = CudnnConvKind::kForward;
- params->input_shape = &lhs_shape;
- params->filter_shape = &rhs_shape;
- params->output_shape = &conv_result_shape;
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- params->kind = CudnnConvKind::kBackwardInput;
- params->input_shape = &conv_result_shape;
- params->filter_shape = &rhs_shape;
- params->output_shape = &lhs_shape;
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- params->kind = CudnnConvKind::kBackwardFilter;
- params->input_shape = &lhs_shape;
- params->filter_shape = &conv_result_shape;
- params->output_shape = &rhs_shape;
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << custom_call->custom_call_target();
+ return CudnnConvKind::kForward;
+ }
+ if (target == kCudnnConvBackwardInputCallTarget) {
+ return CudnnConvKind::kBackwardInput;
+ }
+ if (target == kCudnnConvBackwardFilterCallTarget) {
+ return CudnnConvKind::kBackwardFilter;
+ }
+ return InternalError("Unexpected call target: %s", target);
+}
+
+string CudnnConvKindToString(CudnnConvKind kind) {
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ return "forward";
+ case CudnnConvKind::kBackwardFilter:
+ return "backward_filter";
+ case CudnnConvKind::kBackwardInput:
+ return "backward_input";
}
- return Status::OK();
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index 09c455cc1e..19bd3c6330 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -20,7 +20,6 @@ limitations under the License.
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
-#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
@@ -30,6 +29,31 @@ limitations under the License.
namespace xla {
namespace gpu {
+// Different types of convolutions supported by cudnn.
+//
+// A way to think about these is that a convolution is defined by three arrays
+// -- the "input", the "filter", and the "output" -- and given any two of these,
+// we can compute the third. For example, a backward-input convolution takes as
+// input a filter and an "output" and produces an "input" such that if one were
+// to do a forward convolution of "input" using filter, the result would be
+// something with the same shape as "output".
+//
+// This way of thinking is not correct if you look at the values produced. For
+// example, a backward-input convolution is not actually the mathematical
+// inverse of a forward convolution. But it's right as far as the shapes and
+// "connectivity" (i.e. which elements of the input affect which elements of
+// the output) are concerned.
+enum class CudnnConvKind {
+ kForward, // input + filter => output
+ kBackwardInput, // filter + output => input
+ kBackwardFilter, // input + output => filter
+};
+
+StatusOr<CudnnConvKind> GetCudnnConvKind(const HloCustomCallInstruction* instr);
+
+// Converts a CudnnConvKind value to a string.
+string CudnnConvKindToString(CudnnConvKind kind);
+
constexpr int64 kWarpSize = 32;
// Returns true if `hlo` will be implemented as a call to BLAS gemm.
@@ -150,11 +174,6 @@ llvm::Value* EmitPrintf(absl::string_view fmt,
llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
llvm::IRBuilder<>* builder);
-// Populates params using conv, which must be a custom-call to a cudnn
-// convolution. Does not modify any buffers in the params.
-Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
- CudnnConvParams* params);
-
} // namespace gpu
} // namespace xla