diff options
-rw-r--r-- | tensorflow/core/framework/tensor.cc | 45 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor.h | 7 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor_test.cc | 40 | ||||
-rw-r--r-- | tensorflow/core/kernels/gather_op_test.cc | 13 |
4 files changed, 85 insertions, 20 deletions
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index c928eccec3..6d989fd1d6 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -714,34 +714,43 @@ void Tensor::FillDescription(TensorDescription* description) const { } } -gtl::InlinedVector<int64, 5> Tensor::ComputeFlatInnerDims( +gtl::InlinedVector<int64, 4> Tensor::ComputeFlatInnerDims( int64 num_out_dims) const { - gtl::InlinedVector<int64, 5> out_dims(num_out_dims, 0); + if (num_out_dims == dims()) { + return shape_.dim_sizes(); + } + gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0); const int64 num_elements = NumElements(); - if (num_elements != 0) { - int64 prod_out_dims = 1; - for (int64 out_dim = num_out_dims - 1; out_dim > 0; --out_dim) { - const int64 in_dim = out_dim + (dims() - num_out_dims); - out_dims[out_dim] = - (in_dim >= dims() || in_dim < 0) ? 1 : dim_size(in_dim); - prod_out_dims *= out_dims[out_dim]; - } + int64 prod_out_dims = 1; + for (int64 out_dim = num_out_dims - 1; out_dim > 0; --out_dim) { + const int64 in_dim = out_dim + (dims() - num_out_dims); + out_dims[out_dim] = (in_dim >= dims() || in_dim < 0) ? 1 : dim_size(in_dim); + prod_out_dims *= out_dims[out_dim]; + } + if (prod_out_dims != 0) { out_dims[0] = num_elements / prod_out_dims; + } else { + out_dims[0] = 0; } return out_dims; } -gtl::InlinedVector<int64, 5> Tensor::ComputeFlatOuterDims( +gtl::InlinedVector<int64, 4> Tensor::ComputeFlatOuterDims( int64 num_out_dims) const { - gtl::InlinedVector<int64, 5> out_dims(num_out_dims, 0); + if (num_out_dims == dims()) { + return shape_.dim_sizes(); + } + gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0); const int64 num_elements = NumElements(); - if (num_elements != 0) { - int64 prod_out_dims = 1; - for (int64 out_dim = 0; out_dim < num_out_dims - 1; ++out_dim) { - out_dims[out_dim] = out_dim >= dims() ? 1 : dim_size(out_dim); - prod_out_dims *= out_dims[out_dim]; - } + int64 prod_out_dims = 1; + for (int64 out_dim = 0; out_dim < num_out_dims - 1; ++out_dim) { + out_dims[out_dim] = out_dim >= dims() ? 1 : dim_size(out_dim); + prod_out_dims *= out_dims[out_dim]; + } + if (prod_out_dims != 0) { out_dims[num_out_dims - 1] = num_elements / prod_out_dims; + } else { + out_dims[num_out_dims - 1] = 0; } return out_dims; } diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 5abc9c9f52..32d550bbce 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -361,8 +361,11 @@ class Tensor { void FillDimsAndValidateCompatibleShape( gtl::ArraySlice<int64> new_sizes, Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const; - gtl::InlinedVector<int64, 5> ComputeFlatInnerDims(int64 num_out_dims) const; - gtl::InlinedVector<int64, 5> ComputeFlatOuterDims(int64 num_out_dims) const; + + // TODO(rmlarsen): These shouldn't hardcode '4' so that it lines up with + // TensorShape's InlineVector. + gtl::InlinedVector<int64, 4> ComputeFlatInnerDims(int64 num_out_dims) const; + gtl::InlinedVector<int64, 4> ComputeFlatOuterDims(int64 num_out_dims) const; TensorShape shape_; TensorBuffer* buf_; diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index ecc0467103..a26a392dfc 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -267,6 +267,46 @@ TEST(Tensor_Float, Reshape) { EXPECT_EQ(flat_outer_dims(0, 0, 0, 0, 0), 0.01f); EXPECT_EQ(flat_outer_dims(1, 2, 3, 4, 0), 0.02f); } + + Tensor zero_t(DT_FLOAT, TensorShape({3, 0, 2, 0, 5})); + { + auto flat_outer_dims = zero_t.flat_outer_dims<float>(); + EXPECT_EQ(3, flat_outer_dims.dimension(0)); + EXPECT_EQ(0, flat_outer_dims.dimension(1)); + } + { + auto flat_outer_dims = zero_t.flat_outer_dims<float, 3>(); + EXPECT_EQ(3, flat_outer_dims.dimension(0)); + EXPECT_EQ(0, flat_outer_dims.dimension(1)); + EXPECT_EQ(0, flat_outer_dims.dimension(2)); + } + { + auto flat_outer_dims = zero_t.flat_outer_dims<float, 5>(); + EXPECT_EQ(3, flat_outer_dims.dimension(0)); + EXPECT_EQ(0, flat_outer_dims.dimension(1)); + EXPECT_EQ(2, flat_outer_dims.dimension(2)); + EXPECT_EQ(0, flat_outer_dims.dimension(3)); + EXPECT_EQ(5, flat_outer_dims.dimension(4)); + } + { + auto flat_inner_dims = zero_t.flat_inner_dims<float>(); + EXPECT_EQ(0, flat_inner_dims.dimension(0)); + EXPECT_EQ(5, flat_inner_dims.dimension(1)); + } + { + auto flat_inner_dims = zero_t.flat_inner_dims<float, 3>(); + EXPECT_EQ(0, flat_inner_dims.dimension(0)); + EXPECT_EQ(0, flat_inner_dims.dimension(1)); + EXPECT_EQ(5, flat_inner_dims.dimension(2)); + } + { + auto flat_inner_dims = zero_t.flat_inner_dims<float, 5>(); + EXPECT_EQ(3, flat_inner_dims.dimension(0)); + EXPECT_EQ(0, flat_inner_dims.dimension(1)); + EXPECT_EQ(2, flat_inner_dims.dimension(2)); + EXPECT_EQ(0, flat_inner_dims.dimension(3)); + EXPECT_EQ(5, flat_inner_dims.dimension(4)); + } } TEST(Tensor_Scalar, Basics) { diff --git a/tensorflow/core/kernels/gather_op_test.cc b/tensorflow/core/kernels/gather_op_test.cc index bbf2683ff1..062e3863d9 100644 --- a/tensorflow/core/kernels/gather_op_test.cc +++ b/tensorflow/core/kernels/gather_op_test.cc @@ -78,6 +78,19 @@ TEST_F(GatherOpTest, Simple_TwoD32) { test::ExpectTensorEqual<float>(expected, *GetOutput(0)); } +TEST_F(GatherOpTest, ZeroSize_TwoD32) { + MakeOp(DT_INT32); + + // Feed and run + AddInputFromArray<float>(TensorShape({5, 0}), {}); + AddInputFromArray<int32>(TensorShape({4}), {0, 4, 0, 2}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output. + Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 0})); + test::ExpectTensorEqual<float>(expected, *GetOutput(0)); +} + TEST_F(GatherOpTest, Simple_TwoD64) { MakeOp(DT_INT64); |