aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/sparse_matmul_op.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/sparse_matmul_op.h')
-rw-r--r--tensorflow/core/kernels/sparse_matmul_op.h16
1 files changed, 6 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/sparse_matmul_op.h b/tensorflow/core/kernels/sparse_matmul_op.h
index bff6a0c9b3..61bd6593c3 100644
--- a/tensorflow/core/kernels/sparse_matmul_op.h
+++ b/tensorflow/core/kernels/sparse_matmul_op.h
@@ -255,13 +255,12 @@ EIGEN_STRONG_INLINE Packet8d pbroadcast_second<Packet8d>(const Packet8d& a_in) {
}
template <>
EIGEN_STRONG_INLINE Packet8d pbroadcast_third<Packet8d>(const Packet8d& a_in) {
- Packet2d a = _mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1);
+ Packet2d a = _mm512_extractf32x4_ps(a_in, 1);
return _mm512_broadcastsd_pd(a);
}
template <>
EIGEN_STRONG_INLINE Packet8d pbroadcast_fourth<Packet8d>(const Packet8d& a_in) {
- Packet2d a =
- _mm_permute_pd(_mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1), 3);
+ Packet2d a = _mm_permute_pd(_mm512_extractf32x4_ps(a_in, 1), 3);
return _mm512_broadcastsd_pd(a);
}
template <>
@@ -418,17 +417,14 @@ EIGEN_STRONG_INLINE Packet8f pbroadcast_fourth<Packet8f>(const Packet8f& a) {
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_l(const Packet16f& from) {
- return _mm512_castsi512_ps(_mm512_slli_epi32(
- _mm512_cvtepu16_epi32(_mm512_castsi512_si256(_mm512_castps_si512(from))),
- 16));
+ return _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_castsi512_si256(from)),
+ 16);
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) {
- return _mm512_castsi512_ps(
- _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_castpd_si256(
- _mm512_extractf64x4_pd(_mm512_castps_pd(from), 1))),
- 16));
+ return _mm512_slli_epi32(
+ _mm512_cvtepu16_epi32(_mm512_extractf64x4_pd(from, 1)), 16);
}
#endif