aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cudnn_pooling_gpu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cudnn_pooling_gpu.cc')
-rw-r--r--tensorflow/core/kernels/cudnn_pooling_gpu.cc8
1 files changed, 6 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/cudnn_pooling_gpu.cc b/tensorflow/core/kernels/cudnn_pooling_gpu.cc
index 66f9249234..5939ecdf62 100644
--- a/tensorflow/core/kernels/cudnn_pooling_gpu.cc
+++ b/tensorflow/core/kernels/cudnn_pooling_gpu.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <array>
+#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/conv_3d.h"
#include "tensorflow/core/kernels/conv_ops_gpu.h"
@@ -242,8 +243,11 @@ void DnnPooling3dGradOp<T>::Compute(
}
}
-template class DnnPooling3dOp<float>;
-template class DnnPooling3dGradOp<float>;
+#define DEFINE_DNN_OPS(T) \
+ template class DnnPooling3dOp<T>; \
+ template class DnnPooling3dGradOp<T>;
+TF_CALL_float(DEFINE_DNN_OPS) TF_CALL_half(DEFINE_DNN_OPS)
+#undef DEFINE_DNN_OPS
#endif // GOOGLE_CUDA