diff options
Diffstat (limited to 'tensorflow/core/kernels/pooling_ops_common.cc')
-rw-r--r-- | tensorflow/core/kernels/pooling_ops_common.cc | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc index 3fe16c66b8..37747a3199 100644 --- a/tensorflow/core/kernels/pooling_ops_common.cc +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -17,6 +17,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #if GOOGLE_CUDA @@ -127,8 +128,7 @@ namespace functor { typename TTypes<T, 4>::Tensor out); \ extern template struct TransformDepth<GPUDevice, T, Eigen::DenseIndex>; -DECLARE_GPU_SPEC(float); -DECLARE_GPU_SPEC(Eigen::half); +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC) #undef DECLARE_GPU_SPEC } // namespace functor @@ -373,10 +373,11 @@ void DnnPoolingGradOp<T>::Compute( } } -template class DnnPoolingOp<Eigen::half>; -template class DnnPoolingOp<float>; -template class DnnPoolingGradOp<Eigen::half>; -template class DnnPoolingGradOp<float>; +#define DEFINE_DNN_OPS(T) \ + template class DnnPoolingOp<T>; \ + template class DnnPoolingGradOp<T>; +TF_CALL_GPU_NUMBER_TYPES(DEFINE_DNN_OPS) +#undef DEFINE_DNN_OPS #endif // GOOGLE_CUDA |