From a1ffaf3620801af2a7559b0ee393f962fb6ed7ae Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Tue, 18 Sep 2018 12:40:49 -0700 Subject: [SE] Restore int8x4 data types if that's the requested DataLayout for fused conv This broke in a recent refactoring. PiperOrigin-RevId: 213497416 --- tensorflow/stream_executor/cuda/cuda_dnn.cc | 38 +++++++++++++++++++---------- 1 file changed, 25 insertions(+), 13 deletions(-) (limited to 'tensorflow/stream_executor') diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 63ab367086..3a77ba769c 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -132,30 +132,39 @@ string ToString(cudnnStatus_t status) { } template -cudnnDataType_t GetCudnnDataType(); +cudnnDataType_t GetCudnnDataType( + dnn::DataLayout = dnn::DataLayout::kBatchDepthYX); template <> -cudnnDataType_t GetCudnnDataType() { +cudnnDataType_t GetCudnnDataType(dnn::DataLayout) { return CUDNN_DATA_DOUBLE; } template <> -cudnnDataType_t GetCudnnDataType() { +cudnnDataType_t GetCudnnDataType(dnn::DataLayout) { return CUDNN_DATA_FLOAT; } template <> -cudnnDataType_t GetCudnnDataType() { +cudnnDataType_t GetCudnnDataType(dnn::DataLayout) { return CUDNN_DATA_HALF; } template <> -cudnnDataType_t GetCudnnDataType() { - return CUDNN_DATA_INT8; +cudnnDataType_t GetCudnnDataType(dnn::DataLayout layout) { + switch (layout) { + case dnn::DataLayout::kYXDepthBatch: + case dnn::DataLayout::kYXBatchDepth: + case dnn::DataLayout::kBatchYXDepth: + case dnn::DataLayout::kBatchDepthYX: + return CUDNN_DATA_INT8; + case dnn::DataLayout::kBatchDepthYX4: + return CUDNN_DATA_INT8x4; + } } template <> -cudnnDataType_t GetCudnnDataType() { +cudnnDataType_t GetCudnnDataType(dnn::DataLayout) { return CUDNN_DATA_INT32; } @@ -2518,12 +2527,15 @@ port::Status CudnnSupport::DoFusedConvolveImpl( "Relu or None activation."); } - CudnnTensorDescriptor conv_input_nd(conv_input_descriptor, - GetCudnnDataType()); - CudnnTensorDescriptor output_nd(output_descriptor, - GetCudnnDataType()); - CudnnFilterDescriptor filter(filter_descriptor, - GetCudnnDataType()); + CudnnTensorDescriptor conv_input_nd( + conv_input_descriptor, + GetCudnnDataType(conv_input_descriptor.layout())); + CudnnTensorDescriptor output_nd( + output_descriptor, + GetCudnnDataType(conv_input_descriptor.layout())); + CudnnFilterDescriptor filter( + filter_descriptor, + GetCudnnDataType(conv_input_descriptor.layout())); CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType()); CudnnConvolutionDescriptor conv(convolution_descriptor, GetCudnnDataType()); -- cgit v1.2.3