aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/Default/TypeCasting.h
diff options
context:
space:
mode:
authorGravatar Teng Lu <teng.lu@intel.com>2020-06-20 19:16:24 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-06-20 19:16:24 +0000
commit386d809bde475c65b7940f290efe80e6a05878c4 (patch)
treec38e161a53393d15be0ddb02a7a4e22dec738484 /Eigen/src/Core/arch/Default/TypeCasting.h
parent6b9c92fe7eff0dedb031cec38004c9c3667f3057 (diff)
Support BFloat16 in Eigen
Diffstat (limited to 'Eigen/src/Core/arch/Default/TypeCasting.h')
-rw-r--r--Eigen/src/Core/arch/Default/TypeCasting.h43
1 files changed, 43 insertions, 0 deletions
diff --git a/Eigen/src/Core/arch/Default/TypeCasting.h b/Eigen/src/Core/arch/Default/TypeCasting.h
index b6df98468..fb8183b78 100644
--- a/Eigen/src/Core/arch/Default/TypeCasting.h
+++ b/Eigen/src/Core/arch/Default/TypeCasting.h
@@ -71,6 +71,49 @@ template<>
struct functor_traits<scalar_cast_op<Eigen::half, float> >
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+template<>
+struct scalar_cast_op<float, Eigen::bfloat16> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef Eigen::bfloat16 result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const float& a) const {
+ return Eigen::bfloat16(a);
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<float, Eigen::bfloat16> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+template<>
+struct scalar_cast_op<int, Eigen::bfloat16> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef Eigen::bfloat16 result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const int& a) const {
+ return Eigen::bfloat16(static_cast<float>(a));
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<int, Eigen::bfloat16> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
+template<>
+struct scalar_cast_op<Eigen::bfloat16, float> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef float result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::bfloat16& a) const {
+ return static_cast<float>(a);
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<Eigen::bfloat16, float> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
}
}