aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/sparse/sparse_tensor_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util/sparse/sparse_tensor_test.cc')
-rw-r--r--tensorflow/core/util/sparse/sparse_tensor_test.cc46
1 files changed, 45 insertions, 1 deletions
diff --git a/tensorflow/core/util/sparse/sparse_tensor_test.cc b/tensorflow/core/util/sparse/sparse_tensor_test.cc
index 5edd6cb1d8..efdd97fd3d 100644
--- a/tensorflow/core/util/sparse/sparse_tensor_test.cc
+++ b/tensorflow/core/util/sparse/sparse_tensor_test.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include <string>
#include <vector>
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -26,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow {
namespace sparse {
@@ -612,6 +612,50 @@ TEST(SparseTensorTest, Split) {
EXPECT_EQ(st_list[1].indices().matrix<int64>()(0, 1), 0);
}
+TEST(SparseTensorTest, Slice) {
+ const int N = 4;
+ const int DIM = 2;
+
+ Tensor ids(DT_INT64, TensorShape({N, DIM}));
+ Tensor vals(DT_INT64, TensorShape({N}));
+
+ ids.matrix<int64>()(0, 0) = 0;
+ ids.matrix<int64>()(0, 1) = 0;
+ ids.matrix<int64>()(1, 0) = 1;
+ ids.matrix<int64>()(1, 1) = 1;
+ ids.matrix<int64>()(2, 0) = 1;
+ ids.matrix<int64>()(2, 1) = 2;
+ ids.matrix<int64>()(3, 0) = 3;
+ ids.matrix<int64>()(3, 1) = 0;
+
+ vals.vec<int64>()(0) = 1;
+ vals.vec<int64>()(1) = 2;
+ vals.vec<int64>()(2) = 3;
+ vals.vec<int64>()(3) = 4;
+
+ SparseTensor st(ids, vals, TensorShape({4, 3}));
+
+ std::vector<int64> start(2, 0);
+ std::vector<int64> size(2);
+ size[0] = 2;
+ size[1] = 3;
+
+ SparseTensor slice = SparseTensor::Slice<int64>(st, start, size);
+
+ EXPECT_EQ(TensorShape(slice.shape()), TensorShape({2, 3}));
+ EXPECT_EQ(slice.values().NumElements(), 3);
+ EXPECT_EQ(slice.values().vec<int64>()(0), 1);
+ EXPECT_EQ(slice.values().vec<int64>()(1), 2);
+ EXPECT_EQ(slice.values().vec<int64>()(2), 3);
+ EXPECT_EQ(slice.indices().NumElements(), 6);
+ EXPECT_EQ(slice.indices().matrix<int64>()(0, 0), 0);
+ EXPECT_EQ(slice.indices().matrix<int64>()(0, 1), 0);
+ EXPECT_EQ(slice.indices().matrix<int64>()(1, 0), 1);
+ EXPECT_EQ(slice.indices().matrix<int64>()(1, 1), 1);
+ EXPECT_EQ(slice.indices().matrix<int64>()(2, 0), 1);
+ EXPECT_EQ(slice.indices().matrix<int64>()(2, 1), 2);
+}
+
TEST(SparseTensorTest, Dim0SparseTensorToDenseTensor) {
Tensor ix(DT_INT64, TensorShape({1, 0}));
Tensor vals(DT_INT32, TensorShape({1}));