diff options
author | Rasmus Munk Larsen <rmlarsen@google.com> | 2018-05-14 16:07:13 -0700 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2018-05-14 16:07:13 -0700 |
commit | b8c8e5f436743ac6a6a5ed0ad7ec5cce7dd00248 (patch) | |
tree | 8dd48fd88fa87fd9c492bfcbf5fd3ed4ea8cb150 /unsupported/Eigen/CXX11 | |
parent | 6118c6ff4f898dad99564ffdeb99036beb2ff0ea (diff) |
Add vectorized clip functor for Eigen Tensors.
Diffstat (limited to 'unsupported/Eigen/CXX11')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 6 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h | 19 |
2 files changed, 25 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 1d459a3a0..c8dc3bbc5 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -210,6 +210,12 @@ class TensorBase<Derived, ReadOnlyAccessors> } EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_clip_op<Scalar>, const Derived> + clip(Scalar min, Scalar max) const { + return unaryExpr(internal::scalar_clip_op<Scalar>(min, max)); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, const Derived> conjugate() const { return unaryExpr(internal::scalar_conjugate_op<Scalar>()); diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h index 5dcc3794c..4cee38339 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h @@ -487,6 +487,25 @@ struct functor_traits<GaussianGenerator<T, Index, NumDims> > { }; }; +template <typename Scalar> +struct scalar_clip_op { + EIGEN_DEVICE_FUNC inline scalar_clip_op(const Scalar& _min, const Scalar& _max) : m_min(_min), m_max(_max) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar + operator()(const Scalar& x) const { + return numext::mini(numext::maxi(x, m_min), m_max); + } + template <typename Packet> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet + packetOp(const Packet& x) const { + return internal::pmin(internal::pmax(x, pset1<Packet>(m_min)), pset1<Packet>(m_max)); + } + const Scalar m_min; + const Scalar m_max; +}; +template<typename Scalar> +struct functor_traits<scalar_clip_op<Scalar> > +{ enum { Cost = 2 * NumTraits<Scalar>::AddCost, PacketAccess = (packet_traits<Scalar>::HasMin && packet_traits<Scalar>::HasMax)}; }; + } // end namespace internal } // end namespace Eigen |