diff options
Diffstat (limited to 'tensorflow/core/framework/tensor.h')
-rw-r--r-- | tensorflow/core/framework/tensor.h | 41 |
1 files changed, 34 insertions, 7 deletions
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 103da4c1b3..753548de84 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -304,6 +304,15 @@ class Tensor { template <typename T, size_t NDIMS = 2> typename TTypes<T, NDIMS>::Tensor flat_outer_dims(); + /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing the + /// first 'begin' Tensor dimensions into the first dimension of the result and + /// the Tensor dimensions of the last dims() - 'begin' - NDIMS into the last + /// dimension of the result. If 'begin' < 0 then the the |'begin'| leading + /// dimensions of size 1 will be added. If 'begin' + NDIMS > dims() then + /// 'begin' + NDIMS - dims() trailing dimensions of size 1 will be added. + template <typename T, size_t NDIMS = 3> + typename TTypes<T, NDIMS>::Tensor flat_inner_outer_dims(int64 begin); + template <typename T, size_t NDIMS> typename TTypes<T, NDIMS>::Tensor shaped(gtl::ArraySlice<int64> new_sizes); @@ -386,6 +395,9 @@ class Tensor { template <typename T, size_t NDIMS = 2> typename TTypes<T, NDIMS>::ConstTensor flat_outer_dims() const; + template <typename T, size_t NDIMS = 3> + typename TTypes<T, NDIMS>::Tensor flat_inner_outer_dims(int64 begin) const; + /// Render the first `max_entries` values in `*this` into a string. string SummarizeValue(int64 max_entries) const; @@ -429,10 +441,11 @@ class Tensor { gtl::ArraySlice<int64> new_sizes, Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const; - // TODO(rmlarsen): These shouldn't hardcode '4' so that it lines up with // TensorShape's InlineVector. - gtl::InlinedVector<int64, 4> ComputeFlatInnerDims(int64 num_out_dims) const; - gtl::InlinedVector<int64, 4> ComputeFlatOuterDims(int64 num_out_dims) const; + static gtl::InlinedVector<int64, 4> ComputeFlatInnerDims( + gtl::ArraySlice<int64> orig, int64 num_out_dims); + static gtl::InlinedVector<int64, 4> ComputeFlatOuterDims( + gtl::ArraySlice<int64> orig, int64 num_out_dims); TensorShape shape_; TensorBuffer* buf_; @@ -638,22 +651,36 @@ typename TTypes<T>::ConstScalar Tensor::scalar() const { template <typename T, size_t NDIMS> typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_dims() { - return shaped<T, NDIMS>(ComputeFlatInnerDims(NDIMS)); + return shaped<T, NDIMS>(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS)); } template <typename T, size_t NDIMS> typename TTypes<T, NDIMS>::Tensor Tensor::flat_outer_dims() { - return shaped<T, NDIMS>(ComputeFlatOuterDims(NDIMS)); + return shaped<T, NDIMS>(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS)); +} + +template <typename T, size_t NDIMS> +typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_outer_dims(int64 begin) { + gtl::InlinedVector<int64,4> flat_outer = ComputeFlatOuterDims( + shape_.dim_sizes(), begin + NDIMS); + return shaped<T, NDIMS>(ComputeFlatInnerDims(flat_outer, NDIMS)); } template <typename T, size_t NDIMS> typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_inner_dims() const { - return shaped<T, NDIMS>(ComputeFlatInnerDims(NDIMS)); + return shaped<T, NDIMS>(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS)); } template <typename T, size_t NDIMS> typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_outer_dims() const { - return shaped<T, NDIMS>(ComputeFlatOuterDims(NDIMS)); + return shaped<T, NDIMS>(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS)); +} + +template <typename T, size_t NDIMS> +typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_outer_dims(int64 begin) const { + gtl::InlinedVector<int64,4> flat_outer = ComputeFlatOuterDims( + shape_.dim_sizes(), begin + NDIMS); + return shaped<T, NDIMS>(ComputeFlatInnerDims(flat_outer, NDIMS)); } inline Tensor::Tensor(const Tensor& other) |