diff options
Diffstat (limited to 'tensorflow/core/public/tensor_shape.h')
-rw-r--r-- | tensorflow/core/public/tensor_shape.h | 239 |
1 files changed, 239 insertions, 0 deletions
diff --git a/tensorflow/core/public/tensor_shape.h b/tensorflow/core/public/tensor_shape.h new file mode 100644 index 0000000000..a889b8b17d --- /dev/null +++ b/tensorflow/core/public/tensor_shape.h @@ -0,0 +1,239 @@ +#ifndef TENSORFLOW_PUBLIC_TENSOR_SHAPE_H_ +#define TENSORFLOW_PUBLIC_TENSOR_SHAPE_H_ + +#include <string> + +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +class TensorShapeIter; // Declared below + +/// Manages the dimensions of a Tensor and their sizes. +class TensorShape { + public: + /// \brief Construct a TensorShape from the provided sizes.. + /// REQUIRES: dim_sizes[i] >= 0 + explicit TensorShape(gtl::ArraySlice<int64> dim_sizes); + TensorShape(std::initializer_list<int64> dim_sizes) + : TensorShape(gtl::ArraySlice<int64>(dim_sizes)) {} + + /// REQUIRES: IsValid(proto) + explicit TensorShape(const TensorShapeProto& proto); + + /// Create a tensor shape with no dimensions and one element, which you can + /// then call AddDim() on. + TensorShape(); + + /// Returns true iff "proto" is a valid tensor shape. + static bool IsValid(const TensorShapeProto& proto); + + /// Clear a tensor shape + void Clear(); + + /// \brief Add a dimension to the end ("inner-most"). + /// REQUIRES: size >= 0 + void AddDim(int64 size); + + /// Appends all the dimensions from shape. + void AppendShape(const TensorShape& shape); + + /// \brief Insert a dimension somewhere in the TensorShape. + /// REQUIRES: "0 <= d <= dims()" + /// REQUIRES: size >= 0 + void InsertDim(int d, int64 size); + + /// \brief Modifies the size of the dimension 'd' to be 'size' + /// REQUIRES: "0 <= d < dims()" + /// REQUIRES: size >= 0 + void set_dim(int d, int64 size); + + /// \brief Removes dimension 'd' from the TensorShape. + /// REQUIRES: "0 <= d < dims()" + void RemoveDim(int d); + + /// Return the number of dimensions in the tensor. + int dims() const { return dim_sizes_.size(); } + + /// \brief Returns the number of elements in dimension "d". + /// REQUIRES: "0 <= d < dims()" + // TODO(mdevin): Rename to dimension() to match Eigen::Tensor::dimension()? + int64 dim_size(int d) const { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + return dim_sizes_[d]; + } + + /// Returns sizes of all dimensions. + gtl::ArraySlice<int64> dim_sizes() const { return dim_sizes_; } + + /// \brief Returns the number of elements in the tensor. + /// + /// We use int64 and + /// not size_t to be compatible with Eigen::Tensor which uses ptr_fi + int64 num_elements() const { return num_elements_; } + + /// Returns true if *this and b have the same sizes. Ignores dimension names. + bool IsSameSize(const TensorShape& b) const; + bool operator==(const TensorShape& b) const { return IsSameSize(b); } + + /// Fill *proto from *this. + void AsProto(TensorShapeProto* proto) const; + + /// Fill *dsizes from *this. + template <int NDIMS> + Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizes() const; + + /// Same as AsEigenDSizes() but allows for NDIMS > dims() -- in which case we + /// pad the rest of the sizes with 1. + template <int NDIMS> + Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizesWithPadding() const; + + /// For iterating through the dimensions. + TensorShapeIter begin() const; + TensorShapeIter end() const; + + /// For error messages. + string DebugString() const; + // TODO(vrv): Remove this, this is the same as DebugString(). + string ShortDebugString() const; + + private: + /// Recalculates the dimensions of this tensor after they are modified. + void recompute_dims(); + + // TODO(josh11b): Maybe use something from the Eigen Tensor library + /// for the sizes. + gtl::InlinedVector<int64, 4> dim_sizes_; + + /// total number of elements (avoids recomputing it each time). + int64 num_elements_; +}; + +struct TensorShapeDim { + explicit TensorShapeDim(int64 s) : size(s) {} + int size; +}; + +class TensorShapeIter { + public: + TensorShapeIter(const TensorShape* shape, int d) : shape_(shape), d_(d) {} + bool operator==(const TensorShapeIter& rhs) { + DCHECK(shape_ == rhs.shape_); + return d_ == rhs.d_; + } + bool operator!=(const TensorShapeIter& rhs) { + DCHECK(shape_ == rhs.shape_); + return d_ != rhs.d_; + } + void operator++() { ++d_; } + TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); } + + private: + const TensorShape* shape_; + int d_; +}; + +// In some places, allow shape (1,) to be treated as a scalar and shape () to be +// treated as a vector. This flag is for temporary backwards compatibility +// only, and will be changed to strict within Google around November 15, 2015. +#if defined(PLATFORM_GOOGLE) +// TODO(irving): Become strict on November 15, 2015. +static const bool kAllowLegacyScalars = true; +#else +// For open source (outside Google), we are strict. +static const bool kAllowLegacyScalars = false; +#endif + +/// \brief Static helper routines for TensorShape. Includes a few common +/// predicates on a tensor shape. +class TensorShapeUtils { + public: + static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; } + + static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; } + + // Allow either scalars or (if allowing legacy scalars) shape (1,). + static bool IsLegacyScalar(const TensorShape& shape) { + return shape.dims() == 0 || + (kAllowLegacyScalars && shape.dims() == 1 && shape.dim_size(0) == 1); + } + + // Allow rank 1 or (if allowing legacy scalars) rank 0. + static bool IsLegacyVector(const TensorShape& shape) { + return shape.dims() == 1 || (kAllowLegacyScalars && shape.dims() == 0); + } + + static bool IsVectorOrHigher(const TensorShape& shape) { + return shape.dims() >= 1; + } + + static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; } + + static bool IsMatrixOrHigher(const TensorShape& shape) { + return shape.dims() >= 2; + } + + /// \brief Returns a TensorShape whose dimensions are dims[0], dims[1], ..., + /// dims[n-1]. + template <typename T> + static TensorShape MakeShape(const T* dims, int n) { + TensorShape shape; + for (int i = 0; i < n; ++i) shape.AddDim(dims[i]); + return shape; + } + + static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes) { + string result = "["; + bool first = true; + for (const TensorShape& shape : shapes) { + strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString()); + first = false; + } + strings::StrAppend(&result, "]"); + return result; + } + + static bool StartsWith(const TensorShape& shape0, const TensorShape& shape1); +}; + +// TODO(josh11b): Add TensorStrides once we support strides +// struct TensorStrides { +// gtl::InlinedVector<int, 4> strides_; +// }; + +// ---------------------------------------------------------------------------- +// Template method implementation details below +// ---------------------------------------------------------------------------- + +template <int NDIMS> +Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizes() const { + CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS + << " for a tensor of " << dims() << " dimensions"; + return AsEigenDSizesWithPadding<NDIMS>(); +} + +template <int NDIMS> +Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizesWithPadding() + const { + CHECK_GE(NDIMS, dims()) << "Asking for tensor of " << NDIMS + << " for a tensor of " << dims() << " dimensions"; + Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes; + for (int d = 0; d < dims(); d++) { + dsizes[d] = dim_size(d); + } + for (int d = dims(); d < NDIMS; d++) { + dsizes[d] = 1; + } + return dsizes; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_PUBLIC_TENSOR_SHAPE_H_ |