diff options
Diffstat (limited to 'tensorflow/core/util/sparse/sparse_tensor_test.cc')
-rw-r--r-- | tensorflow/core/util/sparse/sparse_tensor_test.cc | 46 |
1 files changed, 45 insertions, 1 deletions
diff --git a/tensorflow/core/util/sparse/sparse_tensor_test.cc b/tensorflow/core/util/sparse/sparse_tensor_test.cc index 5edd6cb1d8..efdd97fd3d 100644 --- a/tensorflow/core/util/sparse/sparse_tensor_test.cc +++ b/tensorflow/core/util/sparse/sparse_tensor_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include <string> #include <vector> -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -26,6 +25,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { namespace sparse { @@ -612,6 +612,50 @@ TEST(SparseTensorTest, Split) { EXPECT_EQ(st_list[1].indices().matrix<int64>()(0, 1), 0); } +TEST(SparseTensorTest, Slice) { + const int N = 4; + const int DIM = 2; + + Tensor ids(DT_INT64, TensorShape({N, DIM})); + Tensor vals(DT_INT64, TensorShape({N})); + + ids.matrix<int64>()(0, 0) = 0; + ids.matrix<int64>()(0, 1) = 0; + ids.matrix<int64>()(1, 0) = 1; + ids.matrix<int64>()(1, 1) = 1; + ids.matrix<int64>()(2, 0) = 1; + ids.matrix<int64>()(2, 1) = 2; + ids.matrix<int64>()(3, 0) = 3; + ids.matrix<int64>()(3, 1) = 0; + + vals.vec<int64>()(0) = 1; + vals.vec<int64>()(1) = 2; + vals.vec<int64>()(2) = 3; + vals.vec<int64>()(3) = 4; + + SparseTensor st(ids, vals, TensorShape({4, 3})); + + std::vector<int64> start(2, 0); + std::vector<int64> size(2); + size[0] = 2; + size[1] = 3; + + SparseTensor slice = SparseTensor::Slice<int64>(st, start, size); + + EXPECT_EQ(TensorShape(slice.shape()), TensorShape({2, 3})); + EXPECT_EQ(slice.values().NumElements(), 3); + EXPECT_EQ(slice.values().vec<int64>()(0), 1); + EXPECT_EQ(slice.values().vec<int64>()(1), 2); + EXPECT_EQ(slice.values().vec<int64>()(2), 3); + EXPECT_EQ(slice.indices().NumElements(), 6); + EXPECT_EQ(slice.indices().matrix<int64>()(0, 0), 0); + EXPECT_EQ(slice.indices().matrix<int64>()(0, 1), 0); + EXPECT_EQ(slice.indices().matrix<int64>()(1, 0), 1); + EXPECT_EQ(slice.indices().matrix<int64>()(1, 1), 1); + EXPECT_EQ(slice.indices().matrix<int64>()(2, 0), 1); + EXPECT_EQ(slice.indices().matrix<int64>()(2, 1), 2); +} + TEST(SparseTensorTest, Dim0SparseTensorToDenseTensor) { Tensor ix(DT_INT64, TensorShape({1, 0})); Tensor vals(DT_INT32, TensorShape({1})); |