aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-10-21 11:28:28 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-10-21 11:28:28 -0700
commitb178cc347968675bdae942dbdcb7de9ed9daa564 (patch)
treef2e931dd8bf7d8f6710264ac366f4673b01e736a
parent5ca2e25967ee82df3c2347223ad8a5cde5070eb6 (diff)
Added some syntactic sugar to make it simpler to compare a tensor to a scalar.
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBase.h32
-rw-r--r--unsupported/test/CMakeLists.txt1
-rw-r--r--unsupported/test/cxx11_tensor_sugar.cpp38
3 files changed, 71 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
index c00f67950..1b85f5ef5 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
@@ -279,6 +279,38 @@ class TensorBase<Derived, ReadOnlyAccessors>
return binaryExpr(other.derived(), std::not_equal_to<Scalar>());
}
+ // comparisons and tests for Scalars
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::less<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
+ operator<(Scalar threshold) const {
+ return operator<(constant(threshold));
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::less_equal<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
+ operator<=(Scalar threshold) const {
+ return operator<=(constant(threshold));
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::greater<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
+ operator>(Scalar threshold) const {
+ return operator>(constant(threshold));
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::greater_equal<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
+ operator>=(Scalar threshold) const {
+ return operator>=(constant(threshold));
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::equal_to<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
+ operator==(Scalar threshold) const {
+ return operator==(constant(threshold));
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
+ operator!=(Scalar threshold) const {
+ return operator!=(constant(threshold));
+ }
+
// Coefficient-wise ternary operators.
template<typename ThenDerived, typename ElseDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>
diff --git a/unsupported/test/CMakeLists.txt b/unsupported/test/CMakeLists.txt
index 8865892e6..5a9ed5730 100644
--- a/unsupported/test/CMakeLists.txt
+++ b/unsupported/test/CMakeLists.txt
@@ -143,6 +143,7 @@ if(EIGEN_TEST_CXX11)
ei_add_test(cxx11_tensor_generator "-std=c++0x")
ei_add_test(cxx11_tensor_custom_op "-std=c++0x")
ei_add_test(cxx11_tensor_custom_index "-std=c++0x")
+ ei_add_test(cxx11_tensor_sugar "-std=c++0x")
# These tests needs nvcc
# ei_add_test(cxx11_tensor_device "-std=c++0x")
diff --git a/unsupported/test/cxx11_tensor_sugar.cpp b/unsupported/test/cxx11_tensor_sugar.cpp
new file mode 100644
index 000000000..7848acc8b
--- /dev/null
+++ b/unsupported/test/cxx11_tensor_sugar.cpp
@@ -0,0 +1,38 @@
+#include "main.h"
+
+#include <Eigen/CXX11/Tensor>
+
+using Eigen::Tensor;
+using Eigen::RowMajor;
+
+static void test_comparison_sugar() {
+ // we already trust comparisons between tensors, we're simply checking that
+ // the sugared versions are doing the same thing
+ Tensor<int, 3> t(6, 7, 5);
+
+ t.setRandom();
+ // make sure we have at least one value == 0
+ t(0,0,0) = 0;
+
+ Tensor<bool,1> b;
+
+#define TEST_TENSOR_EQUAL(e1, e2) \
+ b = ((e1) == (e2)).all(); \
+ VERIFY(b(0))
+
+#define TEST_OP(op) TEST_TENSOR_EQUAL(t op 0, t op t.constant(0))
+
+ TEST_OP(==);
+ TEST_OP(!=);
+ TEST_OP(<=);
+ TEST_OP(>=);
+ TEST_OP(<);
+ TEST_OP(>);
+#undef TEST_OP
+#undef TEST_TENSOR_EQUAL
+}
+
+void test_cxx11_tensor_sugar()
+{
+ CALL_SUBTEST(test_comparison_sugar());
+}