diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index bc5d1ce94a..8b7749628a 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ +#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -34,8 +35,9 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { // memory while timing the various convolution algorithms. If it's null, // we'll use the default allocator on the StreamExecutor. CudnnConvolutionAlgorithmPicker(se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* allocator) - : stream_exec_(stream_exec), allocator_(allocator) {} + DeviceMemoryAllocator* allocator, + Compiler* compiler) + : stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {} tensorflow::StringPiece name() const override { return "cudnn-convolution-algorithm-picker"; @@ -46,13 +48,14 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { private: StatusOr<bool> RunOnComputation(HloComputation* computation); StatusOr<bool> RunOnInstruction(HloInstruction* instr); - tensorflow::gtl::optional<std::tuple<int64, bool, int64>> PickBestAlgorithm( + StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, const ConvolutionDimensionNumbers& dnums, HloInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null + Compiler* compiler_; }; } // namespace gpu |