aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/pooling_ops_common.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/pooling_ops_common.cc')
-rw-r--r--tensorflow/core/kernels/pooling_ops_common.cc13
1 files changed, 7 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc
index 3fe16c66b8..37747a3199 100644
--- a/tensorflow/core/kernels/pooling_ops_common.cc
+++ b/tensorflow/core/kernels/pooling_ops_common.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#if GOOGLE_CUDA
@@ -127,8 +128,7 @@ namespace functor {
typename TTypes<T, 4>::Tensor out); \
extern template struct TransformDepth<GPUDevice, T, Eigen::DenseIndex>;
-DECLARE_GPU_SPEC(float);
-DECLARE_GPU_SPEC(Eigen::half);
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC)
#undef DECLARE_GPU_SPEC
} // namespace functor
@@ -373,10 +373,11 @@ void DnnPoolingGradOp<T>::Compute(
}
}
-template class DnnPoolingOp<Eigen::half>;
-template class DnnPoolingOp<float>;
-template class DnnPoolingGradOp<Eigen::half>;
-template class DnnPoolingGradOp<float>;
+#define DEFINE_DNN_OPS(T) \
+ template class DnnPoolingOp<T>; \
+ template class DnnPoolingGradOp<T>;
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_DNN_OPS)
+#undef DEFINE_DNN_OPS
#endif // GOOGLE_CUDA