aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-09-10 18:05:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 18:08:04 -0700
commitde683c50d039676e36b6a718e4cc7ed2170a8a2f (patch)
tree0da940da1d16d0c14f648b208e233f45bb21dde1
parentc300a579be9c4adb3736f3551b35826f3f27b0f8 (diff)
Simplify convolution_thunk's interface.
PiperOrigin-RevId: 212370999
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc54
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h55
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc38
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc61
6 files changed, 96 insertions, 122 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index af953a2a16..aab8d0fdca 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -174,6 +174,7 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service:while_loop_analysis",
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
@@ -371,6 +372,8 @@ cc_library(
srcs = ["ir_emission_utils.cc"],
hdrs = ["ir_emission_utils.h"],
deps = [
+ ":backend_configs",
+ ":cudnn_convolution_runner",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 9b567cf4a8..3a23ac1d63 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
@@ -30,63 +31,32 @@ namespace gpu {
using se::dnn::AlgorithmDesc;
-ConvolutionThunk::ConvolutionThunk(
- CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer,
- const BufferAllocation::Slice& filter_buffer,
- const BufferAllocation::Slice& output_buffer,
- const BufferAllocation::Slice& tuple_result_buffer,
- const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape,
- const Shape& filter_shape, const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dim_nums, int64 feature_group_count,
- int64 algorithm, bool tensor_ops_enabled, const HloInstruction* hlo)
- : Thunk(Kind::kConvolution, hlo),
- convolution_kind_(convolution_kind),
- input_buffer_(input_buffer),
- filter_buffer_(filter_buffer),
- output_buffer_(output_buffer),
- tuple_result_buffer_(tuple_result_buffer),
- scratch_buffer_(scratch_buffer),
- input_shape_(input_shape),
- filter_shape_(filter_shape),
- output_shape_(output_shape),
- window_(window),
- dim_nums_(dim_nums),
- feature_group_count_(feature_group_count),
- algorithm_(algorithm),
- tensor_ops_enabled_(tensor_ops_enabled) {}
-
Status ConvolutionThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
- se::DeviceMemoryBase input_data =
- buffer_allocations.GetDeviceAddress(input_buffer_);
- se::DeviceMemoryBase filter_data =
- buffer_allocations.GetDeviceAddress(filter_buffer_);
- se::DeviceMemoryBase output_data =
- buffer_allocations.GetDeviceAddress(output_buffer_);
+ CudnnConvParams params;
+
+ params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_);
+ params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_);
+ params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_);
se::DeviceMemoryBase scratch =
buffer_allocations.GetDeviceAddress(scratch_buffer_);
- se::dnn::AlgorithmConfig algorithm_config(
- se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_));
+ TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, &params));
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
- TF_RETURN_IF_ERROR(RunCudnnConvolution(
- {convolution_kind_, &input_shape_, &filter_shape_, &output_shape_,
- input_data, filter_data, output_data, &window_, &dim_nums_,
- feature_group_count_, algorithm_config},
- scratch, stream));
+ TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream));
// Figure out which of output/input/filter is the result produced by
// this op, and write the result tuple.
void* result_ptr = [&] {
- switch (convolution_kind_) {
+ switch (params.kind) {
case CudnnConvKind::kForward:
- return output_data.opaque();
+ return params.output_buf.opaque();
case CudnnConvKind::kBackwardInput:
- return input_data.opaque();
+ return params.input_buf.opaque();
case CudnnConvKind::kBackwardFilter:
- return filter_data.opaque();
+ return params.filter_buf.opaque();
}
}();
void* ptrs[] = {result_ptr, scratch.opaque()};
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index 68d67c40c5..d7d1f91fba 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -32,7 +33,7 @@ limitations under the License.
namespace xla {
namespace gpu {
-// This class stores everything that StreamExecutor needs to launch a BNN
+// This class stores everything that StreamExecutor needs to launch a DNN
// convolution. It is generated by IrEmitter.
//
// This is thread-compatible.
@@ -41,27 +42,24 @@ class ConvolutionThunk : public Thunk {
// Constructs a thunk for launching a DNN convolution. When run, it will
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
//
- // `algorithm` is a cudnn algorithm number. `algorithm == -1` indicates that
- // we should use the default (i.e. baseline) cudnn algorithm.
- //
// Note that "output" here doesn't refer to the output from running this
// thunk, but rather to the "output" of a hypothetical forward convolution
// that corresponds to this input+filter+output triple. That is, the result
// generated by this thunk is "output" for forward convs, "input" for
// backward-input convs, and "filter" for backward-filter convs.
- //
- // Semantics of null hlo_instruction argument are as in Thunk.
- ConvolutionThunk(CudnnConvKind convolution_kind,
- const BufferAllocation::Slice& input_buffer,
- const BufferAllocation::Slice& filter_buffer,
- const BufferAllocation::Slice& output_buffer,
- const BufferAllocation::Slice& tuple_result_buffer,
- const BufferAllocation::Slice& scratch_buffer,
- const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dim_nums,
- int64 feature_group_count, int64 algorithm,
- bool tensor_ops_enabled, const HloInstruction* hlo);
+ ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
+ BufferAllocation::Slice input_slice,
+ BufferAllocation::Slice filter_slice,
+ BufferAllocation::Slice output_slice,
+ BufferAllocation::Slice scratch_slice,
+ BufferAllocation::Slice tuple_result_slice)
+ : Thunk(Kind::kConvolution, cudnn_call),
+ cudnn_call_(cudnn_call),
+ input_buffer_(std::move(input_slice)),
+ filter_buffer_(std::move(filter_slice)),
+ output_buffer_(std::move(output_slice)),
+ scratch_buffer_(std::move(scratch_slice)),
+ tuple_result_buffer_(std::move(tuple_result_slice)) {}
ConvolutionThunk(const ConvolutionThunk&) = delete;
ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
@@ -72,23 +70,12 @@ class ConvolutionThunk : public Thunk {
HloExecutionProfiler* profiler) override;
private:
- const CudnnConvKind convolution_kind_;
-
- const BufferAllocation::Slice input_buffer_;
- const BufferAllocation::Slice filter_buffer_;
- const BufferAllocation::Slice output_buffer_;
- const BufferAllocation::Slice tuple_result_buffer_;
- const BufferAllocation::Slice scratch_buffer_;
-
- const Shape input_shape_;
- const Shape filter_shape_;
- const Shape output_shape_;
-
- const Window window_;
- const ConvolutionDimensionNumbers dim_nums_;
- int64 feature_group_count_;
- int64 algorithm_;
- bool tensor_ops_enabled_;
+ const HloCustomCallInstruction* cudnn_call_;
+ BufferAllocation::Slice input_buffer_;
+ BufferAllocation::Slice filter_buffer_;
+ BufferAllocation::Slice output_buffer_;
+ BufferAllocation::Slice scratch_buffer_;
+ BufferAllocation::Slice tuple_result_buffer_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 20d523abe0..22f43bc08b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -287,5 +288,42 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
value->getType());
}
+Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
+ CudnnConvParams* params) {
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
+ custom_call->backend_config<CudnnConvBackendConfig>());
+ const auto& target = custom_call->custom_call_target();
+ const auto& lhs_shape = custom_call->operand(0)->shape();
+ const auto& rhs_shape = custom_call->operand(1)->shape();
+ const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
+
+ params->window = &custom_call->window();
+ params->dnums = &custom_call->convolution_dimension_numbers();
+ params->feature_group_count = custom_call->feature_group_count();
+ params->algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
+ backend_config.algorithm(), backend_config.tensor_ops_enabled()));
+
+ if (target == kCudnnConvForwardCallTarget) {
+ params->kind = CudnnConvKind::kForward;
+ params->input_shape = &lhs_shape;
+ params->filter_shape = &rhs_shape;
+ params->output_shape = &conv_result_shape;
+ } else if (target == kCudnnConvBackwardInputCallTarget) {
+ params->kind = CudnnConvKind::kBackwardInput;
+ params->input_shape = &conv_result_shape;
+ params->filter_shape = &rhs_shape;
+ params->output_shape = &lhs_shape;
+ } else if (target == kCudnnConvBackwardFilterCallTarget) {
+ params->kind = CudnnConvKind::kBackwardFilter;
+ params->input_shape = &lhs_shape;
+ params->filter_shape = &conv_result_shape;
+ params->output_shape = &rhs_shape;
+ } else {
+ LOG(FATAL) << "Unexpected custom call target: "
+ << custom_call->custom_call_target();
+ }
+ return Status::OK();
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index 59c65fc268..09c455cc1e 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -20,7 +20,9 @@ limitations under the License.
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
// TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they
// don't belong in "ir_emission_utils".
@@ -148,6 +150,11 @@ llvm::Value* EmitPrintf(absl::string_view fmt,
llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
llvm::IRBuilder<>* builder);
+// Populates params using conv, which must be a custom-call to a cudnn
+// convolution. Does not modify any buffers in the params.
+Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
+ CudnnConvParams* params);
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index f91cc00d71..b669881026 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -61,6 +61,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -464,67 +465,35 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
if (IsCustomCallToDnnConvolution(*custom_call)) {
const auto& assn = ir_emitter_context_->buffer_assignment();
- const auto& lhs_shape = custom_call->operand(0)->shape();
- const auto& rhs_shape = custom_call->operand(1)->shape();
- const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
auto lhs_slice = GetAllocationSlice(*custom_call->operand(0));
auto rhs_slice = GetAllocationSlice(*custom_call->operand(1));
auto tuple_result_slice = GetAllocationSlice(*custom_call);
auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
- TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
- custom_call->backend_config<CudnnConvBackendConfig>());
const auto& target = custom_call->custom_call_target();
- std::unique_ptr<ConvolutionThunk> thunk;
+ BufferAllocation::Slice input_slice, filter_slice, output_slice;
+
if (target == kCudnnConvForwardCallTarget) {
- thunk = absl::make_unique<ConvolutionThunk>(
- CudnnConvKind::kForward,
- /*input_buffer=*/lhs_slice,
- /*filter_buffer=*/rhs_slice,
- /*output_buffer=*/conv_result_slice,
- /*tuple_result_buffer=*/tuple_result_slice,
- /*scratch_buffer=*/scratch_slice,
- /*input_shape=*/lhs_shape,
- /*filter_shape=*/rhs_shape,
- /*output_shape=*/conv_result_shape, //
- custom_call->window(), custom_call->convolution_dimension_numbers(),
- custom_call->feature_group_count(), backend_config.algorithm(),
- backend_config.tensor_ops_enabled(), custom_call);
+ input_slice = lhs_slice;
+ filter_slice = rhs_slice;
+ output_slice = conv_result_slice;
} else if (target == kCudnnConvBackwardInputCallTarget) {
- thunk = absl::make_unique<ConvolutionThunk>(
- CudnnConvKind::kBackwardInput,
- /*input_buffer=*/conv_result_slice,
- /*filter_buffer=*/rhs_slice,
- /*output_buffer=*/lhs_slice,
- /*tuple_result_buffer=*/tuple_result_slice,
- /*scratch_buffer=*/scratch_slice,
- /*input_shape=*/conv_result_shape,
- /*filter_shape=*/rhs_shape,
- /*output_shape=*/lhs_shape, //
- custom_call->window(), custom_call->convolution_dimension_numbers(),
- custom_call->feature_group_count(), backend_config.algorithm(),
- backend_config.tensor_ops_enabled(), custom_call);
+ input_slice = conv_result_slice;
+ filter_slice = rhs_slice;
+ output_slice = lhs_slice;
} else if (target == kCudnnConvBackwardFilterCallTarget) {
- thunk = absl::make_unique<ConvolutionThunk>(
- CudnnConvKind::kBackwardFilter,
- /*input_buffer=*/lhs_slice,
- /*filter_buffer=*/conv_result_slice,
- /*output_buffer=*/rhs_slice,
- /*tuple_result_buffer=*/tuple_result_slice,
- /*scratch_buffer=*/scratch_slice,
- /*input_shape=*/lhs_shape,
- /*filter_shape=*/conv_result_shape,
- /*output_shape=*/rhs_shape, //
- custom_call->window(), custom_call->convolution_dimension_numbers(),
- custom_call->feature_group_count(), backend_config.algorithm(),
- backend_config.tensor_ops_enabled(), custom_call);
+ input_slice = lhs_slice;
+ filter_slice = conv_result_slice;
+ output_slice = rhs_slice;
} else {
LOG(FATAL) << "Unexpected custom call target: "
<< custom_call->custom_call_target();
}
- thunk_sequence_->emplace_back(std::move(thunk));
+ thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>(
+ Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice,
+ output_slice, scratch_slice, tuple_result_slice));
return Status::OK();
}