diff options
author | 2016-05-04 13:49:36 -0800 | |
---|---|---|
committer | 2016-05-04 14:54:21 -0700 | |
commit | 094060ddad37291e3f09d1f9329c87041ead7436 (patch) | |
tree | 0e832659857a678cbe538561ff43f13b50e3522a /tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc | |
parent | 9d424acdde35b8ab2c067cd8ca58193740e46972 (diff) |
Added support for half floats to the sparse cross entropy operation
Change: 121522621
Diffstat (limited to 'tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc index d2cba1761d..6d093689c7 100644 --- a/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc +++ b/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc @@ -42,9 +42,11 @@ struct SparseXentFunctor<GPUDevice, T, Index> { } // end namespace functor // Instantiate the GPU implementation for float. -#define REGISTER(Index) \ - template struct functor::SparseXentFunctor<GPUDevice, float, Index>; \ - template class generator::SparseXentGradGenerator<float, Index>; +#define REGISTER(Index) \ + template struct functor::SparseXentFunctor<GPUDevice, float, Index>; \ + template class generator::SparseXentGradGenerator<float, Index>; \ + template struct functor::SparseXentFunctor<GPUDevice, Eigen::half, Index>; \ + template class generator::SparseXentGradGenerator<Eigen::half, Index>; REGISTER(int32) REGISTER(int64) #undef REGISTER |