diff options
Diffstat (limited to 'tensorflow/core/kernels/cudnn_pooling_gpu.cc')
-rw-r--r-- | tensorflow/core/kernels/cudnn_pooling_gpu.cc | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/cudnn_pooling_gpu.cc b/tensorflow/core/kernels/cudnn_pooling_gpu.cc index 66f9249234..5939ecdf62 100644 --- a/tensorflow/core/kernels/cudnn_pooling_gpu.cc +++ b/tensorflow/core/kernels/cudnn_pooling_gpu.cc @@ -18,6 +18,7 @@ limitations under the License. #include <array> +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/conv_2d.h" #include "tensorflow/core/kernels/conv_3d.h" #include "tensorflow/core/kernels/conv_ops_gpu.h" @@ -242,8 +243,11 @@ void DnnPooling3dGradOp<T>::Compute( } } -template class DnnPooling3dOp<float>; -template class DnnPooling3dGradOp<float>; +#define DEFINE_DNN_OPS(T) \ + template class DnnPooling3dOp<T>; \ + template class DnnPooling3dGradOp<T>; +TF_CALL_float(DEFINE_DNN_OPS) TF_CALL_half(DEFINE_DNN_OPS) +#undef DEFINE_DNN_OPS #endif // GOOGLE_CUDA |