aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-08-10 16:53:36 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2018-08-10 16:53:36 -0700
commitf2209d06e428e0691de71f30fc2db4cb29191cd2 (patch)
tree37d7294a61f80c87389e8e930700a549554afe51 /unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
parentcfaedb38cd662def3b5684a20965b3bc1b0d6a3f (diff)
Add block evaluationto CwiseUnaryOp and add PreferBlockAccess enum to all evaluators
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h141
1 files changed, 141 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
index 877603421..4a3e1ac17 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h
@@ -382,6 +382,147 @@ class TensorBlockWriter : public TensorBlockIO<Scalar, StorageIndex, NumDims,
};
/**
+ * \class TensorBlockCwiseUnaryOp
+ * \ingroup CXX11_Tensor_Module
+ *
+ * \brief Carries out a cwise binary op on a number of coefficients.
+ *
+ * This class reads strided input from the argument, and writes the
+ * result of the cwise unary op to the strided output array.
+ *
+ */
+struct TensorBlockCwiseUnaryOp {
+ template <typename StorageIndex, typename UnaryFunctor,
+ typename OutputScalar, typename InputScalar>
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void Run(
+ const UnaryFunctor& functor, const StorageIndex num_coeff,
+ const StorageIndex output_index, const StorageIndex output_stride,
+ OutputScalar* output_data, const StorageIndex input_index,
+ const StorageIndex input_stride, const InputScalar* input_data) {
+ typedef const Eigen::Array<InputScalar, Dynamic, 1> Input;
+ typedef Eigen::Array<OutputScalar, Dynamic, 1> Output;
+
+ typedef Eigen::Map<Input, 0, InnerStride<>> InputMap;
+ typedef Eigen::Map<Output, 0, InnerStride<>> OutputMap;
+
+ const InputScalar* input_base = &input_data[input_index];
+ OutputScalar* output_base = &output_data[output_index];
+
+ const InputMap input(input_base, num_coeff, InnerStride<>(input_stride));
+ OutputMap output(output_base, num_coeff, InnerStride<>(output_stride));
+
+ output = Eigen::CwiseUnaryOp<UnaryFunctor, InputMap>(input, functor);
+ }
+};
+
+/**
+ * \class TensorBlockCwiseUnaryIO
+ * \ingroup CXX11_Tensor_Module
+ *
+ * \brief Tensor block IO class for carrying out cwise unary ops.
+ *
+ * This class carries out the unary op on given blocks.
+ */
+template <typename UnaryFunctor, typename StorageIndex, typename OutputScalar,
+ int NumDims, int Layout>
+struct TensorBlockCwiseUnaryIO {
+ typedef typename internal::TensorBlock<OutputScalar, StorageIndex, NumDims,
+ Layout>::Dimensions Dimensions;
+
+ struct BlockIteratorState {
+ StorageIndex output_stride, output_span;
+ StorageIndex input_stride, input_span;
+ StorageIndex size, count;
+ };
+
+ template <typename InputScalar>
+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void Run(
+ const UnaryFunctor& functor, const Dimensions& block_sizes,
+ const Dimensions& block_strides, OutputScalar* output_data,
+ const array<StorageIndex, NumDims>& input_strides,
+ const InputScalar* input_data) {
+ // Find the innermost dimension whose size is not 1. This is the effective
+ // inner dim. If all dimensions are of size 1, fallback to using the actual
+ // innermost dim to avoid out-of-bound access.
+ int num_size_one_inner_dims = 0;
+ for (int i = 0; i < NumDims; ++i) {
+ const int dim = cond<Layout>()(i, NumDims - i - 1);
+ if (block_sizes[dim] != 1) {
+ num_size_one_inner_dims = i;
+ break;
+ }
+ }
+ // Calculate strides and dimensions.
+ const int inner_dim =
+ NumDims == 0 ? 1
+ : cond<Layout>()(num_size_one_inner_dims,
+ NumDims - num_size_one_inner_dims - 1);
+ StorageIndex inner_dim_size = NumDims == 0 ? 1 : block_sizes[inner_dim];
+ for (int i = num_size_one_inner_dims + 1; i < NumDims; ++i) {
+ const int dim = cond<Layout>()(i, NumDims - i - 1);
+ // Merge multiple inner dims into one for larger inner dim size (i.e.
+ // fewer calls to TensorBlockCwiseUnaryOp::Run()).
+ if (inner_dim_size == block_strides[dim] &&
+ block_strides[dim] == input_strides[dim]) {
+ inner_dim_size *= block_sizes[dim];
+ ++num_size_one_inner_dims;
+ } else {
+ break;
+ }
+ }
+
+ StorageIndex output_index = 0, input_index = 0;
+
+ const StorageIndex output_stride =
+ NumDims == 0 ? 1 : block_strides[inner_dim];
+ const StorageIndex input_stride =
+ NumDims == 0 ? 1 : input_strides[inner_dim];
+
+ const int at_least_1_dim = NumDims <= 1 ? 1 : NumDims - 1;
+ array<BlockIteratorState, at_least_1_dim> block_iter_state;
+
+ // Initialize block iterator state. Squeeze away any dimension of size 1.
+ int num_squeezed_dims = 0;
+ for (int i = num_size_one_inner_dims; i < NumDims - 1; ++i) {
+ const int dim = cond<Layout>()(i + 1, NumDims - i - 2);
+ const StorageIndex size = block_sizes[dim];
+ if (size == 1) {
+ continue;
+ }
+ BlockIteratorState& state = block_iter_state[num_squeezed_dims];
+ state.output_stride = block_strides[dim];
+ state.input_stride = input_strides[dim];
+ state.size = size;
+ state.output_span = state.output_stride * (size - 1);
+ state.input_span = state.input_stride * (size - 1);
+ state.count = 0;
+ ++num_squeezed_dims;
+ }
+
+ // Compute cwise unary op.
+ const StorageIndex block_total_size =
+ NumDims == 0 ? 1 : block_sizes.TotalSize();
+ for (StorageIndex i = 0; i < block_total_size; i += inner_dim_size) {
+ TensorBlockCwiseUnaryOp::Run(functor, inner_dim_size, output_index,
+ output_stride, output_data, input_index,
+ input_stride, input_data);
+ // Update index.
+ for (int j = 0; j < num_squeezed_dims; ++j) {
+ auto& state = block_iter_state[j];
+ if (++state.count < state.size) {
+ output_index += state.output_stride;
+ input_index += state.input_stride;
+ break;
+ }
+ state.count = 0;
+ output_index -= state.output_span;
+ input_index -= state.input_span;
+ }
+ }
+ }
+};
+
+/**
* \class TensorBlockCwiseBinaryOp
* \ingroup CXX11_Tensor_Module
*