aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/convolution_thunk.h')
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h55
1 files changed, 21 insertions, 34 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index 68d67c40c5..d7d1f91fba 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -32,7 +33,7 @@ limitations under the License.
namespace xla {
namespace gpu {
-// This class stores everything that StreamExecutor needs to launch a BNN
+// This class stores everything that StreamExecutor needs to launch a DNN
// convolution. It is generated by IrEmitter.
//
// This is thread-compatible.
@@ -41,27 +42,24 @@ class ConvolutionThunk : public Thunk {
// Constructs a thunk for launching a DNN convolution. When run, it will
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
//
- // `algorithm` is a cudnn algorithm number. `algorithm == -1` indicates that
- // we should use the default (i.e. baseline) cudnn algorithm.
- //
// Note that "output" here doesn't refer to the output from running this
// thunk, but rather to the "output" of a hypothetical forward convolution
// that corresponds to this input+filter+output triple. That is, the result
// generated by this thunk is "output" for forward convs, "input" for
// backward-input convs, and "filter" for backward-filter convs.
- //
- // Semantics of null hlo_instruction argument are as in Thunk.
- ConvolutionThunk(CudnnConvKind convolution_kind,
- const BufferAllocation::Slice& input_buffer,
- const BufferAllocation::Slice& filter_buffer,
- const BufferAllocation::Slice& output_buffer,
- const BufferAllocation::Slice& tuple_result_buffer,
- const BufferAllocation::Slice& scratch_buffer,
- const Shape& input_shape, const Shape& filter_shape,
- const Shape& output_shape, const Window& window,
- const ConvolutionDimensionNumbers& dim_nums,
- int64 feature_group_count, int64 algorithm,
- bool tensor_ops_enabled, const HloInstruction* hlo);
+ ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
+ BufferAllocation::Slice input_slice,
+ BufferAllocation::Slice filter_slice,
+ BufferAllocation::Slice output_slice,
+ BufferAllocation::Slice scratch_slice,
+ BufferAllocation::Slice tuple_result_slice)
+ : Thunk(Kind::kConvolution, cudnn_call),
+ cudnn_call_(cudnn_call),
+ input_buffer_(std::move(input_slice)),
+ filter_buffer_(std::move(filter_slice)),
+ output_buffer_(std::move(output_slice)),
+ scratch_buffer_(std::move(scratch_slice)),
+ tuple_result_buffer_(std::move(tuple_result_slice)) {}
ConvolutionThunk(const ConvolutionThunk&) = delete;
ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
@@ -72,23 +70,12 @@ class ConvolutionThunk : public Thunk {
HloExecutionProfiler* profiler) override;
private:
- const CudnnConvKind convolution_kind_;
-
- const BufferAllocation::Slice input_buffer_;
- const BufferAllocation::Slice filter_buffer_;
- const BufferAllocation::Slice output_buffer_;
- const BufferAllocation::Slice tuple_result_buffer_;
- const BufferAllocation::Slice scratch_buffer_;
-
- const Shape input_shape_;
- const Shape filter_shape_;
- const Shape output_shape_;
-
- const Window window_;
- const ConvolutionDimensionNumbers dim_nums_;
- int64 feature_group_count_;
- int64 algorithm_;
- bool tensor_ops_enabled_;
+ const HloCustomCallInstruction* cudnn_call_;
+ BufferAllocation::Slice input_buffer_;
+ BufferAllocation::Slice filter_buffer_;
+ BufferAllocation::Slice output_buffer_;
+ BufferAllocation::Slice scratch_buffer_;
+ BufferAllocation::Slice tuple_result_buffer_;
};
} // namespace gpu