diff options
author | Benjamin Kramer <kramerb@google.com> | 2018-09-18 12:40:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-18 12:48:41 -0700 |
commit | a1ffaf3620801af2a7559b0ee393f962fb6ed7ae (patch) | |
tree | 215572c619acb0b8c5b55db85799110c792f4104 /tensorflow/stream_executor | |
parent | 0c8a8289da120ee353c4fba5decb0bea9014e0a7 (diff) |
[SE] Restore int8x4 data types if that's the requested DataLayout for fused conv
This broke in a recent refactoring.
PiperOrigin-RevId: 213497416
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_dnn.cc | 38 |
1 files changed, 25 insertions, 13 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 63ab367086..3a77ba769c 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -132,30 +132,39 @@ string ToString(cudnnStatus_t status) { } template <typename T> -cudnnDataType_t GetCudnnDataType(); +cudnnDataType_t GetCudnnDataType( + dnn::DataLayout = dnn::DataLayout::kBatchDepthYX); template <> -cudnnDataType_t GetCudnnDataType<double>() { +cudnnDataType_t GetCudnnDataType<double>(dnn::DataLayout) { return CUDNN_DATA_DOUBLE; } template <> -cudnnDataType_t GetCudnnDataType<float>() { +cudnnDataType_t GetCudnnDataType<float>(dnn::DataLayout) { return CUDNN_DATA_FLOAT; } template <> -cudnnDataType_t GetCudnnDataType<Eigen::half>() { +cudnnDataType_t GetCudnnDataType<Eigen::half>(dnn::DataLayout) { return CUDNN_DATA_HALF; } template <> -cudnnDataType_t GetCudnnDataType<int8>() { - return CUDNN_DATA_INT8; +cudnnDataType_t GetCudnnDataType<int8>(dnn::DataLayout layout) { + switch (layout) { + case dnn::DataLayout::kYXDepthBatch: + case dnn::DataLayout::kYXBatchDepth: + case dnn::DataLayout::kBatchYXDepth: + case dnn::DataLayout::kBatchDepthYX: + return CUDNN_DATA_INT8; + case dnn::DataLayout::kBatchDepthYX4: + return CUDNN_DATA_INT8x4; + } } template <> -cudnnDataType_t GetCudnnDataType<int32>() { +cudnnDataType_t GetCudnnDataType<int32>(dnn::DataLayout) { return CUDNN_DATA_INT32; } @@ -2518,12 +2527,15 @@ port::Status CudnnSupport::DoFusedConvolveImpl( "Relu or None activation."); } - CudnnTensorDescriptor conv_input_nd(conv_input_descriptor, - GetCudnnDataType<ElementType>()); - CudnnTensorDescriptor output_nd(output_descriptor, - GetCudnnDataType<ElementType>()); - CudnnFilterDescriptor filter(filter_descriptor, - GetCudnnDataType<ElementType>()); + CudnnTensorDescriptor conv_input_nd( + conv_input_descriptor, + GetCudnnDataType<ElementType>(conv_input_descriptor.layout())); + CudnnTensorDescriptor output_nd( + output_descriptor, + GetCudnnDataType<ElementType>(conv_input_descriptor.layout())); + CudnnFilterDescriptor filter( + filter_descriptor, + GetCudnnDataType<ElementType>(conv_input_descriptor.layout())); CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>()); CudnnConvolutionDescriptor conv(convolution_descriptor, GetCudnnDataType<AccumulatorType>()); |