aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-05-14 16:07:13 -0700
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-05-14 16:07:13 -0700
commitb8c8e5f436743ac6a6a5ed0ad7ec5cce7dd00248 (patch)
tree8dd48fd88fa87fd9c492bfcbf5fd3ed4ea8cb150 /unsupported
parent6118c6ff4f898dad99564ffdeb99036beb2ff0ea (diff)
Add vectorized clip functor for Eigen Tensors.
Diffstat (limited to 'unsupported')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBase.h6
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h19
-rw-r--r--unsupported/test/cxx11_tensor_expr.cpp21
3 files changed, 46 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
index 1d459a3a0..c8dc3bbc5 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
@@ -210,6 +210,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
}
EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_clip_op<Scalar>, const Derived>
+ clip(Scalar min, Scalar max) const {
+ return unaryExpr(internal::scalar_clip_op<Scalar>(min, max));
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, const Derived>
conjugate() const {
return unaryExpr(internal::scalar_conjugate_op<Scalar>());
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
index 5dcc3794c..4cee38339 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
@@ -487,6 +487,25 @@ struct functor_traits<GaussianGenerator<T, Index, NumDims> > {
};
};
+template <typename Scalar>
+struct scalar_clip_op {
+ EIGEN_DEVICE_FUNC inline scalar_clip_op(const Scalar& _min, const Scalar& _max) : m_min(_min), m_max(_max) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
+ operator()(const Scalar& x) const {
+ return numext::mini(numext::maxi(x, m_min), m_max);
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
+ packetOp(const Packet& x) const {
+ return internal::pmin(internal::pmax(x, pset1<Packet>(m_min)), pset1<Packet>(m_max));
+ }
+ const Scalar m_min;
+ const Scalar m_max;
+};
+template<typename Scalar>
+struct functor_traits<scalar_clip_op<Scalar> >
+{ enum { Cost = 2 * NumTraits<Scalar>::AddCost, PacketAccess = (packet_traits<Scalar>::HasMin && packet_traits<Scalar>::HasMax)}; };
+
} // end namespace internal
} // end namespace Eigen
diff --git a/unsupported/test/cxx11_tensor_expr.cpp b/unsupported/test/cxx11_tensor_expr.cpp
index 129b4e659..30e5a9860 100644
--- a/unsupported/test/cxx11_tensor_expr.cpp
+++ b/unsupported/test/cxx11_tensor_expr.cpp
@@ -340,6 +340,26 @@ void test_minmax_nan_propagation_templ() {
}
}
+static void test_clip()
+{
+ Tensor<float, 1> vec(6);
+ vec(0) = 4.0;
+ vec(1) = 8.0;
+ vec(2) = 15.0;
+ vec(3) = 16.0;
+ vec(4) = 23.0;
+ vec(5) = 42.0;
+
+ float kMin = 20;
+ float kMax = 30;
+
+ Tensor<float, 1> vec_clipped(6);
+ vec_clipped = vec.clip(kMin, kMax);
+ for (int i = 0; i < 6; ++i) {
+ VERIFY_IS_EQUAL(vec_clipped(i), std::min(std::max(vec(i), kMin), kMax));
+ }
+}
+
static void test_minmax_nan_propagation()
{
test_minmax_nan_propagation_templ<float>();
@@ -356,5 +376,6 @@ void test_cxx11_tensor_expr()
CALL_SUBTEST(test_functors());
CALL_SUBTEST(test_type_casting());
CALL_SUBTEST(test_select());
+ CALL_SUBTEST(test_clip());
CALL_SUBTEST(test_minmax_nan_propagation());
}