aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h')
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index bc5d1ce94a..8b7749628a 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
+#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -34,8 +35,9 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
// memory while timing the various convolution algorithms. If it's null,
// we'll use the default allocator on the StreamExecutor.
CudnnConvolutionAlgorithmPicker(se::StreamExecutor* stream_exec,
- DeviceMemoryAllocator* allocator)
- : stream_exec_(stream_exec), allocator_(allocator) {}
+ DeviceMemoryAllocator* allocator,
+ Compiler* compiler)
+ : stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {}
tensorflow::StringPiece name() const override {
return "cudnn-convolution-algorithm-picker";
@@ -46,13 +48,14 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
private:
StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
- tensorflow::gtl::optional<std::tuple<int64, bool, int64>> PickBestAlgorithm(
+ StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
const ConvolutionDimensionNumbers& dnums, HloInstruction* instr);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null
+ Compiler* compiler_;
};
} // namespace gpu