aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2020-06-25 14:31:16 -0700
committerGravatar Antonio Sánchez <cantonios@google.com>2020-06-30 18:53:55 +0000
commit9cb8771e9c4a1f44ba59741c9fac495d1872bb25 (patch)
tree5348c34ac0673d09fe97aea29770e7b236e85510 /unsupported
parent145e51516fdac7b30d22c11c6878c2805fc3d724 (diff)
Fix tensor casts for large packets and casts to/from std::complex
The original tensor casts were only defined for `SrcCoeffRatio`:`TgtCoeffRatio` 1:1, 1:2, 2:1, 4:1. Here we add the missing 1:N and 8:1. We also add casting `Eigen::half` to/from `std::complex<T>`, which was missing to make it consistent with `Eigen:bfloat16`, and generalize the overload to work for any complex type. Tests were added to `basicstuff`, `packetmath`, and `cxx11_tensor_casts` to test all cast configurations.
Diffstat (limited to 'unsupported')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h33
-rw-r--r--unsupported/test/cxx11_tensor_casts.cpp81
2 files changed, 107 insertions, 7 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h
index cdbafbbb1..44493906d 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h
@@ -51,7 +51,10 @@ struct nested<TensorConversionOp<TargetType, XprType>, 1, typename eval<TensorCo
template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket, int SrcCoeffRatio, int TgtCoeffRatio>
-struct PacketConverter {
+struct PacketConverter;
+
+template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket>
+struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 1, 1> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
PacketConverter(const TensorEvaluator& impl)
: m_impl(impl) {}
@@ -109,7 +112,33 @@ struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 4, 1> {
};
template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket>
-struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 1, 2> {
+struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 8, 1> {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ PacketConverter(const TensorEvaluator& impl)
+ : m_impl(impl) {}
+
+ template<int LoadMode, typename Index>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TgtPacket packet(Index index) const {
+ const int SrcPacketSize = internal::unpacket_traits<SrcPacket>::size;
+
+ SrcPacket src1 = m_impl.template packet<LoadMode>(index);
+ SrcPacket src2 = m_impl.template packet<LoadMode>(index + 1 * SrcPacketSize);
+ SrcPacket src3 = m_impl.template packet<LoadMode>(index + 2 * SrcPacketSize);
+ SrcPacket src4 = m_impl.template packet<LoadMode>(index + 3 * SrcPacketSize);
+ SrcPacket src5 = m_impl.template packet<LoadMode>(index + 4 * SrcPacketSize);
+ SrcPacket src6 = m_impl.template packet<LoadMode>(index + 5 * SrcPacketSize);
+ SrcPacket src7 = m_impl.template packet<LoadMode>(index + 6 * SrcPacketSize);
+ SrcPacket src8 = m_impl.template packet<LoadMode>(index + 7 * SrcPacketSize);
+ TgtPacket result = internal::pcast<SrcPacket, TgtPacket>(src1, src2, src3, src4, src5, src6, src7, src8);
+ return result;
+ }
+
+ private:
+ const TensorEvaluator& m_impl;
+};
+
+template <typename TensorEvaluator, typename SrcPacket, typename TgtPacket, int TgtCoeffRatio>
+struct PacketConverter<TensorEvaluator, SrcPacket, TgtPacket, 1, TgtCoeffRatio> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
PacketConverter(const TensorEvaluator& impl)
: m_impl(impl), m_maxIndex(impl.dimensions().TotalSize()) {}
diff --git a/unsupported/test/cxx11_tensor_casts.cpp b/unsupported/test/cxx11_tensor_casts.cpp
index c4fe9a798..45456f3ef 100644
--- a/unsupported/test/cxx11_tensor_casts.cpp
+++ b/unsupported/test/cxx11_tensor_casts.cpp
@@ -8,6 +8,7 @@
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#include "main.h"
+#include "random_without_cast_overflow.h"
#include <Eigen/CXX11/Tensor>
@@ -104,12 +105,82 @@ static void test_small_to_big_type_cast()
}
}
+template <typename FromType, typename ToType>
+static void test_type_cast() {
+ Tensor<FromType, 2> ftensor(100, 200);
+ // Generate random values for a valid cast.
+ for (int i = 0; i < 100; ++i) {
+ for (int j = 0; j < 200; ++j) {
+ ftensor(i, j) = internal::random_without_cast_overflow<FromType,ToType>::value();
+ }
+ }
+
+ Tensor<ToType, 2> ttensor(100, 200);
+ ttensor = ftensor.template cast<ToType>();
+
+ for (int i = 0; i < 100; ++i) {
+ for (int j = 0; j < 200; ++j) {
+ const ToType ref = internal::cast<FromType,ToType>(ftensor(i, j));
+ VERIFY_IS_APPROX(ttensor(i, j), ref);
+ }
+ }
+}
+
+template<typename Scalar, typename EnableIf = void>
+struct test_cast_runner {
+ static void run() {
+ test_type_cast<Scalar, bool>();
+ test_type_cast<Scalar, int8_t>();
+ test_type_cast<Scalar, int16_t>();
+ test_type_cast<Scalar, int32_t>();
+ test_type_cast<Scalar, int64_t>();
+ test_type_cast<Scalar, uint8_t>();
+ test_type_cast<Scalar, uint16_t>();
+ test_type_cast<Scalar, uint32_t>();
+ test_type_cast<Scalar, uint64_t>();
+ test_type_cast<Scalar, half>();
+ test_type_cast<Scalar, bfloat16>();
+ test_type_cast<Scalar, float>();
+ test_type_cast<Scalar, double>();
+ test_type_cast<Scalar, std::complex<float>>();
+ test_type_cast<Scalar, std::complex<double>>();
+ }
+};
+
+// Only certain types allow cast from std::complex<>.
+template<typename Scalar>
+struct test_cast_runner<Scalar, typename internal::enable_if<NumTraits<Scalar>::IsComplex>::type> {
+ static void run() {
+ test_type_cast<Scalar, half>();
+ test_type_cast<Scalar, bfloat16>();
+ test_type_cast<Scalar, std::complex<float>>();
+ test_type_cast<Scalar, std::complex<double>>();
+ }
+};
+
EIGEN_DECLARE_TEST(cxx11_tensor_casts)
{
- CALL_SUBTEST(test_simple_cast());
- CALL_SUBTEST(test_vectorized_cast());
- CALL_SUBTEST(test_float_to_int_cast());
- CALL_SUBTEST(test_big_to_small_type_cast());
- CALL_SUBTEST(test_small_to_big_type_cast());
+ CALL_SUBTEST(test_simple_cast());
+ CALL_SUBTEST(test_vectorized_cast());
+ CALL_SUBTEST(test_float_to_int_cast());
+ CALL_SUBTEST(test_big_to_small_type_cast());
+ CALL_SUBTEST(test_small_to_big_type_cast());
+
+ CALL_SUBTEST(test_cast_runner<bool>::run());
+ CALL_SUBTEST(test_cast_runner<int8_t>::run());
+ CALL_SUBTEST(test_cast_runner<int16_t>::run());
+ CALL_SUBTEST(test_cast_runner<int32_t>::run());
+ CALL_SUBTEST(test_cast_runner<int64_t>::run());
+ CALL_SUBTEST(test_cast_runner<uint8_t>::run());
+ CALL_SUBTEST(test_cast_runner<uint16_t>::run());
+ CALL_SUBTEST(test_cast_runner<uint32_t>::run());
+ CALL_SUBTEST(test_cast_runner<uint64_t>::run());
+ CALL_SUBTEST(test_cast_runner<half>::run());
+ CALL_SUBTEST(test_cast_runner<bfloat16>::run());
+ CALL_SUBTEST(test_cast_runner<float>::run());
+ CALL_SUBTEST(test_cast_runner<double>::run());
+ CALL_SUBTEST(test_cast_runner<std::complex<float>>::run());
+ CALL_SUBTEST(test_cast_runner<std::complex<double>>::run());
+
}