aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc38
1 files changed, 25 insertions, 13 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 63ab367086..3a77ba769c 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -132,30 +132,39 @@ string ToString(cudnnStatus_t status) {
}
template <typename T>
-cudnnDataType_t GetCudnnDataType();
+cudnnDataType_t GetCudnnDataType(
+ dnn::DataLayout = dnn::DataLayout::kBatchDepthYX);
template <>
-cudnnDataType_t GetCudnnDataType<double>() {
+cudnnDataType_t GetCudnnDataType<double>(dnn::DataLayout) {
return CUDNN_DATA_DOUBLE;
}
template <>
-cudnnDataType_t GetCudnnDataType<float>() {
+cudnnDataType_t GetCudnnDataType<float>(dnn::DataLayout) {
return CUDNN_DATA_FLOAT;
}
template <>
-cudnnDataType_t GetCudnnDataType<Eigen::half>() {
+cudnnDataType_t GetCudnnDataType<Eigen::half>(dnn::DataLayout) {
return CUDNN_DATA_HALF;
}
template <>
-cudnnDataType_t GetCudnnDataType<int8>() {
- return CUDNN_DATA_INT8;
+cudnnDataType_t GetCudnnDataType<int8>(dnn::DataLayout layout) {
+ switch (layout) {
+ case dnn::DataLayout::kYXDepthBatch:
+ case dnn::DataLayout::kYXBatchDepth:
+ case dnn::DataLayout::kBatchYXDepth:
+ case dnn::DataLayout::kBatchDepthYX:
+ return CUDNN_DATA_INT8;
+ case dnn::DataLayout::kBatchDepthYX4:
+ return CUDNN_DATA_INT8x4;
+ }
}
template <>
-cudnnDataType_t GetCudnnDataType<int32>() {
+cudnnDataType_t GetCudnnDataType<int32>(dnn::DataLayout) {
return CUDNN_DATA_INT32;
}
@@ -2518,12 +2527,15 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
"Relu or None activation.");
}
- CudnnTensorDescriptor conv_input_nd(conv_input_descriptor,
- GetCudnnDataType<ElementType>());
- CudnnTensorDescriptor output_nd(output_descriptor,
- GetCudnnDataType<ElementType>());
- CudnnFilterDescriptor filter(filter_descriptor,
- GetCudnnDataType<ElementType>());
+ CudnnTensorDescriptor conv_input_nd(
+ conv_input_descriptor,
+ GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
+ CudnnTensorDescriptor output_nd(
+ output_descriptor,
+ GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
+ CudnnFilterDescriptor filter(
+ filter_descriptor,
+ GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>());
CudnnConvolutionDescriptor conv(convolution_descriptor,
GetCudnnDataType<AccumulatorType>());