aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/tensor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/tensor.h')
-rw-r--r--tensorflow/core/framework/tensor.h41
1 files changed, 34 insertions, 7 deletions
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 103da4c1b3..753548de84 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -304,6 +304,15 @@ class Tensor {
template <typename T, size_t NDIMS = 2>
typename TTypes<T, NDIMS>::Tensor flat_outer_dims();
+ /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing the
+ /// first 'begin' Tensor dimensions into the first dimension of the result and
+ /// the Tensor dimensions of the last dims() - 'begin' - NDIMS into the last
+ /// dimension of the result. If 'begin' < 0 then the the |'begin'| leading
+ /// dimensions of size 1 will be added. If 'begin' + NDIMS > dims() then
+ /// 'begin' + NDIMS - dims() trailing dimensions of size 1 will be added.
+ template <typename T, size_t NDIMS = 3>
+ typename TTypes<T, NDIMS>::Tensor flat_inner_outer_dims(int64 begin);
+
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor shaped(gtl::ArraySlice<int64> new_sizes);
@@ -386,6 +395,9 @@ class Tensor {
template <typename T, size_t NDIMS = 2>
typename TTypes<T, NDIMS>::ConstTensor flat_outer_dims() const;
+ template <typename T, size_t NDIMS = 3>
+ typename TTypes<T, NDIMS>::Tensor flat_inner_outer_dims(int64 begin) const;
+
/// Render the first `max_entries` values in `*this` into a string.
string SummarizeValue(int64 max_entries) const;
@@ -429,10 +441,11 @@ class Tensor {
gtl::ArraySlice<int64> new_sizes,
Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const;
- // TODO(rmlarsen): These shouldn't hardcode '4' so that it lines up with
// TensorShape's InlineVector.
- gtl::InlinedVector<int64, 4> ComputeFlatInnerDims(int64 num_out_dims) const;
- gtl::InlinedVector<int64, 4> ComputeFlatOuterDims(int64 num_out_dims) const;
+ static gtl::InlinedVector<int64, 4> ComputeFlatInnerDims(
+ gtl::ArraySlice<int64> orig, int64 num_out_dims);
+ static gtl::InlinedVector<int64, 4> ComputeFlatOuterDims(
+ gtl::ArraySlice<int64> orig, int64 num_out_dims);
TensorShape shape_;
TensorBuffer* buf_;
@@ -638,22 +651,36 @@ typename TTypes<T>::ConstScalar Tensor::scalar() const {
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_dims() {
- return shaped<T, NDIMS>(ComputeFlatInnerDims(NDIMS));
+ return shaped<T, NDIMS>(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS));
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor Tensor::flat_outer_dims() {
- return shaped<T, NDIMS>(ComputeFlatOuterDims(NDIMS));
+ return shaped<T, NDIMS>(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS));
+}
+
+template <typename T, size_t NDIMS>
+typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_outer_dims(int64 begin) {
+ gtl::InlinedVector<int64,4> flat_outer = ComputeFlatOuterDims(
+ shape_.dim_sizes(), begin + NDIMS);
+ return shaped<T, NDIMS>(ComputeFlatInnerDims(flat_outer, NDIMS));
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_inner_dims() const {
- return shaped<T, NDIMS>(ComputeFlatInnerDims(NDIMS));
+ return shaped<T, NDIMS>(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS));
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_outer_dims() const {
- return shaped<T, NDIMS>(ComputeFlatOuterDims(NDIMS));
+ return shaped<T, NDIMS>(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS));
+}
+
+template <typename T, size_t NDIMS>
+typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_outer_dims(int64 begin) const {
+ gtl::InlinedVector<int64,4> flat_outer = ComputeFlatOuterDims(
+ shape_.dim_sizes(), begin + NDIMS);
+ return shaped<T, NDIMS>(ComputeFlatInnerDims(flat_outer, NDIMS));
}
inline Tensor::Tensor(const Tensor& other)