#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_ #define TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { // Helper to define Tensor types given that the scalar is of type T. template struct TTypes { // Rank- tensor of scalar type T. typedef Eigen::TensorMap, Eigen::Aligned> Tensor; typedef Eigen::TensorMap, Eigen::Aligned> ConstTensor; // Unaligned Rank- tensor of scalar type T. typedef Eigen::TensorMap > UnalignedTensor; typedef Eigen::TensorMap > UnalignedConstTensor; typedef Eigen::TensorMap, Eigen::Aligned> Tensor32Bit; // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. typedef Eigen::TensorMap< Eigen::TensorFixedSize, Eigen::RowMajor>, Eigen::Aligned> Scalar; typedef Eigen::TensorMap< Eigen::TensorFixedSize, Eigen::RowMajor>, Eigen::Aligned> ConstScalar; // Unaligned Scalar tensor of scalar type T. typedef Eigen::TensorMap, Eigen::RowMajor> > UnalignedScalar; typedef Eigen::TensorMap, Eigen::RowMajor> > UnalignedConstScalar; // Rank-1 tensor (vector) of scalar type T. typedef Eigen::TensorMap, Eigen::Aligned> Flat; typedef Eigen::TensorMap, Eigen::Aligned> ConstFlat; typedef Eigen::TensorMap, Eigen::Aligned> Vec; typedef Eigen::TensorMap, Eigen::Aligned> ConstVec; // Unaligned Rank-1 tensor (vector) of scalar type T. typedef Eigen::TensorMap > UnalignedFlat; typedef Eigen::TensorMap > UnalignedConstFlat; typedef Eigen::TensorMap > UnalignedVec; typedef Eigen::TensorMap > UnalignedConstVec; // Rank-2 tensor (matrix) of scalar type T. typedef Eigen::TensorMap, Eigen::Aligned> Matrix; typedef Eigen::TensorMap, Eigen::Aligned> ConstMatrix; // Unaligned Rank-2 tensor (matrix) of scalar type T. typedef Eigen::TensorMap > UnalignedMatrix; typedef Eigen::TensorMap > UnalignedConstMatrix; }; typedef typename TTypes::Tensor32Bit::Index Index32; template Eigen::DSizes To32BitDims(const DSizes& in) { Eigen::DSizes out; for (int i = 0; i < DSizes::count; ++i) { out[i] = in[i]; } return out; } template typename TTypes::Tensor32Bit To32Bit(TensorType in) { typedef typename TTypes::Tensor32Bit RetType; return RetType(in.data(), To32BitDims(in.dimensions())); } } // namespace tensorflow #endif // TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_