diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index 0cb01161b0..f79b113f8f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -49,10 +50,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { StatusOr<bool> RunOnComputation(HloComputation* computation); StatusOr<bool> RunOnInstruction(HloInstruction* instr); StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - HloInstruction* instr); + const HloCustomCallInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null |