diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-06-20 10:46:45 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-06-20 10:46:45 -0700 |
commit | de32f8d656c3ea7855ced77457ea661e43d417b7 (patch) | |
tree | 747902dd5ac134ae9b7522d37e7ef4992d997524 | |
parent | b055590e9135ffe762775ec919e490513b6974fa (diff) |
Fixed the printing of rank-0 tensors
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorIO.h | 62 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_io.cpp | 16 |
2 files changed, 59 insertions, 19 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h b/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h index 38a833f82..3db692ac6 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h @@ -17,34 +17,58 @@ template<> struct significant_decimals_impl<std::string> : significant_decimals_default_impl<std::string, true> {}; -} +// Print the tensor as a 2d matrix +template <typename Tensor, int Rank> +struct TensorPrinter { + static void run (std::ostream& os, const Tensor& tensor) { + typedef typename internal::remove_const<typename Tensor::Scalar>::type Scalar; + typedef typename Tensor::Index Index; + const Index total_size = internal::array_prod(tensor.dimensions()); + const Index first_dim = Eigen::internal::array_get<0>(tensor.dimensions()); + static const int layout = Tensor::Layout; + Map<const Array<Scalar, Dynamic, Dynamic, layout> > matrix(const_cast<Scalar*>(tensor.data()), first_dim, total_size/first_dim); + os << matrix; + } +}; + + +// Print the tensor as a vector +template <typename Tensor> +struct TensorPrinter<Tensor, 1> { + static void run (std::ostream& os, const Tensor& tensor) { + typedef typename internal::remove_const<typename Tensor::Scalar>::type Scalar; + typedef typename Tensor::Index Index; + const Index total_size = internal::array_prod(tensor.dimensions()); + Map<const Array<Scalar, Dynamic, 1> > array(const_cast<Scalar*>(tensor.data()), total_size); + os << array; + } +}; + + +// Print the tensor as a scalar +template <typename Tensor> +struct TensorPrinter<Tensor, 0> { + static void run (std::ostream& os, const Tensor& tensor) { + os << tensor.coeff(0); + } +}; +} + template <typename T> std::ostream& operator << (std::ostream& os, const TensorBase<T, ReadOnlyAccessors>& expr) { + typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator; + typedef typename Evaluator::Dimensions Dimensions; + // Evaluate the expression if needed TensorForcedEvalOp<const T> eval = expr.eval(); - TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> tensor(eval, DefaultDevice()); + Evaluator tensor(eval, DefaultDevice()); tensor.evalSubExprsIfNeeded(NULL); - typedef typename internal::remove_const<typename T::Scalar>::type Scalar; - typedef typename T::Index Index; - typedef typename TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice>::Dimensions Dimensions; - const Index total_size = internal::array_prod(tensor.dimensions()); - - // Print the tensor as a 1d vector or a 2d matrix. + // Print the result static const int rank = internal::array_size<Dimensions>::value; - if (rank == 0) { - os << tensor.coeff(0); - } else if (rank == 1) { - Map<const Array<Scalar, Dynamic, 1> > array(const_cast<Scalar*>(tensor.data()), total_size); - os << array; - } else { - const Index first_dim = Eigen::internal::array_get<0>(tensor.dimensions()); - static const int layout = TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice>::Layout; - Map<const Array<Scalar, Dynamic, Dynamic, layout> > matrix(const_cast<Scalar*>(tensor.data()), first_dim, total_size/first_dim); - os << matrix; - } + internal::TensorPrinter<Evaluator, rank>::run(os, tensor); // Cleanup. tensor.cleanup(); diff --git a/unsupported/test/cxx11_tensor_io.cpp b/unsupported/test/cxx11_tensor_io.cpp index 8bbcf7089..8267dcadd 100644 --- a/unsupported/test/cxx11_tensor_io.cpp +++ b/unsupported/test/cxx11_tensor_io.cpp @@ -14,6 +14,20 @@ template<int DataLayout> +static void test_output_0d() +{ + Tensor<int, 0, DataLayout> tensor; + tensor() = 123; + + std::stringstream os; + os << tensor; + + std::string expected("123"); + VERIFY_IS_EQUAL(std::string(os.str()), expected); +} + + +template<int DataLayout> static void test_output_1d() { Tensor<int, 1, DataLayout> tensor(5); @@ -101,6 +115,8 @@ static void test_output_const() void test_cxx11_tensor_io() { + CALL_SUBTEST(test_output_0d<ColMajor>()); + CALL_SUBTEST(test_output_0d<RowMajor>()); CALL_SUBTEST(test_output_1d<ColMajor>()); CALL_SUBTEST(test_output_1d<RowMajor>()); CALL_SUBTEST(test_output_2d<ColMajor>()); |