diff options
Diffstat (limited to 'tensorflow/core/kernels/sparse_matmul_op.h')
-rw-r--r-- | tensorflow/core/kernels/sparse_matmul_op.h | 16 |
1 files changed, 6 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/sparse_matmul_op.h b/tensorflow/core/kernels/sparse_matmul_op.h index bff6a0c9b3..61bd6593c3 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.h +++ b/tensorflow/core/kernels/sparse_matmul_op.h @@ -255,13 +255,12 @@ EIGEN_STRONG_INLINE Packet8d pbroadcast_second<Packet8d>(const Packet8d& a_in) { } template <> EIGEN_STRONG_INLINE Packet8d pbroadcast_third<Packet8d>(const Packet8d& a_in) { - Packet2d a = _mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1); + Packet2d a = _mm512_extractf32x4_ps(a_in, 1); return _mm512_broadcastsd_pd(a); } template <> EIGEN_STRONG_INLINE Packet8d pbroadcast_fourth<Packet8d>(const Packet8d& a_in) { - Packet2d a = - _mm_permute_pd(_mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1), 3); + Packet2d a = _mm_permute_pd(_mm512_extractf32x4_ps(a_in, 1), 3); return _mm512_broadcastsd_pd(a); } template <> @@ -418,17 +417,14 @@ EIGEN_STRONG_INLINE Packet8f pbroadcast_fourth<Packet8f>(const Packet8f& a) { template <typename Packet> EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_l(const Packet16f& from) { - return _mm512_castsi512_ps(_mm512_slli_epi32( - _mm512_cvtepu16_epi32(_mm512_castsi512_si256(_mm512_castps_si512(from))), - 16)); + return _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_castsi512_si256(from)), + 16); } template <typename Packet> EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) { - return _mm512_castsi512_ps( - _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_castpd_si256( - _mm512_extractf64x4_pd(_mm512_castps_pd(from), 1))), - 16)); + return _mm512_slli_epi32( + _mm512_cvtepu16_epi32(_mm512_extractf64x4_pd(from, 1)), 16); } #endif |