diff options
author | avijit-nervana <avijit.chakraborty@intel.com> | 2018-09-14 09:21:08 -0700 |
---|---|---|
committer | avijit-nervana <avijit.chakraborty@intel.com> | 2018-09-14 09:21:08 -0700 |
commit | 41aaed7751690b0b3137dad2620656a698b3ceae (patch) | |
tree | 00fc1a7f6be0c3968f3e674a65ca4907110ddf2d /tensorflow/compiler/xla/service/gpu/convolution_thunk.h | |
parent | c26c5e1217944448f1f4c2b97626fc4d7d6406d3 (diff) | |
parent | 95338704198205c1bdec1e344e103f1daf05df68 (diff) |
Merge branch 'master' into avijit/add-cpu-backend
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/convolution_thunk.h')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/convolution_thunk.h | 55 |
1 files changed, 21 insertions, 34 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 68d67c40c5..d7d1f91fba 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -32,7 +33,7 @@ limitations under the License. namespace xla { namespace gpu { -// This class stores everything that StreamExecutor needs to launch a BNN +// This class stores everything that StreamExecutor needs to launch a DNN // convolution. It is generated by IrEmitter. // // This is thread-compatible. @@ -41,27 +42,24 @@ class ConvolutionThunk : public Thunk { // Constructs a thunk for launching a DNN convolution. When run, it will // write a tuple (result, scratch_memory) into `tuple_result_buffer`. // - // `algorithm` is a cudnn algorithm number. `algorithm == -1` indicates that - // we should use the default (i.e. baseline) cudnn algorithm. - // // Note that "output" here doesn't refer to the output from running this // thunk, but rather to the "output" of a hypothetical forward convolution // that corresponds to this input+filter+output triple. That is, the result // generated by this thunk is "output" for forward convs, "input" for // backward-input convs, and "filter" for backward-filter convs. - // - // Semantics of null hlo_instruction argument are as in Thunk. - ConvolutionThunk(CudnnConvKind convolution_kind, - const BufferAllocation::Slice& input_buffer, - const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, - const BufferAllocation::Slice& tuple_result_buffer, - const BufferAllocation::Slice& scratch_buffer, - const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, - int64 feature_group_count, int64 algorithm, - bool tensor_ops_enabled, const HloInstruction* hlo); + ConvolutionThunk(const HloCustomCallInstruction* cudnn_call, + BufferAllocation::Slice input_slice, + BufferAllocation::Slice filter_slice, + BufferAllocation::Slice output_slice, + BufferAllocation::Slice scratch_slice, + BufferAllocation::Slice tuple_result_slice) + : Thunk(Kind::kConvolution, cudnn_call), + cudnn_call_(cudnn_call), + input_buffer_(std::move(input_slice)), + filter_buffer_(std::move(filter_slice)), + output_buffer_(std::move(output_slice)), + scratch_buffer_(std::move(scratch_slice)), + tuple_result_buffer_(std::move(tuple_result_slice)) {} ConvolutionThunk(const ConvolutionThunk&) = delete; ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; @@ -72,23 +70,12 @@ class ConvolutionThunk : public Thunk { HloExecutionProfiler* profiler) override; private: - const CudnnConvKind convolution_kind_; - - const BufferAllocation::Slice input_buffer_; - const BufferAllocation::Slice filter_buffer_; - const BufferAllocation::Slice output_buffer_; - const BufferAllocation::Slice tuple_result_buffer_; - const BufferAllocation::Slice scratch_buffer_; - - const Shape input_shape_; - const Shape filter_shape_; - const Shape output_shape_; - - const Window window_; - const ConvolutionDimensionNumbers dim_nums_; - int64 feature_group_count_; - int64 algorithm_; - bool tensor_ops_enabled_; + const HloCustomCallInstruction* cudnn_call_; + BufferAllocation::Slice input_buffer_; + BufferAllocation::Slice filter_buffer_; + BufferAllocation::Slice output_buffer_; + BufferAllocation::Slice scratch_buffer_; + BufferAllocation::Slice tuple_result_buffer_; }; } // namespace gpu |