aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-09-18 12:40:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-18 12:48:41 -0700
commita1ffaf3620801af2a7559b0ee393f962fb6ed7ae (patch)
tree215572c619acb0b8c5b55db85799110c792f4104 /tensorflow/stream_executor
parent0c8a8289da120ee353c4fba5decb0bea9014e0a7 (diff)
[SE] Restore int8x4 data types if that's the requested DataLayout for fused conv
This broke in a recent refactoring. PiperOrigin-RevId: 213497416
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc38
1 files changed, 25 insertions, 13 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 63ab367086..3a77ba769c 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -132,30 +132,39 @@ string ToString(cudnnStatus_t status) {
}
template <typename T>
-cudnnDataType_t GetCudnnDataType();
+cudnnDataType_t GetCudnnDataType(
+ dnn::DataLayout = dnn::DataLayout::kBatchDepthYX);
template <>
-cudnnDataType_t GetCudnnDataType<double>() {
+cudnnDataType_t GetCudnnDataType<double>(dnn::DataLayout) {
return CUDNN_DATA_DOUBLE;
}
template <>
-cudnnDataType_t GetCudnnDataType<float>() {
+cudnnDataType_t GetCudnnDataType<float>(dnn::DataLayout) {
return CUDNN_DATA_FLOAT;
}
template <>
-cudnnDataType_t GetCudnnDataType<Eigen::half>() {
+cudnnDataType_t GetCudnnDataType<Eigen::half>(dnn::DataLayout) {
return CUDNN_DATA_HALF;
}
template <>
-cudnnDataType_t GetCudnnDataType<int8>() {
- return CUDNN_DATA_INT8;
+cudnnDataType_t GetCudnnDataType<int8>(dnn::DataLayout layout) {
+ switch (layout) {
+ case dnn::DataLayout::kYXDepthBatch:
+ case dnn::DataLayout::kYXBatchDepth:
+ case dnn::DataLayout::kBatchYXDepth:
+ case dnn::DataLayout::kBatchDepthYX:
+ return CUDNN_DATA_INT8;
+ case dnn::DataLayout::kBatchDepthYX4:
+ return CUDNN_DATA_INT8x4;
+ }
}
template <>
-cudnnDataType_t GetCudnnDataType<int32>() {
+cudnnDataType_t GetCudnnDataType<int32>(dnn::DataLayout) {
return CUDNN_DATA_INT32;
}
@@ -2518,12 +2527,15 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
"Relu or None activation.");
}
- CudnnTensorDescriptor conv_input_nd(conv_input_descriptor,
- GetCudnnDataType<ElementType>());
- CudnnTensorDescriptor output_nd(output_descriptor,
- GetCudnnDataType<ElementType>());
- CudnnFilterDescriptor filter(filter_descriptor,
- GetCudnnDataType<ElementType>());
+ CudnnTensorDescriptor conv_input_nd(
+ conv_input_descriptor,
+ GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
+ CudnnTensorDescriptor output_nd(
+ output_descriptor,
+ GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
+ CudnnFilterDescriptor filter(
+ filter_descriptor,
+ GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>());
CudnnConvolutionDescriptor conv(convolution_descriptor,
GetCudnnDataType<AccumulatorType>());