diff options
Diffstat (limited to 'tensorflow/core/kernels/cast_op_impl_bfloat.cc')
-rw-r--r-- | tensorflow/core/kernels/cast_op_impl_bfloat.cc | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/cast_op_impl_bfloat.cc b/tensorflow/core/kernels/cast_op_impl_bfloat.cc index bfa7ba0d47..96aae15608 100644 --- a/tensorflow/core/kernels/cast_op_impl_bfloat.cc +++ b/tensorflow/core/kernels/cast_op_impl_bfloat.cc @@ -22,20 +22,19 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromBfloat(DataType dst_dtype) { +CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, bfloat16); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromBfloat(DataType dst_dtype) { +CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype) { if (dst_dtype == DT_FLOAT) { - return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { + return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out, + bool truncate) { functor::CastFunctor<GPUDevice, float, bfloat16> func; func(ctx->eigen_device<GPUDevice>(), out->flat<float>(), - inp.flat<bfloat16>()); + inp.flat<bfloat16>(), truncate); }; } return nullptr; |