aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-05-04 13:49:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-04 14:54:21 -0700
commit094060ddad37291e3f09d1f9329c87041ead7436 (patch)
tree0e832659857a678cbe538561ff43f13b50e3522a /tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc
parent9d424acdde35b8ab2c067cd8ca58193740e46972 (diff)
Added support for half floats to the sparse cross entropy operation
Change: 121522621
Diffstat (limited to 'tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc
index d2cba1761d..6d093689c7 100644
--- a/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc
@@ -42,9 +42,11 @@ struct SparseXentFunctor<GPUDevice, T, Index> {
} // end namespace functor
// Instantiate the GPU implementation for float.
-#define REGISTER(Index) \
- template struct functor::SparseXentFunctor<GPUDevice, float, Index>; \
- template class generator::SparseXentGradGenerator<float, Index>;
+#define REGISTER(Index) \
+ template struct functor::SparseXentFunctor<GPUDevice, float, Index>; \
+ template class generator::SparseXentGradGenerator<float, Index>; \
+ template struct functor::SparseXentFunctor<GPUDevice, Eigen::half, Index>; \
+ template class generator::SparseXentGradGenerator<Eigen::half, Index>;
REGISTER(int32)
REGISTER(int64)
#undef REGISTER