aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-05-04 13:49:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-04 14:54:21 -0700
commit094060ddad37291e3f09d1f9329c87041ead7436 (patch)
tree0e832659857a678cbe538561ff43f13b50e3522a /tensorflow/core/kernels
parent9d424acdde35b8ab2c067cd8ca58193740e46972 (diff)
Added support for half floats to the sparse cross entropy operation
Change: 121522621
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/sparse_xent_op.cc4
-rw-r--r--tensorflow/core/kernels/sparse_xent_op.h2
-rw-r--r--tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc8
3 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/sparse_xent_op.cc b/tensorflow/core/kernels/sparse_xent_op.cc
index 53b47aae2c..6e27c30a70 100644
--- a/tensorflow/core/kernels/sparse_xent_op.cc
+++ b/tensorflow/core/kernels/sparse_xent_op.cc
@@ -96,10 +96,14 @@ REGISTER(CPU, float, int32)
REGISTER(CPU, float, int64)
REGISTER(CPU, double, int32)
REGISTER(CPU, double, int64)
+REGISTER(CPU, Eigen::half, int32)
+REGISTER(CPU, Eigen::half, int64)
#if GOOGLE_CUDA
REGISTER(GPU, float, int32)
REGISTER(GPU, float, int64)
+REGISTER(GPU, Eigen::half, int32)
+REGISTER(GPU, Eigen::half, int64)
#endif // GOOGLE_CUDA
#undef REGISTER
diff --git a/tensorflow/core/kernels/sparse_xent_op.h b/tensorflow/core/kernels/sparse_xent_op.h
index 5fc81c6db2..c83160a7cd 100644
--- a/tensorflow/core/kernels/sparse_xent_op.h
+++ b/tensorflow/core/kernels/sparse_xent_op.h
@@ -64,7 +64,7 @@ class SparseXentLossGenerator {
int batch = coords[0];
int depth = coords[1];
return (labels_(batch) == depth)
- ? (std::log(sum_exp_logits_(batch)) - logits_(coords))
+ ? (Eigen::numext::log(sum_exp_logits_(batch)) - logits_(coords))
: T(0.0);
};
diff --git a/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc
index d2cba1761d..6d093689c7 100644
--- a/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc
@@ -42,9 +42,11 @@ struct SparseXentFunctor<GPUDevice, T, Index> {
} // end namespace functor
// Instantiate the GPU implementation for float.
-#define REGISTER(Index) \
- template struct functor::SparseXentFunctor<GPUDevice, float, Index>; \
- template class generator::SparseXentGradGenerator<float, Index>;
+#define REGISTER(Index) \
+ template struct functor::SparseXentFunctor<GPUDevice, float, Index>; \
+ template class generator::SparseXentGradGenerator<float, Index>; \
+ template struct functor::SparseXentFunctor<GPUDevice, Eigen::half, Index>; \
+ template class generator::SparseXentGradGenerator<Eigen::half, Index>;
REGISTER(int32)
REGISTER(int64)
#undef REGISTER