aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/dnn.h
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-04-07 11:42:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-07 11:45:04 -0700
commit1cd76c209ce6f74298843568a7fc397c2e6f958f (patch)
treec4647ef54eaba837b9a5a1a05b0cf029aaec7b36 /tensorflow/stream_executor/dnn.h
parente7ea87f97e03360719d132a71acc1eb2f93c249f (diff)
[XLA:GPU] Eliminate the guard around Winograd non-fused convolutions with cudnn7.
Adds DnnSupport::GetVersion() and uses this to unguard Winograd non-fused convolutions if you're using cudnn7. PiperOrigin-RevId: 192010450
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r--tensorflow/stream_executor/dnn.h7
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 43cfd313c1..3c47d2c2e8 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -25,6 +25,7 @@ limitations under the License.
#include <functional>
#include <limits>
#include <memory>
+#include <tuple>
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
@@ -885,6 +886,12 @@ class DnnSupport {
virtual port::Status Init() = 0;
+ // Gets the version of the backing library, as a {major, minor, patch} tuple.
+ virtual port::StatusOr<std::tuple<int, int, int>> GetVersion() {
+ return port::UnimplementedError(
+ "DnnSupport::GetVersion not implemented on this platform.");
+ }
+
// Performs a single-precision forward batch normalization operation onto
// the stream.
//