diff options
-rw-r--r-- | unsupported/Eigen/CXX11/Tensor | 2 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h | 1 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorRef.h | 360 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h | 40 | ||||
-rw-r--r-- | unsupported/test/CMakeLists.txt | 1 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_ref.cpp | 192 |
6 files changed, 596 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/Tensor b/unsupported/Eigen/CXX11/Tensor index 47447f446..c36db96ec 100644 --- a/unsupported/Eigen/CXX11/Tensor +++ b/unsupported/Eigen/CXX11/Tensor @@ -76,6 +76,8 @@ #include "unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorMap.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorRef.h" + #include "unsupported/Eigen/CXX11/src/Tensor/TensorIO.h" #include "Eigen/src/Core/util/ReenableStupidWarnings.h" diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h index 67f478822..a72e11215 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h @@ -15,6 +15,7 @@ namespace Eigen { template<typename Scalar_, std::size_t NumIndices_, int Options_ = 0> class Tensor; template<typename Scalar_, typename Dimensions, int Options_ = 0> class TensorFixedSize; template<typename PlainObjectType, int Options_ = Unaligned> class TensorMap; +template<typename PlainObjectType> class TensorRef; template<typename Derived, int AccessLevel = internal::accessors_level<Derived>::value> class TensorBase; template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryOp; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h new file mode 100644 index 000000000..db2027a5f --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h @@ -0,0 +1,360 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_TENSOR_TENSOR_REF_H +#define EIGEN_CXX11_TENSOR_TENSOR_REF_H + +namespace Eigen { + +namespace internal { + +template <typename Dimensions, typename Scalar> +class TensorLazyBaseEvaluator { + public: + TensorLazyBaseEvaluator() : m_refcount(0) { } + virtual ~TensorLazyBaseEvaluator() { } + + virtual const Dimensions& dimensions() const = 0; + virtual const Scalar* data() const = 0; + + virtual const Scalar coeff(DenseIndex index) const = 0; + virtual Scalar& coeffRef(DenseIndex index) = 0; + + void incrRefCount() { ++m_refcount; } + void decrRefCount() { --m_refcount; } + int refCount() const { return m_refcount; } + + private: + // No copy, no assigment; + TensorLazyBaseEvaluator(const TensorLazyBaseEvaluator& other); + TensorLazyBaseEvaluator& operator = (const TensorLazyBaseEvaluator& other); + + int m_refcount; +}; + +static char dummy[8]; + +template <typename Dimensions, typename Expr, typename Device> +class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator<Dimensions, typename TensorEvaluator<Expr, Device>::Scalar> { + public: + // typedef typename TensorEvaluator<Expr, Device>::Dimensions Dimensions; + typedef typename TensorEvaluator<Expr, Device>::Scalar Scalar; + + TensorLazyEvaluatorReadOnly(const Expr& expr, const Device& device) : m_impl(expr, device) { + m_dims = m_impl.dimensions(); + m_impl.evalSubExprsIfNeeded(NULL); + } + virtual ~TensorLazyEvaluatorReadOnly() { + m_impl.cleanup(); + } + + virtual const Dimensions& dimensions() const { + return m_dims; + } + virtual const Scalar* data() const { + return m_impl.data(); + } + + virtual const Scalar coeff(DenseIndex index) const { + return m_impl.coeff(index); + } + virtual Scalar& coeffRef(DenseIndex index) { + eigen_assert(false && "can't reference the coefficient of a rvalue"); + return *reinterpret_cast<Scalar*>(dummy); + }; + + protected: + TensorEvaluator<Expr, Device> m_impl; + Dimensions m_dims; +}; + +template <typename Dimensions, typename Expr, typename Device> +class TensorLazyEvaluatorWritable : public TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> { + public: + typedef TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> Base; + typedef typename Base::Scalar Scalar; + + TensorLazyEvaluatorWritable(const Expr& expr, const Device& device) : Base(expr, device) { + } + virtual ~TensorLazyEvaluatorWritable() { + } + + virtual Scalar& coeffRef(DenseIndex index) { + return this->m_impl.coeffRef(index); + } +}; + +template <typename Dimensions, typename Expr, typename Device> +class TensorLazyEvaluator : public internal::conditional<bool(internal::is_lvalue<Expr>::value), + TensorLazyEvaluatorWritable<Dimensions, Expr, Device>, + TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type { + public: + typedef typename internal::conditional<bool(internal::is_lvalue<Expr>::value), + TensorLazyEvaluatorWritable<Dimensions, Expr, Device>, + TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type Base; + typedef typename Base::Scalar Scalar; + + TensorLazyEvaluator(const Expr& expr, const Device& device) : Base(expr, device) { + } + virtual ~TensorLazyEvaluator() { + } +}; + +} // namespace internal + + +/** \class TensorRef + * \ingroup CXX11_Tensor_Module + * + * \brief A reference to a tensor expression + * The expression will be evaluated lazily (as much as possible). + * + */ +template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef<PlainObjectType> > +{ + public: + typedef TensorRef<PlainObjectType> Self; + typedef typename PlainObjectType::Base Base; + typedef typename Eigen::internal::nested<Self>::type Nested; + typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind; + typedef typename internal::traits<PlainObjectType>::Index Index; + typedef typename internal::traits<PlainObjectType>::Scalar Scalar; + typedef typename internal::packet_traits<Scalar>::type Packet; + typedef typename NumTraits<Scalar>::Real RealScalar; + typedef typename Base::CoeffReturnType CoeffReturnType; + typedef Scalar* PointerType; + typedef PointerType PointerArgType; + + static const Index NumIndices = PlainObjectType::NumIndices; + typedef typename PlainObjectType::Dimensions Dimensions; + + enum { + IsAligned = false, + PacketAccess = false, + }; + + EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) { + } + + template <typename Expression> + EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : m_evaluator(new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice())) { + m_evaluator->incrRefCount(); + } + + template <typename Expression> + EIGEN_STRONG_INLINE TensorRef& operator = (const Expression& expr) { + unrefEvaluator(); + m_evaluator = new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice()); + m_evaluator->incrRefCount(); + return *this; + } + + ~TensorRef() { + unrefEvaluator(); + } + + TensorRef(const TensorRef& other) : m_evaluator(other.m_evaluator) { + eigen_assert(m_evaluator->refCount() > 0); + m_evaluator->incrRefCount(); + } + + TensorRef& operator = (const TensorRef& other) { + if (this != &other) { + unrefEvaluator(); + m_evaluator = other.m_evaluator; + eigen_assert(m_evaluator->refCount() > 0); + m_evaluator->incrRefCount(); + } + return *this; + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Index size() const { return m_evaluator->dimensions().TotalSize(); } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar* data() const { return m_evaluator->data(); } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar operator()(Index index) const + { + return m_evaluator->coeff(index); + } + +#ifdef EIGEN_HAS_VARIADIC_TEMPLATES + template<typename... IndexTypes> EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar operator()(Index firstIndex, IndexTypes... otherIndices) const + { + const std::size_t NumIndices = (sizeof...(otherIndices) + 1); + const array<Index, NumIndices> indices{{firstIndex, otherIndices...}}; + return coeff(indices); + } +#else + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1) const + { + array<Index, 2> indices; + indices[0] = i0; + indices[1] = i1; + return coeff(indices); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2) const + { + array<Index, 3> indices; + indices[0] = i0; + indices[1] = i1; + indices[2] = i2; + return coeff(indices); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3) const + { + array<Index, 4> indices; + indices[0] = i0; + indices[1] = i1; + indices[2] = i2; + indices[3] = i3; + return coeff(indices); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const + { + array<Index, 5> indices; + indices[0] = i0; + indices[1] = i1; + indices[2] = i2; + indices[3] = i3; + indices[4] = i4; + return coeff(indices); + } +#endif + + template <std::size_t NumIndices> EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar coeff(const array<Index, NumIndices>& indices) const + { + const Dimensions& dims = this->dimensions(); + Index index = 0; + if (PlainObjectType::Options&RowMajor) { + index += indices[0]; + for (int i = 1; i < NumIndices; ++i) { + index = index * dims[i] + indices[i]; + } + } else { + index += indices[NumIndices-1]; + for (int i = NumIndices-2; i >= 0; --i) { + index = index * dims[i] + indices[i]; + } + } + return m_evaluator->coeff(index); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Scalar coeff(Index index) const + { + return m_evaluator->coeff(index); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) + { + return m_evaluator->coeffRef(index); + } + + private: + EIGEN_STRONG_INLINE void unrefEvaluator() { + if (m_evaluator) { + m_evaluator->decrRefCount(); + if (m_evaluator->refCount() == 0) { + delete m_evaluator; + } + } + } + + internal::TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator; +}; + + +// evaluator for rvalues +template<typename Derived, typename Device> +struct TensorEvaluator<const TensorRef<Derived>, Device> +{ + typedef typename Derived::Index Index; + typedef typename Derived::Scalar Scalar; + typedef typename Derived::Packet Packet; + typedef typename Derived::Scalar CoeffReturnType; + typedef typename Derived::Packet PacketReturnType; + typedef typename Derived::Dimensions Dimensions; + + enum { + IsAligned = false, + PacketAccess = false, + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&) + : m_ref(m) + { } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_ref.dimensions(); } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { + return true; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { + return m_ref.coeff(index); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { + return m_ref.coeffRef(index); + } + + Scalar* data() const { return m_ref.data(); } + + protected: + TensorRef<Derived> m_ref; +}; + + +// evaluator for lvalues +template<typename Derived, typename Device> +struct TensorEvaluator<TensorRef<Derived>, Device> : public TensorEvaluator<const TensorRef<Derived>, Device> +{ + typedef typename Derived::Index Index; + typedef typename Derived::Scalar Scalar; + typedef typename Derived::Packet Packet; + typedef typename Derived::Scalar CoeffReturnType; + typedef typename Derived::Packet PacketReturnType; + typedef typename Derived::Dimensions Dimensions; + + typedef TensorEvaluator<const TensorRef<Derived>, Device> Base; + + enum { + IsAligned = false, + PacketAccess = false, + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(TensorRef<Derived>& m, const Device& d) : Base(m, d) + { } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { + return this->m_ref.coeffRef(index); + } +}; + + + +} // end namespace Eigen + +#endif // EIGEN_CXX11_TENSOR_TENSOR_REF_H diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h b/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h index 5940a8cf1..5c0f78489 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h @@ -84,6 +84,20 @@ struct traits<TensorMap<PlainObjectType, Options_> > }; }; +template<typename PlainObjectType> +struct traits<TensorRef<PlainObjectType> > + : public traits<PlainObjectType> +{ + typedef traits<PlainObjectType> BaseTraits; + typedef typename BaseTraits::Scalar Scalar; + typedef typename BaseTraits::StorageKind StorageKind; + typedef typename BaseTraits::Index Index; + enum { + Options = BaseTraits::Options, + Flags = ((BaseTraits::Flags | LvalueBit) & ~AlignedBit) | (Options&Aligned ? AlignedBit : 0), + }; +}; + template<typename _Scalar, std::size_t NumIndices_, int Options> struct eval<Tensor<_Scalar, NumIndices_, Options>, Eigen::Dense> @@ -121,6 +135,19 @@ struct eval<const TensorMap<PlainObjectType, Options>, Eigen::Dense> typedef const TensorMap<PlainObjectType, Options>& type; }; +template<typename PlainObjectType> +struct eval<TensorRef<PlainObjectType>, Eigen::Dense> +{ + typedef const TensorRef<PlainObjectType>& type; +}; + +template<typename PlainObjectType> +struct eval<const TensorRef<PlainObjectType>, Eigen::Dense> +{ + typedef const TensorRef<PlainObjectType>& type; +}; + + template <typename Scalar_, std::size_t NumIndices_, int Options_> struct nested<Tensor<Scalar_, NumIndices_, Options_>, 1, typename eval<Tensor<Scalar_, NumIndices_, Options_> >::type> { @@ -145,6 +172,7 @@ struct nested<const TensorFixedSize<Scalar_, Dimensions, Options>, 1, typename e typedef const TensorFixedSize<Scalar_, Dimensions, Options>& type; }; + template <typename PlainObjectType, int Options> struct nested<TensorMap<PlainObjectType, Options>, 1, typename eval<TensorMap<PlainObjectType, Options> >::type> { @@ -157,6 +185,18 @@ struct nested<const TensorMap<PlainObjectType, Options>, 1, typename eval<Tensor typedef const TensorMap<PlainObjectType, Options>& type; }; +template <typename PlainObjectType> +struct nested<TensorRef<PlainObjectType>, 1, typename eval<TensorRef<PlainObjectType> >::type> +{ + typedef const TensorRef<PlainObjectType>& type; +}; + +template <typename PlainObjectType> +struct nested<const TensorRef<PlainObjectType>, 1, typename eval<TensorRef<PlainObjectType> >::type> +{ + typedef const TensorRef<PlainObjectType>& type; +}; + } // end namespace internal } // end namespace Eigen diff --git a/unsupported/test/CMakeLists.txt b/unsupported/test/CMakeLists.txt index a7ef2b402..2b5395013 100644 --- a/unsupported/test/CMakeLists.txt +++ b/unsupported/test/CMakeLists.txt @@ -126,5 +126,6 @@ if(EIGEN_TEST_CXX11) ei_add_test(cxx11_tensor_striding "-std=c++0x") # ei_add_test(cxx11_tensor_device "-std=c++0x") ei_add_test(cxx11_tensor_thread_pool "-std=c++0x") + ei_add_test(cxx11_tensor_ref "-std=c++0x") ei_add_test(cxx11_tensor_io "-std=c++0x") endif() diff --git a/unsupported/test/cxx11_tensor_ref.cpp b/unsupported/test/cxx11_tensor_ref.cpp new file mode 100644 index 000000000..4ff94a059 --- /dev/null +++ b/unsupported/test/cxx11_tensor_ref.cpp @@ -0,0 +1,192 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "main.h" + +#include <Eigen/CXX11/Tensor> + +using Eigen::Tensor; +using Eigen::RowMajor; + +static void test_simple_lvalue_ref() +{ + Tensor<int, 1> input(6); + input.setRandom(); + + TensorRef<Tensor<int, 1>> ref3(input); + TensorRef<Tensor<int, 1>> ref4 = input; + + VERIFY_IS_EQUAL(ref3.data(), input.data()); + VERIFY_IS_EQUAL(ref4.data(), input.data()); + + for (int i = 0; i < 6; ++i) { + VERIFY_IS_EQUAL(ref3(i), input(i)); + VERIFY_IS_EQUAL(ref4(i), input(i)); + } + + for (int i = 0; i < 6; ++i) { + ref3.coeffRef(i) = i; + } + for (int i = 0; i < 6; ++i) { + VERIFY_IS_EQUAL(input(i), i); + } + for (int i = 0; i < 6; ++i) { + ref4.coeffRef(i) = -i * 2; + } + for (int i = 0; i < 6; ++i) { + VERIFY_IS_EQUAL(input(i), -i*2); + } +} + + +static void test_simple_rvalue_ref() +{ + Tensor<int, 1> input1(6); + input1.setRandom(); + Tensor<int, 1> input2(6); + input2.setRandom(); + + TensorRef<Tensor<int, 1>> ref3(input1 + input2); + TensorRef<Tensor<int, 1>> ref4 = input1 + input2; + + VERIFY_IS_NOT_EQUAL(ref3.data(), input1.data()); + VERIFY_IS_NOT_EQUAL(ref4.data(), input1.data()); + VERIFY_IS_NOT_EQUAL(ref3.data(), input2.data()); + VERIFY_IS_NOT_EQUAL(ref4.data(), input2.data()); + + for (int i = 0; i < 6; ++i) { + VERIFY_IS_EQUAL(ref3(i), input1(i) + input2(i)); + VERIFY_IS_EQUAL(ref4(i), input1(i) + input2(i)); + } +} + + +static void test_multiple_dims() +{ + Tensor<float, 3> input(3,5,7); + input.setRandom(); + + TensorRef<Tensor<float, 3>> ref(input); + VERIFY_IS_EQUAL(ref.data(), input.data()); + VERIFY_IS_EQUAL(ref.dimension(0), 3); + VERIFY_IS_EQUAL(ref.dimension(1), 5); + VERIFY_IS_EQUAL(ref.dimension(2), 7); + + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 5; ++j) { + for (int k = 0; k < 7; ++k) { + VERIFY_IS_EQUAL(ref(i,j,k), input(i,j,k)); + } + } + } +} + + +static void test_slice() +{ + Tensor<float, 5> tensor(2,3,5,7,11); + tensor.setRandom(); + + Eigen::DSizes<ptrdiff_t, 5> indices(1,2,3,4,5); + Eigen::DSizes<ptrdiff_t, 5> sizes(1,1,1,1,1); + TensorRef<Tensor<float, 5>> slice = tensor.slice(indices, sizes); + VERIFY_IS_EQUAL(slice(0,0,0,0,0), tensor(1,2,3,4,5)); + + Eigen::DSizes<ptrdiff_t, 5> indices2(1,1,3,4,5); + Eigen::DSizes<ptrdiff_t, 5> sizes2(1,1,2,2,3); + slice = tensor.slice(indices2, sizes2); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + for (int k = 0; k < 3; ++k) { + VERIFY_IS_EQUAL(slice(0,0,i,j,k), tensor(1,1,3+i,4+j,5+k)); + } + } + } + + Eigen::DSizes<ptrdiff_t, 5> indices3(0,0,0,0,0); + Eigen::DSizes<ptrdiff_t, 5> sizes3(2,3,1,1,1); + slice = tensor.slice(indices3, sizes3); + VERIFY_IS_EQUAL(slice.data(), tensor.data()); +} + + +static void test_ref_of_ref() +{ + Tensor<float, 3> input(3,5,7); + input.setRandom(); + + TensorRef<Tensor<float, 3>> ref(input); + TensorRef<Tensor<float, 3>> ref_of_ref(ref); + TensorRef<Tensor<float, 3>> ref_of_ref2; + ref_of_ref2 = ref; + + VERIFY_IS_EQUAL(ref_of_ref.data(), input.data()); + VERIFY_IS_EQUAL(ref_of_ref.dimension(0), 3); + VERIFY_IS_EQUAL(ref_of_ref.dimension(1), 5); + VERIFY_IS_EQUAL(ref_of_ref.dimension(2), 7); + + VERIFY_IS_EQUAL(ref_of_ref2.data(), input.data()); + VERIFY_IS_EQUAL(ref_of_ref2.dimension(0), 3); + VERIFY_IS_EQUAL(ref_of_ref2.dimension(1), 5); + VERIFY_IS_EQUAL(ref_of_ref2.dimension(2), 7); + + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 5; ++j) { + for (int k = 0; k < 7; ++k) { + VERIFY_IS_EQUAL(ref_of_ref(i,j,k), input(i,j,k)); + VERIFY_IS_EQUAL(ref_of_ref2(i,j,k), input(i,j,k)); + } + } + } +} + + +static void test_ref_in_expr() +{ + Tensor<float, 3> input(3,5,7); + input.setRandom(); + TensorRef<Tensor<float, 3>> input_ref(input); + + Tensor<float, 3> result(3,5,7); + result.setRandom(); + TensorRef<Tensor<float, 3>> result_ref(result); + + Tensor<float, 3> bias(3,5,7); + bias.setRandom(); + + result_ref = input_ref + bias; + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 5; ++j) { + for (int k = 0; k < 7; ++k) { + VERIFY_IS_EQUAL(result_ref(i,j,k), input(i,j,k) + bias(i,j,k)); + VERIFY_IS_NOT_EQUAL(result(i,j,k), input(i,j,k) + bias(i,j,k)); + } + } + } + + result = result_ref; + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 5; ++j) { + for (int k = 0; k < 7; ++k) { + VERIFY_IS_EQUAL(result(i,j,k), input(i,j,k) + bias(i,j,k)); + } + } + } +} + + +void test_cxx11_tensor_ref() +{ + CALL_SUBTEST(test_simple_lvalue_ref()); + CALL_SUBTEST(test_simple_rvalue_ref()); + CALL_SUBTEST(test_multiple_dims()); + CALL_SUBTEST(test_slice()); + CALL_SUBTEST(test_ref_of_ref()); + CALL_SUBTEST(test_ref_in_expr()); +} |