aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/tensor_types.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/tensor_types.h')
-rw-r--r--tensorflow/core/framework/tensor_types.h92
1 files changed, 92 insertions, 0 deletions
diff --git a/tensorflow/core/framework/tensor_types.h b/tensorflow/core/framework/tensor_types.h
new file mode 100644
index 0000000000..077d86d442
--- /dev/null
+++ b/tensorflow/core/framework/tensor_types.h
@@ -0,0 +1,92 @@
+#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_
+#define TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+// Helper to define Tensor types given that the scalar is of type T.
+template <typename T, int NDIMS = 1>
+struct TTypes {
+ // Rank-<NDIMS> tensor of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor>,
+ Eigen::Aligned> Tensor;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, NDIMS, Eigen::RowMajor>,
+ Eigen::Aligned> ConstTensor;
+
+ // Unaligned Rank-<NDIMS> tensor of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor> >
+ UnalignedTensor;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, NDIMS, Eigen::RowMajor> >
+ UnalignedConstTensor;
+
+ typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, int>,
+ Eigen::Aligned> Tensor32Bit;
+
+ // Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
+ typedef Eigen::TensorMap<
+ Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor>,
+ Eigen::Aligned> Scalar;
+ typedef Eigen::TensorMap<
+ Eigen::TensorFixedSize<const T, Eigen::Sizes<>, Eigen::RowMajor>,
+ Eigen::Aligned> ConstScalar;
+
+ // Unaligned Scalar tensor of scalar type T.
+ typedef Eigen::TensorMap<Eigen::TensorFixedSize<
+ T, Eigen::Sizes<>, Eigen::RowMajor> > UnalignedScalar;
+ typedef Eigen::TensorMap<Eigen::TensorFixedSize<
+ const T, Eigen::Sizes<>, Eigen::RowMajor> > UnalignedConstScalar;
+
+ // Rank-1 tensor (vector) of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, Eigen::Aligned>
+ Flat;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
+ Eigen::Aligned> ConstFlat;
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, Eigen::Aligned>
+ Vec;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
+ Eigen::Aligned> ConstVec;
+
+ // Unaligned Rank-1 tensor (vector) of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor> > UnalignedFlat;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor> >
+ UnalignedConstFlat;
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor> > UnalignedVec;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor> >
+ UnalignedConstVec;
+
+ // Rank-2 tensor (matrix) of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>
+ Matrix;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Aligned> ConstMatrix;
+
+ // Unaligned Rank-2 tensor (matrix) of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor> >
+ UnalignedMatrix;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor> >
+ UnalignedConstMatrix;
+};
+
+typedef typename TTypes<float, 1>::Tensor32Bit::Index Index32;
+
+template <typename DSizes>
+Eigen::DSizes<Index32, DSizes::count> To32BitDims(const DSizes& in) {
+ Eigen::DSizes<Index32, DSizes::count> out;
+ for (int i = 0; i < DSizes::count; ++i) {
+ out[i] = in[i];
+ }
+ return out;
+}
+
+template <typename TensorType>
+typename TTypes<typename TensorType::Scalar,
+ TensorType::NumIndices>::Tensor32Bit
+To32Bit(TensorType in) {
+ typedef typename TTypes<typename TensorType::Scalar,
+ TensorType::NumIndices>::Tensor32Bit RetType;
+ return RetType(in.data(), To32BitDims(in.dimensions()));
+}
+
+} // namespace tensorflow
+#endif // TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_