aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc12
-rw-r--r--tensorflow/core/kernels/conv_grad_filter_ops.cc3
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc3
-rw-r--r--tensorflow/core/kernels/conv_grad_ops_3d.cc7
-rw-r--r--tensorflow/core/kernels/conv_ops.cc3
-rw-r--r--tensorflow/core/kernels/conv_ops_3d.cc3
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.h16
-rw-r--r--tensorflow/core/kernels/conv_ops_test.cc42
-rw-r--r--tensorflow/python/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py23
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc22
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h3
-rw-r--r--tensorflow/stream_executor/dnn.cc6
-rw-r--r--tensorflow/stream_executor/dnn.h6
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc12
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h5
16 files changed, 135 insertions, 34 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index f6b7fe1e8e..b2197c6a1f 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -258,15 +258,21 @@ tensorflow::Status ConvolutionThunk::Convolve(
std::vector<se::dnn::AlgorithmType> ConvolutionThunk::GetAlgorithms(
se::StreamExecutor* stream_exec) const {
std::vector<se::dnn::AlgorithmType> algorithms;
+ // TODO(yangzihao): Currently disable the use of winograd nonfused in XLA
+ // by default. Should send in conv parameters and enable it when
+ // ShouldIncludeWinogradNonfusedAlgo() returns true.
switch (convolution_kind_) {
case ConvolutionKind::kBackwardFilter:
- CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms(&algorithms));
+ CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms(
+ /*with_winograd_nonfused=*/false, &algorithms));
break;
case ConvolutionKind::kBackwardInput:
- CHECK(stream_exec->GetConvolveBackwardDataAlgorithms(&algorithms));
+ CHECK(stream_exec->GetConvolveBackwardDataAlgorithms(
+ /*with_winograd_nonfused=*/false, &algorithms));
break;
case ConvolutionKind::kForward:
- CHECK(stream_exec->GetConvolveAlgorithms(&algorithms));
+ CHECK(stream_exec->GetConvolveAlgorithms(/*with_winograd_nonfused=*/false,
+ &algorithms));
break;
}
return algorithms;
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index 20394cad43..98c2ea1362 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -776,7 +776,8 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
if (cudnn_use_autotune_ && !AutoTuneConvBwdFilter::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
std::vector<AlgorithmType> algorithms;
- CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(&algorithms));
+ CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index 9a50431a2f..a94b1bea4b 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -856,7 +856,8 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
if (cudnn_use_autotune_ && !AutoTuneConvBwdData::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
std::vector<AlgorithmType> algorithms;
- CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(&algorithms));
+ CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index 61b1e0fd3f..b4d0bf2cfa 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -656,7 +656,8 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
std::vector<AlgorithmType> algorithms;
- CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(&algorithms));
+ CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
@@ -1020,11 +1021,11 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
using perftools::gputools::dnn::ProfileResult;
using perftools::gputools::dnn::kDefaultAlgorithm;
AlgorithmConfig algorithm_config;
-
if (cudnn_use_autotune_ && !AutoTuneConv3dBwdFilter::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
std::vector<AlgorithmType> algorithms;
- CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(&algorithms));
+ CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index b3803778c8..8c75b312ef 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -668,7 +668,8 @@ void LaunchConv2DOp<GPUDevice, T>::launch(
if (cudnn_use_autotune &&
!AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) {
std::vector<AlgorithmType> algorithms;
- CHECK(stream->parent()->GetConvolveAlgorithms(&algorithms));
+ CHECK(stream->parent()->GetConvolveAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index dfcb2cfbe2..58f8e3b2cd 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -392,7 +392,8 @@ struct LaunchConvOp<GPUDevice, T> {
if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
std::vector<AlgorithmType> algorithms;
- CHECK(stream->parent()->GetConvolveAlgorithms(&algorithms));
+ CHECK(stream->parent()->GetConvolveAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index 8917824bfa..b268f8dbd2 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -145,6 +145,22 @@ class ConvParameters {
// clang-format on
}
+ // TODO(yangzihao): The purpose of this function is to disable winograd
+ // nonfused conv algorithm for certain input parameters so as to avoid a bug
+ // in cuDNNv5 and cuDNNv6. Remove this once switch to cuDNNv7.
+ template <typename T>
+ bool ShouldIncludeWinogradNonfusedAlgo() const {
+ int64 total_size = 16 * std::ceil(batch_ / 16.0) *
+ std::max(in_depths_, out_depths_) * in_[0] * in_[1] *
+ sizeof(T);
+ int64 threshold = 1L << 31;
+ if (total_size >= threshold) {
+ return false;
+ } else {
+ return true;
+ }
+ }
+
private:
typedef std::tuple<int64, int64, SpatialArray, int64, SpatialArray,
SpatialArray, SpatialArray, DataType, int>
diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc
index cd9aa4a53e..88ba433050 100644
--- a/tensorflow/core/kernels/conv_ops_test.cc
+++ b/tensorflow/core/kernels/conv_ops_test.cc
@@ -28,8 +28,50 @@ limitations under the License.
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/kernels/conv_ops_gpu.h"
+
namespace tensorflow {
+#if GOOGLE_CUDA
+
+TEST(ConvParameters, WinogradNonfusedAlgoSize) {
+ ConvParameters conv_params_small = {
+ 1, // batch
+ 32, // in_depths
+ {{300, // in_rows
+ 300}}, // in_cols
+ 128, // out_depths
+ {{3, // filter_rows
+ 3}}, // filter_cols
+ {{1, // stride_rows
+ 1}}, // stride_cols
+ {{0, // padding_rows
+ 0}}, // padding_cols
+ DT_FLOAT, // tensor datatype
+ 0, // device_id
+ };
+ EXPECT_TRUE(conv_params_small.ShouldIncludeWinogradNonfusedAlgo<float>());
+
+ ConvParameters conv_params_large = {
+ 1, // batch
+ 128, // in_depths
+ {{300, // in_rows
+ 300}}, // in_cols
+ 768, // out_depths
+ {{3, // filter_rows
+ 3}}, // filter_cols
+ {{1, // stride_rows
+ 1}}, // stride_cols
+ {{0, // padding_rows
+ 0}}, // padding_cols
+ DT_FLOAT, // tensor datatype
+ 0, // device_id
+ };
+ EXPECT_FALSE(conv_params_large.ShouldIncludeWinogradNonfusedAlgo<float>());
+}
+
+#endif // GOOGLE_CUDA
+
class FusedResizePadConvOpTest : public OpsTestBase {
protected:
void HandwrittenConv() {
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 419823263c..7dff3d5189 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -2071,7 +2071,7 @@ cuda_py_test(
cuda_py_test(
name = "conv_ops_test",
- size = "medium",
+ size = "large",
srcs = ["conv_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2089,6 +2089,7 @@ cuda_py_test(
"//tensorflow/python:random_ops",
"//tensorflow/python:variables",
],
+ shard_count = 4,
)
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index db0adfc794..b9a853e5d9 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -189,7 +189,7 @@ class Conv2DTest(test.TestCase):
# numbers from 1.
x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
- with self.test_session(use_gpu=use_gpu) as sess:
+ with self.test_session(use_gpu=use_gpu):
t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype)
t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype)
strides = [1] + strides + [1]
@@ -378,7 +378,7 @@ class Conv2DTest(test.TestCase):
expected=[50, 60])
# TODO this currently fails.
- #self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1],
+ # self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1],
# filter_in_sizes=[2, 2, 1, 1],
# strides=[4, 4], padding="SAME",
# expected=[72, 112, 392, 432])
@@ -424,7 +424,7 @@ class Conv2DTest(test.TestCase):
x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(data_format, use_gpu):
- with self.test_session(use_gpu=use_gpu) as sess:
+ with self.test_session(use_gpu=use_gpu):
if data_format == "NCHW":
new_input_sizes = test_util.NHWCToNCHW(input_sizes)
else:
@@ -580,7 +580,7 @@ class Conv2DTest(test.TestCase):
x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(data_format, use_gpu):
- with self.test_session(use_gpu=use_gpu) as sess:
+ with self.test_session(use_gpu=use_gpu):
t0 = constant_op.constant(x0, shape=input_sizes)
t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
t2 = constant_op.constant(x2, shape=output_sizes)
@@ -1444,4 +1444,19 @@ if __name__ == "__main__":
GetInceptionBackFilterTest(input_size_, filter_size_, output_size_,
[stride_, stride_], padding_))
+ # TODO(b/35359731)
+ # Fwd, BckInput, and BackFilter to test that for certain input parameter
+ # set, winograd nonfused algorithm will be excluded from conv autotune. If
+ # in such case, winograd nonfused algorithm is added as one option of the
+ # conv autotune, and cuDNN version is smaller than 7, the following tests
+ # will fail.
+ ishape = [1, 400, 400, 128]
+ fshape = [3, 3, 128, 768]
+ oshape = [1, 400, 400, 768]
+ setattr(Conv2DTest, "testInceptionFwd_No_Winograd_Nonfused",
+ GetInceptionFwdTest(ishape, fshape, 1, "SAME"))
+ setattr(Conv2DTest, "testInceptionBackInput_No_Winograd_Nonfused",
+ GetInceptionBackInputTest(ishape, fshape, oshape, 1, "SAME"))
+ setattr(Conv2DTest, "testInceptionBackFilter_No_Winograd_Nonfused",
+ GetInceptionBackFilterTest(ishape, fshape, oshape, [1, 1], "SAME"))
test.main()
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 8a0fbcd8a8..e1674745c8 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -1966,12 +1966,12 @@ bool CudnnSupport::DoConvolveImpl(
}
// A helper class to decide whether to enable the WINOGRAD_NONFUSED algorithms.
-// Doing so by default make a few TensorFlow test cases to fail. Users can
-// explicitly enable them through an env-var "TF_ENABLE_WINOGRAD_NONFUSED=1".
+// By default it is turned on, users can explicitly disable them through an
+// env-var "TF_ENABLE_WINOGRAD_NONFUSED=0".
// https://github.com/tensorflow/tensorflow/pull/4901
-// TODO(yangzihao): for certain shapes, setting default flag to be true will
-// cause bug and return negative tensor shapes. Will flip the default flag when
-// the bug is fixed.
+// TODO(yangzihao): winograd_nonfused bug will only be fixed in cuDNNv7, for
+// cuDNN with smaller versions, we have added code to avoid using winograd
+// nonfused for certain input parameter set.
template <bool DefaultFlag>
class WinogradNonfused {
public:
@@ -1997,6 +1997,7 @@ class WinogradNonfused {
};
bool CudnnSupport::GetConvolveAlgorithms(
+ bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType>* out_algorithms) {
out_algorithms->assign({
// clang-format off
@@ -2012,7 +2013,7 @@ bool CudnnSupport::GetConvolveAlgorithms(
// clang-format on
});
#if CUDNN_VERSION >= 5100
- if (WinogradNonfused<false>::IsEnabled()) {
+ if (WinogradNonfused<true>::IsEnabled() && with_winograd_nonfused) {
out_algorithms->push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
}
#endif
@@ -2020,6 +2021,7 @@ bool CudnnSupport::GetConvolveAlgorithms(
}
bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
+ bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType>* out_algorithms) {
out_algorithms->assign({
// clang-format off
@@ -2033,7 +2035,7 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
// clang-format on
});
#if CUDNN_VERSION >= 5100
- if (WinogradNonfused<false>::IsEnabled()) {
+ if (WinogradNonfused<true>::IsEnabled() && with_winograd_nonfused) {
out_algorithms->push_back(
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
}
@@ -2042,6 +2044,7 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
}
bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
+ bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType>* out_algorithms) {
out_algorithms->assign({
// clang-format off
@@ -2053,11 +2056,12 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
});
#if CUDNN_VERSION >= 5100
#if CUDNN_VERSION >= 5110
- static constexpr bool kDefaultFlagWinogradNonfused = false;
+ static constexpr bool kDefaultFlagWinogradNonfused = true;
#else
static constexpr bool kDefaultFlagWinogradNonfused = false;
#endif
- if (WinogradNonfused<kDefaultFlagWinogradNonfused>::IsEnabled()) {
+ if (WinogradNonfused<kDefaultFlagWinogradNonfused>::IsEnabled() &&
+ with_winograd_nonfused) {
out_algorithms->push_back(
// Based on cudnn.h, the following is not implemented.
// CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index b280b73c70..2c8ed9a335 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -104,12 +104,15 @@ class CudnnSupport : public dnn::DnnSupport {
ScratchAllocator* workspace_allocator) override;
bool GetConvolveAlgorithms(
+ bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType>* out_algorithms) override;
bool GetConvolveBackwardDataAlgorithms(
+ bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType>* out_algorithms) override;
bool GetConvolveBackwardFilterAlgorithms(
+ bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType>* out_algorithms) override;
bool DoBatchNormalizationForward(
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index e834119bf8..ee78066f95 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -23,17 +23,17 @@ namespace gputools {
namespace dnn {
bool DnnSupport::GetConvolveAlgorithms(
- std::vector<AlgorithmType>* out_algorithms) {
+ bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms) {
return false;
}
bool DnnSupport::GetConvolveBackwardDataAlgorithms(
- std::vector<AlgorithmType>* out_algorithms) {
+ bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms) {
return false;
}
bool DnnSupport::GetConvolveBackwardFilterAlgorithms(
- std::vector<AlgorithmType>* out_algorithms) {
+ bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms) {
return false;
}
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index c5805064f3..8e56933ba3 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -952,7 +952,7 @@ class DnnSupport {
// Return a list of algorithms supported by the forward convolution pass.
virtual bool GetConvolveAlgorithms(
- std::vector<AlgorithmType>* out_algorithms);
+ bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms);
// Enqueues a double-precision convolution operation onto the stream.
// See DoConvolve above for argument details.
@@ -1056,7 +1056,7 @@ class DnnSupport {
// Return a list of algorithms supported by the backward convolution pass for
// data.
virtual bool GetConvolveBackwardDataAlgorithms(
- std::vector<AlgorithmType>* out_algorithms);
+ bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms);
virtual bool DoConvolveBackwardData(
Stream* stream, const FilterDescriptor& filter_descriptor,
@@ -1104,7 +1104,7 @@ class DnnSupport {
// Return a list of algorithms supported by the backward convolution pass for
// filters.
virtual bool GetConvolveBackwardFilterAlgorithms(
- std::vector<AlgorithmType>* out_algorithms);
+ bool with_winograd_nonfused, std::vector<AlgorithmType>* out_algorithms);
virtual bool DoConvolveBackwardFilter(
Stream* stream, const BatchDescriptor& input_descriptor,
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index fe5da12639..b3eefe0299 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -285,30 +285,36 @@ bool StreamExecutor::SupportsDnn() const {
}
bool StreamExecutor::GetConvolveAlgorithms(
+ bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType> *out_algorithms) {
dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) {
return false;
}
- return dnn_support->GetConvolveAlgorithms(out_algorithms);
+ return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused,
+ out_algorithms);
}
bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
+ bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType> *out_algorithms) {
dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) {
return false;
}
- return dnn_support->GetConvolveBackwardDataAlgorithms(out_algorithms);
+ return dnn_support->GetConvolveBackwardDataAlgorithms(with_winograd_nonfused,
+ out_algorithms);
}
bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
+ bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType> *out_algorithms) {
dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) {
return false;
}
- return dnn_support->GetConvolveBackwardFilterAlgorithms(out_algorithms);
+ return dnn_support->GetConvolveBackwardFilterAlgorithms(
+ with_winograd_nonfused, out_algorithms);
}
bool StreamExecutor::GetBlasGemmAlgorithms(
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 5c52afa794..3dbeddd5d4 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -342,15 +342,18 @@ class StreamExecutor {
bool SupportsDnn() const;
// Get the list of supported algorithms for the forward convolution opeartion.
- bool GetConvolveAlgorithms(std::vector<dnn::AlgorithmType> *out_algorithms);
+ bool GetConvolveAlgorithms(bool with_winograd_nonfused,
+ std::vector<dnn::AlgorithmType> *out_algorithms);
// Get the list of supported algorithms for the backward convolution on data.
bool GetConvolveBackwardDataAlgorithms(
+ bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType> *out_algorithms);
// Get the list of supported algorithms for the backward convolution on the
// filter.
bool GetConvolveBackwardFilterAlgorithms(
+ bool with_winograd_nonfused,
std::vector<dnn::AlgorithmType> *out_algorithms);
// Get the list of supported algorithms for BLAS gemm.