diff options
Diffstat (limited to 'tensorflow/core/kernels/cast_op.h')
-rw-r--r-- | tensorflow/core/kernels/cast_op.h | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/cast_op.h b/tensorflow/core/kernels/cast_op.h new file mode 100644 index 0000000000..d066206abc --- /dev/null +++ b/tensorflow/core/kernels/cast_op.h @@ -0,0 +1,71 @@ +#ifndef TENSORFLOW_KERNELS_CAST_OP_H_ +#define TENSORFLOW_KERNELS_CAST_OP_H_ + +#include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/port.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +template <typename Device, typename Tout, typename Tin> +void Cast(const Device& d, typename TTypes<Tout>::Flat o, + typename TTypes<Tin>::ConstFlat i) { + o.device(d) = i.template cast<Tout>(); +} + +template <typename Device, typename Tout, typename Tin> +struct CastFunctor { + void operator()(const Device& d, typename TTypes<Tout>::Flat o, + typename TTypes<Tin>::ConstFlat i); +}; + +} // end namespace functor +} // end namespace tensorflow + +namespace Eigen { +namespace internal { + +// Specialized cast op impls for bfloat16. +template <> +struct scalar_cast_op< ::tensorflow::bfloat16, float> { + EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) + typedef float result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()( + const ::tensorflow::bfloat16& a) const { + static_assert(::tensorflow::port::kLittleEndian, ""); + float ret; + uint16_t* p = reinterpret_cast<uint16_t*>(&ret); + p[0] = 0; + p[1] = a.value; + return ret; + } +}; + +template <> +struct functor_traits<scalar_cast_op< ::tensorflow::bfloat16, float> > { + enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; +}; + +template <> +struct scalar_cast_op<float, ::tensorflow::bfloat16> { + EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) + typedef ::tensorflow::bfloat16 result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const ::tensorflow::bfloat16 operator()( + const float a) const { + static_assert(::tensorflow::port::kLittleEndian, ""); + const uint16_t* p = reinterpret_cast<const uint16_t*>(&a); + return ::tensorflow::bfloat16(p[1]); + } +}; + +template <> +struct functor_traits<scalar_cast_op<float, ::tensorflow::bfloat16> > { + enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; +}; + +} // namespace internal +} // namespace Eigen + +#endif // TENSORFLOW_KERNELS_CAST_OP_H_ |