diff options
Diffstat (limited to 'tensorflow/core/framework/tensor.cc')
-rw-r--r-- | tensorflow/core/framework/tensor.cc | 43 |
1 files changed, 14 insertions, 29 deletions
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index ecb9810d83..d049da1c9d 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -902,42 +902,27 @@ void Tensor::FillDescription(TensorDescription* description) const { } gtl::InlinedVector<int64, 4> Tensor::ComputeFlatInnerDims( - int64 num_out_dims) const { - if (num_out_dims == dims()) { - return shape_.dim_sizes(); - } + gtl::ArraySlice<int64> orig, int64 num_out_dims) { gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0); - const int64 num_elements = NumElements(); - int64 prod_out_dims = 1; - for (int64 out_dim = num_out_dims - 1; out_dim > 0; --out_dim) { - const int64 in_dim = out_dim + (dims() - num_out_dims); - out_dims[out_dim] = (in_dim >= dims() || in_dim < 0) ? 1 : dim_size(in_dim); - prod_out_dims *= out_dims[out_dim]; - } - if (prod_out_dims != 0) { - out_dims[0] = num_elements / prod_out_dims; - } else { - out_dims[0] = 0; + int64 offset = orig.size() - num_out_dims; + for (int64 out_dim = num_out_dims - 1; out_dim >= 0; --out_dim) { + const int64 in_dim = out_dim + offset; + out_dims[out_dim] = in_dim < 0 ? 1 : orig[in_dim]; + } + for (int64 in_dim = 0; in_dim < offset; ++in_dim) { + out_dims[0] *= orig[in_dim]; } return out_dims; } gtl::InlinedVector<int64, 4> Tensor::ComputeFlatOuterDims( - int64 num_out_dims) const { - if (num_out_dims == dims()) { - return shape_.dim_sizes(); - } + gtl::ArraySlice<int64> orig, int64 num_out_dims) { gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0); - const int64 num_elements = NumElements(); - int64 prod_out_dims = 1; - for (int64 out_dim = 0; out_dim < num_out_dims - 1; ++out_dim) { - out_dims[out_dim] = out_dim >= dims() ? 1 : dim_size(out_dim); - prod_out_dims *= out_dims[out_dim]; - } - if (prod_out_dims != 0) { - out_dims[num_out_dims - 1] = num_elements / prod_out_dims; - } else { - out_dims[num_out_dims - 1] = 0; + for (int64 out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) { + out_dims[out_dim] = out_dim >= orig.size() ? 1 : orig[out_dim]; + } + for (int64 in_dim = num_out_dims; in_dim < orig.size(); ++in_dim) { + out_dims[num_out_dims - 1] *= orig[in_dim]; } return out_dims; } |