aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/tensor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/tensor.cc')
-rw-r--r--tensorflow/core/framework/tensor.cc43
1 files changed, 14 insertions, 29 deletions
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index ecb9810d83..d049da1c9d 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -902,42 +902,27 @@ void Tensor::FillDescription(TensorDescription* description) const {
}
gtl::InlinedVector<int64, 4> Tensor::ComputeFlatInnerDims(
- int64 num_out_dims) const {
- if (num_out_dims == dims()) {
- return shape_.dim_sizes();
- }
+ gtl::ArraySlice<int64> orig, int64 num_out_dims) {
gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
- const int64 num_elements = NumElements();
- int64 prod_out_dims = 1;
- for (int64 out_dim = num_out_dims - 1; out_dim > 0; --out_dim) {
- const int64 in_dim = out_dim + (dims() - num_out_dims);
- out_dims[out_dim] = (in_dim >= dims() || in_dim < 0) ? 1 : dim_size(in_dim);
- prod_out_dims *= out_dims[out_dim];
- }
- if (prod_out_dims != 0) {
- out_dims[0] = num_elements / prod_out_dims;
- } else {
- out_dims[0] = 0;
+ int64 offset = orig.size() - num_out_dims;
+ for (int64 out_dim = num_out_dims - 1; out_dim >= 0; --out_dim) {
+ const int64 in_dim = out_dim + offset;
+ out_dims[out_dim] = in_dim < 0 ? 1 : orig[in_dim];
+ }
+ for (int64 in_dim = 0; in_dim < offset; ++in_dim) {
+ out_dims[0] *= orig[in_dim];
}
return out_dims;
}
gtl::InlinedVector<int64, 4> Tensor::ComputeFlatOuterDims(
- int64 num_out_dims) const {
- if (num_out_dims == dims()) {
- return shape_.dim_sizes();
- }
+ gtl::ArraySlice<int64> orig, int64 num_out_dims) {
gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
- const int64 num_elements = NumElements();
- int64 prod_out_dims = 1;
- for (int64 out_dim = 0; out_dim < num_out_dims - 1; ++out_dim) {
- out_dims[out_dim] = out_dim >= dims() ? 1 : dim_size(out_dim);
- prod_out_dims *= out_dims[out_dim];
- }
- if (prod_out_dims != 0) {
- out_dims[num_out_dims - 1] = num_elements / prod_out_dims;
- } else {
- out_dims[num_out_dims - 1] = 0;
+ for (int64 out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) {
+ out_dims[out_dim] = out_dim >= orig.size() ? 1 : orig[out_dim];
+ }
+ for (int64 in_dim = num_out_dims; in_dim < orig.size(); ++in_dim) {
+ out_dims[num_out_dims - 1] *= orig[in_dim];
}
return out_dims;
}