From b8c8e5f436743ac6a6a5ed0ad7ec5cce7dd00248 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Mon, 14 May 2018 16:07:13 -0700 Subject: Add vectorized clip functor for Eigen Tensors. --- unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 6 ++++++ unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h | 19 +++++++++++++++++++ unsupported/test/cxx11_tensor_expr.cpp | 21 +++++++++++++++++++++ 3 files changed, 46 insertions(+) (limited to 'unsupported') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 1d459a3a0..c8dc3bbc5 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -209,6 +209,12 @@ class TensorBase return unaryExpr(internal::scalar_abs_op()); } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> + clip(Scalar min, Scalar max) const { + return unaryExpr(internal::scalar_clip_op(min, max)); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> conjugate() const { diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h index 5dcc3794c..4cee38339 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h @@ -487,6 +487,25 @@ struct functor_traits > { }; }; +template +struct scalar_clip_op { + EIGEN_DEVICE_FUNC inline scalar_clip_op(const Scalar& _min, const Scalar& _max) : m_min(_min), m_max(_max) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar + operator()(const Scalar& x) const { + return numext::mini(numext::maxi(x, m_min), m_max); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet + packetOp(const Packet& x) const { + return internal::pmin(internal::pmax(x, pset1(m_min)), pset1(m_max)); + } + const Scalar m_min; + const Scalar m_max; +}; +template +struct functor_traits > +{ enum { Cost = 2 * NumTraits::AddCost, PacketAccess = (packet_traits::HasMin && packet_traits::HasMax)}; }; + } // end namespace internal } // end namespace Eigen diff --git a/unsupported/test/cxx11_tensor_expr.cpp b/unsupported/test/cxx11_tensor_expr.cpp index 129b4e659..30e5a9860 100644 --- a/unsupported/test/cxx11_tensor_expr.cpp +++ b/unsupported/test/cxx11_tensor_expr.cpp @@ -340,6 +340,26 @@ void test_minmax_nan_propagation_templ() { } } +static void test_clip() +{ + Tensor vec(6); + vec(0) = 4.0; + vec(1) = 8.0; + vec(2) = 15.0; + vec(3) = 16.0; + vec(4) = 23.0; + vec(5) = 42.0; + + float kMin = 20; + float kMax = 30; + + Tensor vec_clipped(6); + vec_clipped = vec.clip(kMin, kMax); + for (int i = 0; i < 6; ++i) { + VERIFY_IS_EQUAL(vec_clipped(i), std::min(std::max(vec(i), kMin), kMax)); + } +} + static void test_minmax_nan_propagation() { test_minmax_nan_propagation_templ(); @@ -356,5 +376,6 @@ void test_cxx11_tensor_expr() CALL_SUBTEST(test_functors()); CALL_SUBTEST(test_type_casting()); CALL_SUBTEST(test_select()); + CALL_SUBTEST(test_clip()); CALL_SUBTEST(test_minmax_nan_propagation()); } -- cgit v1.2.3