aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/framework/tensor.cc45
-rw-r--r--tensorflow/core/framework/tensor.h7
-rw-r--r--tensorflow/core/framework/tensor_test.cc40
-rw-r--r--tensorflow/core/kernels/gather_op_test.cc13
4 files changed, 85 insertions, 20 deletions
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index c928eccec3..6d989fd1d6 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -714,34 +714,43 @@ void Tensor::FillDescription(TensorDescription* description) const {
}
}
-gtl::InlinedVector<int64, 5> Tensor::ComputeFlatInnerDims(
+gtl::InlinedVector<int64, 4> Tensor::ComputeFlatInnerDims(
int64 num_out_dims) const {
- gtl::InlinedVector<int64, 5> out_dims(num_out_dims, 0);
+ if (num_out_dims == dims()) {
+ return shape_.dim_sizes();
+ }
+ gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
const int64 num_elements = NumElements();
- if (num_elements != 0) {
- int64 prod_out_dims = 1;
- for (int64 out_dim = num_out_dims - 1; out_dim > 0; --out_dim) {
- const int64 in_dim = out_dim + (dims() - num_out_dims);
- out_dims[out_dim] =
- (in_dim >= dims() || in_dim < 0) ? 1 : dim_size(in_dim);
- prod_out_dims *= out_dims[out_dim];
- }
+ int64 prod_out_dims = 1;
+ for (int64 out_dim = num_out_dims - 1; out_dim > 0; --out_dim) {
+ const int64 in_dim = out_dim + (dims() - num_out_dims);
+ out_dims[out_dim] = (in_dim >= dims() || in_dim < 0) ? 1 : dim_size(in_dim);
+ prod_out_dims *= out_dims[out_dim];
+ }
+ if (prod_out_dims != 0) {
out_dims[0] = num_elements / prod_out_dims;
+ } else {
+ out_dims[0] = 0;
}
return out_dims;
}
-gtl::InlinedVector<int64, 5> Tensor::ComputeFlatOuterDims(
+gtl::InlinedVector<int64, 4> Tensor::ComputeFlatOuterDims(
int64 num_out_dims) const {
- gtl::InlinedVector<int64, 5> out_dims(num_out_dims, 0);
+ if (num_out_dims == dims()) {
+ return shape_.dim_sizes();
+ }
+ gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
const int64 num_elements = NumElements();
- if (num_elements != 0) {
- int64 prod_out_dims = 1;
- for (int64 out_dim = 0; out_dim < num_out_dims - 1; ++out_dim) {
- out_dims[out_dim] = out_dim >= dims() ? 1 : dim_size(out_dim);
- prod_out_dims *= out_dims[out_dim];
- }
+ int64 prod_out_dims = 1;
+ for (int64 out_dim = 0; out_dim < num_out_dims - 1; ++out_dim) {
+ out_dims[out_dim] = out_dim >= dims() ? 1 : dim_size(out_dim);
+ prod_out_dims *= out_dims[out_dim];
+ }
+ if (prod_out_dims != 0) {
out_dims[num_out_dims - 1] = num_elements / prod_out_dims;
+ } else {
+ out_dims[num_out_dims - 1] = 0;
}
return out_dims;
}
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 5abc9c9f52..32d550bbce 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -361,8 +361,11 @@ class Tensor {
void FillDimsAndValidateCompatibleShape(
gtl::ArraySlice<int64> new_sizes,
Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const;
- gtl::InlinedVector<int64, 5> ComputeFlatInnerDims(int64 num_out_dims) const;
- gtl::InlinedVector<int64, 5> ComputeFlatOuterDims(int64 num_out_dims) const;
+
+ // TODO(rmlarsen): These shouldn't hardcode '4' so that it lines up with
+ // TensorShape's InlineVector.
+ gtl::InlinedVector<int64, 4> ComputeFlatInnerDims(int64 num_out_dims) const;
+ gtl::InlinedVector<int64, 4> ComputeFlatOuterDims(int64 num_out_dims) const;
TensorShape shape_;
TensorBuffer* buf_;
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index ecc0467103..a26a392dfc 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -267,6 +267,46 @@ TEST(Tensor_Float, Reshape) {
EXPECT_EQ(flat_outer_dims(0, 0, 0, 0, 0), 0.01f);
EXPECT_EQ(flat_outer_dims(1, 2, 3, 4, 0), 0.02f);
}
+
+ Tensor zero_t(DT_FLOAT, TensorShape({3, 0, 2, 0, 5}));
+ {
+ auto flat_outer_dims = zero_t.flat_outer_dims<float>();
+ EXPECT_EQ(3, flat_outer_dims.dimension(0));
+ EXPECT_EQ(0, flat_outer_dims.dimension(1));
+ }
+ {
+ auto flat_outer_dims = zero_t.flat_outer_dims<float, 3>();
+ EXPECT_EQ(3, flat_outer_dims.dimension(0));
+ EXPECT_EQ(0, flat_outer_dims.dimension(1));
+ EXPECT_EQ(0, flat_outer_dims.dimension(2));
+ }
+ {
+ auto flat_outer_dims = zero_t.flat_outer_dims<float, 5>();
+ EXPECT_EQ(3, flat_outer_dims.dimension(0));
+ EXPECT_EQ(0, flat_outer_dims.dimension(1));
+ EXPECT_EQ(2, flat_outer_dims.dimension(2));
+ EXPECT_EQ(0, flat_outer_dims.dimension(3));
+ EXPECT_EQ(5, flat_outer_dims.dimension(4));
+ }
+ {
+ auto flat_inner_dims = zero_t.flat_inner_dims<float>();
+ EXPECT_EQ(0, flat_inner_dims.dimension(0));
+ EXPECT_EQ(5, flat_inner_dims.dimension(1));
+ }
+ {
+ auto flat_inner_dims = zero_t.flat_inner_dims<float, 3>();
+ EXPECT_EQ(0, flat_inner_dims.dimension(0));
+ EXPECT_EQ(0, flat_inner_dims.dimension(1));
+ EXPECT_EQ(5, flat_inner_dims.dimension(2));
+ }
+ {
+ auto flat_inner_dims = zero_t.flat_inner_dims<float, 5>();
+ EXPECT_EQ(3, flat_inner_dims.dimension(0));
+ EXPECT_EQ(0, flat_inner_dims.dimension(1));
+ EXPECT_EQ(2, flat_inner_dims.dimension(2));
+ EXPECT_EQ(0, flat_inner_dims.dimension(3));
+ EXPECT_EQ(5, flat_inner_dims.dimension(4));
+ }
}
TEST(Tensor_Scalar, Basics) {
diff --git a/tensorflow/core/kernels/gather_op_test.cc b/tensorflow/core/kernels/gather_op_test.cc
index bbf2683ff1..062e3863d9 100644
--- a/tensorflow/core/kernels/gather_op_test.cc
+++ b/tensorflow/core/kernels/gather_op_test.cc
@@ -78,6 +78,19 @@ TEST_F(GatherOpTest, Simple_TwoD32) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
+TEST_F(GatherOpTest, ZeroSize_TwoD32) {
+ MakeOp(DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 0}), {});
+ AddInputFromArray<int32>(TensorShape({4}), {0, 4, 0, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 0}));
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
TEST_F(GatherOpTest, Simple_TwoD64) {
MakeOp(DT_INT64);