diff options
Diffstat (limited to 'tensorflow/core/kernels/cast_op.h')
-rw-r--r-- | tensorflow/core/kernels/cast_op.h | 166 |
1 files changed, 163 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/cast_op.h b/tensorflow/core/kernels/cast_op.h index 16d2e0e0a5..527ab528c9 100644 --- a/tensorflow/core/kernels/cast_op.h +++ b/tensorflow/core/kernels/cast_op.h @@ -24,8 +24,71 @@ limitations under the License. #include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/types.h" +// Note that the GPU cast functor templates need to be instantiated unlike the +// CPU ones, and hence their specializations are different than that for CPUs. +#ifdef SPECIALIZE_FOR_GPUS +#define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \ + template <typename Device> \ + struct CastFunctor<Device, OUT_TYPE, IN_OUT> { \ + void operator()(const Device& d, \ + typename TTypes<OUT_TYPE>::Flat out_tensor, \ + typename TTypes<IN_OUT>::ConstFlat in_tensor, \ + bool truncate = false) { \ + if (truncate) { \ + out_tensor.device(d) = \ + in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \ + .template cast<OUT_TYPE>(); \ + } else { \ + out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \ + } \ + } \ + }; \ + template struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT>; +#else +#define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \ + template <> \ + struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT> { \ + void operator()(const DEVICE& d, \ + typename TTypes<OUT_TYPE>::Flat out_tensor, \ + typename TTypes<IN_OUT>::ConstFlat in_tensor, \ + bool truncate = false) { \ + if (truncate) { \ + out_tensor.device(d) = \ + in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \ + .template cast<OUT_TYPE>(); \ + } else { \ + out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \ + } \ + } \ + }; +#endif + +#define CAST_FUNCTORS(devname) \ + SPECIALIZE_CAST(devname, float, double) \ + SPECIALIZE_CAST(devname, float, std::complex<double>) \ + SPECIALIZE_CAST(devname, std::complex<float>, std::complex<double>) \ + SPECIALIZE_CAST(devname, std::complex<float>, double) \ + SPECIALIZE_CAST(devname, Eigen::half, double) \ + SPECIALIZE_CAST(devname, Eigen::half, float) \ + SPECIALIZE_CAST(devname, Eigen::half, std::complex<double>) \ + SPECIALIZE_CAST(devname, Eigen::half, std::complex<float>) \ + SPECIALIZE_CAST(devname, bfloat16, float) \ + template <typename OUT_TYPE, typename IN_OUT> \ + struct CastFunctor<devname, OUT_TYPE, IN_OUT> { \ + void operator()(const devname& d, \ + typename TTypes<OUT_TYPE>::Flat out_tensor, \ + typename TTypes<IN_OUT>::ConstFlat in_tensor, \ + bool truncate = false) { \ + out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \ + } \ + }; + namespace tensorflow { +typedef std::function<void(OpKernelContext*, const Tensor&, Tensor*, + bool trunc)> + CastFunctorType; + // Common base class of Cast kernels class CastOpBase : public OpKernel { public: @@ -36,8 +99,10 @@ class CastOpBase : public OpKernel { protected: DataType src_dtype_; DataType dst_dtype_; - std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr; - + DataType external_src_dtype_; + DataType external_dst_dtype_; + bool use_truncation_; + CastFunctorType work_ = nullptr; Status Unimplemented(); TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase); @@ -54,6 +119,23 @@ class CpuCastOp : public CastOpBase { namespace functor { +template <typename I> +constexpr int MantissaWidth() { + return std::numeric_limits<I>::digits; +} + +template <> +constexpr int MantissaWidth<Eigen::half>() { + // Remember, there's 1 hidden bit + return 10 + 1; +} + +template <> +constexpr int MantissaWidth<bfloat16>() { + // Remember, there's 1 hidden bit + return 7 + 1; +} + template <typename Device, typename Tout, typename Tin> void Cast(const Device& d, typename TTypes<Tout>::Flat o, typename TTypes<Tin>::ConstFlat i) { @@ -63,7 +145,85 @@ void Cast(const Device& d, typename TTypes<Tout>::Flat o, template <typename Device, typename Tout, typename Tin> struct CastFunctor { void operator()(const Device& d, typename TTypes<Tout>::Flat o, - typename TTypes<Tin>::ConstFlat i); + typename TTypes<Tin>::ConstFlat i, bool truncate = false); +}; + +// Only enable LSBZeroSetterHelper for 64 and 32 bit input data types. +// Specialize for others if needed in future. +template <typename I> +typename std::enable_if<sizeof(I) == 8, void>::type EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) { + // Only zero the bits for non-NaNs. + // For NaNs, let the non-truncation version handle it. + if (!std::isnan(t)) { + uint64_t* p = reinterpret_cast<uint64_t*>(&t); + *p &= (0xFFFFFFFFFFFFFFFF << n); + } +} + +template <typename I> +typename std::enable_if<sizeof(I) == 4, void>::type EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) { + // Only zero the bits for non-NaNs. + // For NaNs, let the non-truncation version handle it. + if (!std::isnan(t)) { + uint32_t* p = reinterpret_cast<uint32_t*>(&t); + *p &= (0xFFFFFFFF << n); + } +} + +// Set n least significant bits to 0 +template <typename I, typename O> +struct LSBZeroSetter { + EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter) + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const I operator()(const I& a) const { + constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>(); + static_assert( + bits > 0, + "The output type must have fewer mantissa bits than the input type\n"); + I t = a; + LSBZeroSetterHelper(t, bits); + return t; + } +}; + +template <typename I, typename O> +struct LSBZeroSetter<std::complex<I>, std::complex<O>> { + EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter) + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()( + const std::complex<I>& a) const { + constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>(); + static_assert( + bits > 0, + "The output type must have fewer mantissa bits than the input type\n"); + I re = std::real(a); + I img = std::imag(a); + LSBZeroSetterHelper(re, bits); + LSBZeroSetterHelper(img, bits); + std::complex<I> toReturn(re, img); + return toReturn; + } +}; + +template <typename I, typename O> +struct LSBZeroSetter<std::complex<I>, O> { + EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter) + // Sets the 16 LSBits of the float to 0 + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()( + const std::complex<I>& a) const { + constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>(); + static_assert( + bits > 0, + "The output type must have fewer mantissa bits than the input type\n"); + I re = std::real(a); + I img = std::imag(a); + LSBZeroSetterHelper(re, bits); + LSBZeroSetterHelper(img, bits); + std::complex<I> toReturn(re, img); + return toReturn; + } }; } // end namespace functor |