aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/sparse
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util/sparse')
-rw-r--r--tensorflow/core/util/sparse/README.md222
-rw-r--r--tensorflow/core/util/sparse/dim_comparator.h60
-rw-r--r--tensorflow/core/util/sparse/group_iterator.cc49
-rw-r--r--tensorflow/core/util/sparse/group_iterator.h120
-rw-r--r--tensorflow/core/util/sparse/sparse_tensor.h353
-rw-r--r--tensorflow/core/util/sparse/sparse_tensor_test.cc467
6 files changed, 1271 insertions, 0 deletions
diff --git a/tensorflow/core/util/sparse/README.md b/tensorflow/core/util/sparse/README.md
new file mode 100644
index 0000000000..7b0799eb0e
--- /dev/null
+++ b/tensorflow/core/util/sparse/README.md
@@ -0,0 +1,222 @@
+SparseTensor
+============
+
+Sparse Tensors are stored as two dense tensors and a shape:
+
+* `indices`: a `brain::Tensor` storing a matrix, typically `int64`
+* `values`: a `brain::Tensor` storing a vector with values of type T.
+* `shape`: a `TensorShape` storing the bounds of the underlying tensor
+* `order`: (optional) a `gtl::InlinedVector<int64,8>` with the dimensions
+ along which the indices are ordered.
+
+Let
+
+ ix = indices.matrix<int64>()
+ vals = values.vec<T>()
+
+The shape of `ix` is `N x NDIMS`, and each row corresponds to the
+index of a single element of the sparse tensor.
+
+The length of `vals` must be `N`, and `vals(i)` corresponds to the
+value with index `ix(i,:)`.
+
+Shape must be a `TensorShape` with `dims() == NDIMS`.
+The shape is the full shape of the dense tensor these indices
+represent.
+
+To be specific, the representation (pseudocode) is:
+
+ tensor[ix[i,:]] == vals[i] for i = 0, ..., N-1
+
+Ordering
+--------
+
+Indices need not be provided in order. For example, the following
+index matrix is ordered according to dimension order `{0, 1, 2}`.
+
+ [0 0 1]
+ [0 1 1]
+ [2 0 2]
+
+However, you can provide an unordered version:
+
+ [2 0 2]
+ [0 0 1]
+ [0 1 1]
+
+If the SparseTensor is constructed without a provided order, then a
+the default order is `{-1, ..., -1}`. Certain operations will fail or crash
+when the order is not provided.
+
+Resorting the SparseTensor in-place (which resorts the underlying index and
+values tensors in-place) will update the order. The cost of reordering the
+matrix is `O(N*log(N))`, and requires `O(N)` additional temporary space to store
+a reordering index. If the default order is not specified and reordering is not
+performed, the following will happen:
+
+* `group()` will **raise an assertion failure**
+* `IndicesValid()` will **raise an assertion failure**
+
+To update the internal index ordering after construction, call
+`Reorder<T>()` via, e.g., `Reorder<T>({0,1,2})`.
+After this step, all the above methods should work correctly.
+
+The method `IndicesValid()` checks to make sure:
+
+* `0 <= ix(i, d) < shape.dim_size(d)`
+* indices do not repeat
+* indices are in order
+
+Iterating
+---------
+
+### group({grouping dims})
+
+* provides an iterator that groups entries according to
+ dimensions you care about
+* may require a sort if your data isn't presorted in a way that's
+ compatible with grouping_dims
+* for each group, returns the group index (values of the group
+ dims for this iteration), the subset of indices in this group,
+ and the subset of values in this group. these are lazy outputs
+ so to read them individually, copy them as per the example
+ below.
+
+#### **NOTE**
+`group({dim0, ..., dimk})` will **raise an assertion failure** if the
+order of the SparseTensor does not match the dimensions you wish to group by.
+You must either have your indices in the correct order and construct the
+SparseTensor with
+
+ order = {dim0, ..., dimk, ...}
+
+or call
+
+ Reorder<T>({dim0, .., dimk, ...})
+
+to sort the SparseTensor before grouping.
+
+Example of grouping:
+
+ Tensor indices(DT_INT64, TensorShape({N, NDIMS});
+ Tensor values(DT_STRING, TensorShape({N});
+ TensorShape shape({dim0,...});
+ SparseTensor sp(indices, vals, shape);
+ sp.Reorder<string>({1, 2, 0, 3, ...}); // Must provide NDIMS dims.
+ // group according to dims 1 and 2
+ for (const auto& g : sp.group({1, 2})) {
+ cout << "vals of ix[:, 1,2] for this group: "
+ << g.group()[0] << ", " << g.group()[1];
+ cout << "full indices of group:\n" << g.indices();
+ cout << "values of group:\n" << g.values();
+
+ TTypes<int64>::UnalignedMatrix g_ix = g.indices();
+ TTypes<string>::UnalignedVec g_v = g.values();
+ ASSERT(g_ix.dimension(0) == g_v.size()); // number of elements match.
+ }
+
+
+ToDense
+--------
+
+Converts sparse tensor to dense. You must provide a pointer to the
+dense tensor (preallocated). `ToDense()` will optionally
+preinitialize the tensor with zeros.
+
+Shape checking is performed, as is boundary checking.
+
+ Tensor indices(DT_INT64, TensorShape({N, NDIMS});
+ Tensor values(DT_STRING, TensorShape({N});
+ TensorShape shape({dim0,...});
+ SparseTensor sp(indices, vals, shape);
+ ASSERT(sp.IndicesValid()); // checks ordering & index bounds.
+
+ Tensor dense(DT_STRING, shape);
+ // initialize other indices to zero. copy.
+ ASSERT(sp.ToDense<string>(&dense, true));
+
+
+Concat
+--------
+
+Concatenates multiple SparseTensors and returns a new SparseTensor.
+This concatenation is with respect to the "dense" versions of these
+SparseTensors. Concatenation is performed along dimension order[0]
+of all tensors. As a result, shape[order[0]] may differ across
+the inputs, but shape[d] for d != order[0] must match across all inputs.
+
+We call order[0] the **primary dimension**.
+
+**Prerequisites**
+
+* The inputs' ranks must all match.
+* The inputs' order[0] must all match.
+* The inputs' shapes must all match except for dimension order[0].
+* The inputs' values must all be of the same type.
+
+If any of these are false, concat will die with an assertion failure.
+
+Example:
+Concatenate two sparse matrices along columns.
+
+Matrix 1:
+
+ [0 0 1]
+ [2 0 0]
+ [3 0 4]
+
+Matrix 2:
+
+ [0 0 0 0 0]
+ [0 1 0 0 0]
+ [2 0 0 1 0]
+
+Concatenated Matrix:
+
+ [0 0 1 0 0 0 0 0]
+ [2 0 0 0 1 0 0 0]
+ [3 0 4 2 0 0 1 0]
+
+Expected input shapes, orders, and `nnz()`:
+
+ shape_1 = TensorShape({3, 3})
+ shape_2 = TensorShape({3, 8})
+ order_1 = {1, 0} // primary order is 1, columns
+ order_2 = {1, 0} // primary order is 1, must match
+ nnz_1 = 4
+ nnz_2 = 3
+
+Output shapes and orders:
+
+ conc_shape = TensorShape({3, 11}) // primary dim increased, others same
+ conc_order = {1, 0} // Orders match along all inputs
+ conc_nnz = 7 // Sum of nonzeros of inputs
+
+Coding Example:
+
+ Tensor ix1(DT_INT64, TensorShape({N1, 3});
+ Tensor vals1(DT_STRING, TensorShape({N1, 3});
+ Tensor ix2(DT_INT64, TensorShape({N2, 3});
+ Tensor vals2(DT_STRING, TensorShape({N2, 3});
+ Tensor ix3(DT_INT64, TensorShape({N3, 3});
+ Tensor vals3(DT_STRING, TensorShape({N3, 3});
+
+ SparseTensor st1(ix1, vals1, TensorShape({10, 20, 5}), {1, 0, 2});
+ SparseTensor st2(ix2, vals2, TensorShape({10, 10, 5}), {1, 0, 2});
+ // For kicks, st3 indices are out of order, but order[0] matches so we
+ // can still concatenate along this dimension.
+ SparseTensor st3(ix3, vals3, TensorShape({10, 30, 5}), {1, 2, 0});
+
+ SparseTensor conc = SparseTensor::Concat<string>({st1, st2, st3});
+ Tensor ix_conc = conc.indices();
+ Tensor vals_conc = conc.values();
+ EXPECT_EQ(conc.nnz(), st1.nnz() + st2.nnz() + st3.nnz());
+ EXPECT_EQ(conc.Shape(), TensorShape({10, 60, 5}));
+ EXPECT_EQ(conc.Order(), {-1, -1, -1});
+
+ // Reorder st3 so all input tensors have the exact same orders.
+ st3.Reorder<string>({1, 0, 2});
+ SparseTensor conc2 = SparseTensor::Concat<string>({st1, st2, st3});
+ EXPECT_EQ(conc2.Order(), {1, 0, 2});
+ // All indices' orders matched, so output is in order.
+ EXPECT_TRUE(conc2.IndicesValid());
diff --git a/tensorflow/core/util/sparse/dim_comparator.h b/tensorflow/core/util/sparse/dim_comparator.h
new file mode 100644
index 0000000000..57473867cf
--- /dev/null
+++ b/tensorflow/core/util/sparse/dim_comparator.h
@@ -0,0 +1,60 @@
+#ifndef TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
+#define TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace sparse {
+
+/////////////////
+// DimComparator
+/////////////////
+//
+// Helper class, mainly used by the IndexSortOrder. This comparator
+// can be passed to e.g. std::sort, or any other sorter, to sort two
+// rows of an index matrix according to the dimension(s) of interest.
+// The dimensions to sort by are passed to the constructor as "order".
+//
+// Example: if given index matrix IX, two rows ai and bi, and order = {2,1}.
+// operator() compares
+// IX(ai,2) < IX(bi,2).
+// If IX(ai,2) == IX(bi,2), it compares
+// IX(ai,1) < IX(bi,1).
+//
+// This can be used to sort a vector of row indices into IX according to
+// the values in IX in particular columns (dimensions) of interest.
+class DimComparator {
+ public:
+ typedef typename gtl::ArraySlice<int64> VarDimArray;
+
+ inline DimComparator(const TTypes<int64>::Matrix& ix,
+ const VarDimArray& order, int dims)
+ : ix_(ix), order_(order), dims_(dims) {
+ CHECK_GT(order.size(), 0) << "Must order using at least one index";
+ CHECK_LE(order.size(), dims_) << "Can only sort up to dims";
+ for (size_t d = 0; d < order.size(); ++d) {
+ CHECK_GE(order[d], 0);
+ CHECK_LT(order[d], dims);
+ }
+ }
+
+ inline bool operator()(const int64 i, const int64 j) const {
+ for (int di = 0; di < dims_; ++di) {
+ const int64 d = order_[di];
+ if (ix_(i, d) < ix_(j, d)) return true;
+ if (ix_(i, d) > ix_(j, d)) return false;
+ }
+ return false;
+ }
+
+ const TTypes<int64>::Matrix ix_;
+ const VarDimArray order_;
+ const int dims_;
+};
+
+} // namespace sparse
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
diff --git a/tensorflow/core/util/sparse/group_iterator.cc b/tensorflow/core/util/sparse/group_iterator.cc
new file mode 100644
index 0000000000..e153bcdbb4
--- /dev/null
+++ b/tensorflow/core/util/sparse/group_iterator.cc
@@ -0,0 +1,49 @@
+#include "tensorflow/core/util/sparse/group_iterator.h"
+
+namespace tensorflow {
+namespace sparse {
+
+void GroupIterable::IteratorStep::UpdateEndOfGroup() {
+ ++next_loc_;
+ int64 N = iter_->ix_.dim_size(0);
+ auto ix_t = iter_->ix_.template matrix<int64>();
+ while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) {
+ ++next_loc_;
+ }
+}
+
+bool GroupIterable::IteratorStep::operator!=(const IteratorStep& rhs) const {
+ CHECK_EQ(rhs.iter_, iter_) << "Can't compare steps from different iterators";
+ return (rhs.loc_ != loc_);
+}
+
+GroupIterable::IteratorStep& GroupIterable::IteratorStep::
+operator++() { // prefix ++
+ loc_ = next_loc_;
+ UpdateEndOfGroup();
+ return *this;
+}
+
+GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++(
+ int) { // postfix ++
+ IteratorStep lhs(*this);
+ ++(*this);
+ return lhs;
+}
+
+std::vector<int64> Group::group() const {
+ std::vector<int64> g;
+ auto ix_t = iter_->ix_.template matrix<int64>();
+ for (const int d : iter_->group_dims_) {
+ g.push_back(ix_t(loc_, d));
+ }
+ return g;
+}
+
+TTypes<int64>::UnalignedConstMatrix Group::indices() const {
+ return TTypes<int64>::UnalignedConstMatrix(
+ &(iter_->ix_.matrix<int64>()(loc_, 0)), next_loc_ - loc_, iter_->dims_);
+}
+
+} // namespace sparse
+} // namespace tensorflow
diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h
new file mode 100644
index 0000000000..8423d54f27
--- /dev/null
+++ b/tensorflow/core/util/sparse/group_iterator.h
@@ -0,0 +1,120 @@
+#ifndef TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
+#define TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace sparse {
+
+class GroupIterable; // Predeclare GroupIterable for Group.
+
+// This class is returned when dereferencing a GroupIterable iterator.
+// It provides the methods group(), indices(), and values(), which
+// provide access into the underlying SparseTensor.
+class Group {
+ public:
+ Group(GroupIterable* iter, int64 loc, int64 next_loc)
+ : iter_(iter), loc_(loc), next_loc_(next_loc) {}
+
+ std::vector<int64> group() const;
+ TTypes<int64>::UnalignedConstMatrix indices() const;
+ template <typename T>
+ typename TTypes<T>::UnalignedVec values() const;
+
+ private:
+ GroupIterable* iter_;
+ int64 loc_;
+ int64 next_loc_;
+};
+
+/////////////////
+// GroupIterable
+/////////////////
+//
+// Returned when calling sparse_tensor.group({dim0, dim1, ...}).
+//
+// Please note: the sparse_tensor should already be ordered according
+// to {dim0, dim1, ...}. Otherwise this iteration will return invalid groups.
+//
+// Allows grouping and iteration of the SparseTensor according to the
+// subset of dimensions provided to the group call.
+//
+// The actual grouping dimensions are stored in the
+// internal vector group_dims_. Iterators inside the iterable provide
+// the three methods:
+//
+// * group(): returns a vector with the current group dimension values.
+// * indices(): a map of index, providing the indices in
+// this group.
+// * values(): a map of values, providing the values in
+// this group.
+//
+// To iterate across GroupIterable, see examples in README.md.
+//
+
+// Forward declaration of SparseTensor
+class GroupIterable {
+ public:
+ typedef gtl::ArraySlice<int64> VarDimArray;
+
+ GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
+ : ix_(ix), vals_(vals), dims_(dims), group_dims_(group_dims) {}
+
+ class IteratorStep;
+
+ IteratorStep begin() { return IteratorStep(this, 0); }
+ IteratorStep end() { return IteratorStep(this, ix_.dim_size(0)); }
+
+ template <typename TIX>
+ inline bool GroupMatches(const TIX& ix, int64 loc_a, int64 loc_b) const {
+ bool matches = true;
+ for (int d : group_dims_) {
+ if (ix(loc_a, d) != ix(loc_b, d)) {
+ matches = false;
+ }
+ }
+ return matches;
+ }
+
+ class IteratorStep {
+ public:
+ IteratorStep(GroupIterable* iter, int64 loc)
+ : iter_(iter), loc_(loc), next_loc_(loc_) {
+ UpdateEndOfGroup();
+ }
+
+ void UpdateEndOfGroup();
+ bool operator!=(const IteratorStep& rhs) const;
+ IteratorStep& operator++(); // prefix ++
+ IteratorStep operator++(int); // postfix ++
+ Group operator*() const { return Group(iter_, loc_, next_loc_); }
+
+ private:
+ GroupIterable* iter_;
+ int64 loc_;
+ int64 next_loc_;
+ };
+
+ private:
+ friend class Group;
+ Tensor ix_;
+ Tensor vals_;
+ const int dims_;
+ const VarDimArray group_dims_;
+};
+
+// Implementation of Group::values<T>()
+template <typename T>
+typename TTypes<T>::UnalignedVec Group::values() const {
+ return typename TTypes<T>::UnalignedVec(&(iter_->vals_.vec<T>()(loc_)),
+ next_loc_ - loc_);
+}
+
+} // namespace sparse
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h
new file mode 100644
index 0000000000..dcb75e7f54
--- /dev/null
+++ b/tensorflow/core/util/sparse/sparse_tensor.h
@@ -0,0 +1,353 @@
+#ifndef TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
+#define TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
+
+#include <limits>
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/sparse/dim_comparator.h"
+#include "tensorflow/core/util/sparse/group_iterator.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace sparse {
+
+class SparseTensor {
+ public:
+ typedef typename gtl::ArraySlice<int64> VarDimArray;
+
+ SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
+ : SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {}
+
+ SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
+ const VarDimArray& order)
+ : ix_(ix),
+ vals_(vals),
+ shape_(shape),
+ order_(order.begin(), order.end()),
+ dims_(GetDimsFromIx(ix)) {
+ CHECK_EQ(ix.dtype(), DT_INT64) << "indices must be type int64 but got: "
+ << ix.dtype();
+ CHECK(TensorShapeUtils::IsMatrix(ix.shape()))
+ << "indices must be a matrix, but got: " << ix.shape().DebugString();
+ CHECK(TensorShapeUtils::IsVector(vals.shape()))
+ << "vals must be a vec, but got: " << vals.shape().DebugString();
+ CHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0))
+ << "indices and values rows (indexing dimension) must match.";
+ }
+
+ std::size_t num_entries() const { return ix_.dim_size(0); }
+
+ const Tensor& indices() const { return ix_; }
+
+ const Tensor& values() const { return vals_; }
+
+ DataType dtype() const { return vals_.dtype(); }
+
+ bool IndicesValid() const {
+ const auto ix_t = ix_.matrix<int64>();
+ for (int64 ord : order_) {
+ CHECK_GE(ord, 0) << "Order was not provided. Provide an order at "
+ "construction time or run ReorderInPlace";
+ }
+
+ for (std::size_t n = 0; n < num_entries(); ++n) {
+ if (!IndexValid(ix_t, n)) return false;
+ }
+
+ return true;
+ }
+
+ // Returns the tensor shape (the dimensions of the "densified"
+ // tensor this tensor represents).
+ const TensorShape shape() const { return shape_; }
+
+ const VarDimArray order() const { return order_; }
+
+ // Resorts the indices and values according to the dimensions in order.
+ template <typename T>
+ void Reorder(const VarDimArray& order);
+
+ // Returns a group iterable that can be used for clumping indices
+ // and values according to the group indices of interest.
+ //
+ // Precondition: order()[0..group_ix.size()] == group_ix.
+ //
+ // See the README.md in this directory for more usage information.
+ GroupIterable group(const VarDimArray& group_ix) {
+ CHECK_LE(group_ix.size(), dims_);
+ for (std::size_t di = 0; di < group_ix.size(); ++di) {
+ CHECK_GE(group_ix[di], 0) << "Group dimension out of range";
+ CHECK_LT(group_ix[di], dims_) << "Group dimension out of range";
+ CHECK_EQ(group_ix[di], order_[di])
+ << "Group dimension does not match sorted order";
+ }
+ return GroupIterable(ix_, vals_, dims_, group_ix);
+ }
+
+ // Stores the sparse indices into the dense tensor out.
+ // Preconditions:
+ // out->shape().dims() == shape().dims()
+ // out->shape().dim_size(d) >= shape(d) for all d
+ //
+ // Returns true on success. False on failure (mismatched dimensions
+ // or out-of-bounds indices).
+ //
+ // If initialize==True, ToDense first overwrites all coefficients in out to 0.
+ //
+ template <typename T>
+ bool ToDense(Tensor* out, bool initialize = true);
+
+ // Concat() will concatenate all the tensors according to their first order
+ // dimension. All tensors must have identical shape except for
+ // the first order dimension. All tensors orders' first dimension
+ // must match.
+ //
+ // If all of the tensors have identical ordering, then the output
+ // will have this ordering. Otherwise the output is set as not
+ // having any order and a Reorder<T>() should be called on it before
+ // performing any subsequent operations.
+ template <typename T>
+ static SparseTensor Concat(const gtl::ArraySlice<SparseTensor>& tensors);
+
+ private:
+ static int GetDimsFromIx(const Tensor& ix) {
+ CHECK(TensorShapeUtils::IsMatrix(ix.shape()));
+ return ix.dim_size(1);
+ }
+
+ static gtl::InlinedVector<int64, 8> UndefinedOrder(const TensorShape& shape) {
+ return gtl::InlinedVector<int64, 8>(shape.dims(), -1);
+ }
+
+ // Helper for IndicesValid()
+ inline bool IndexValid(const TTypes<int64>::ConstMatrix& ix_t,
+ int64 n) const {
+ bool different = false;
+ bool bad_order = false;
+ bool valid = true;
+ if (n == 0) {
+ for (int di = 0; di < dims_; ++di) {
+ if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_.dim_size(di))
+ valid = false;
+ }
+ different = true;
+ } else {
+ for (int di = 0; di < dims_; ++di) {
+ if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_.dim_size(di))
+ valid = false;
+ int64 diff = ix_t(n, order_[di]) - ix_t(n - 1, order_[di]);
+ if (diff > 0) different = true;
+ if (!different && diff < 0) bad_order = true;
+ }
+ }
+ if (!valid) return false; // Out of bounds
+ if (!different) return false; // The past two indices are identical...
+ if (bad_order) return false; // Decreasing in order.
+ return true;
+ }
+
+ // Helper for ToDense<T>()
+ template <typename T>
+ bool ValidateAndInitializeToDense(Tensor* out, bool initialize);
+
+ Tensor ix_;
+ Tensor vals_;
+ TensorShape shape_;
+ gtl::InlinedVector<int64, 8> order_;
+ const int dims_;
+};
+
+// This operation updates the indices and values Tensor rows, so it is
+// an in-place algorithm. It requires O(N log N) time and O(N)
+// temporary space.
+template <typename T>
+void SparseTensor::Reorder(const VarDimArray& order) {
+ CHECK_EQ(DataTypeToEnum<T>::v(), dtype())
+ << "Reorder requested with the wrong datatype";
+ CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank";
+ auto ix_t = ix_.matrix<int64>();
+ auto vals_t = vals_.vec<T>();
+
+ DimComparator sorter(ix_t, order, dims_);
+
+ std::vector<int64> reorder(num_entries());
+ std::iota(reorder.begin(), reorder.end(), 0);
+
+ // Sort to get order of indices
+ std::sort(reorder.begin(), reorder.end(), sorter);
+
+ // We have a forward reordering, but what we'll need is a
+ // permutation (the inverse). This can be calculated with O(1)
+ // additional
+ // and O(n) time (INVPERM) but we just do the simple thing here.
+ std::vector<int64> permutation(reorder.size());
+ for (std::size_t n = 0; n < reorder.size(); ++n) {
+ permutation[reorder[n]] = n;
+ }
+
+ // Update indices & values by converting the permutations to
+ // a product of transpositions. Iterate over the cycles in the
+ // permutation, and convert each of those into a product of
+ // transpositions (swaps):
+ // https://en.wikipedia.org/wiki/Cyclic_permutation
+ // This is N swaps, 2*N comparisons.
+ for (std::size_t n = 0; n + 1 < permutation.size(); ++n) {
+ while (n != permutation[n]) {
+ std::size_t r = permutation[n];
+ std::swap_ranges(&(ix_t(n, 0)), &(ix_t(n + 1, 0)), &(ix_t(r, 0)));
+ std::swap(vals_t(n), vals_t(r));
+ std::swap(permutation[n], permutation[r]);
+ }
+ }
+
+ order_ = gtl::InlinedVector<int64, 8>(order.begin(), order.end());
+}
+
+template <typename T>
+bool SparseTensor::ValidateAndInitializeToDense(Tensor* out, bool initialize) {
+ CHECK_EQ(DataTypeToEnum<T>::v(), dtype())
+ << "ToDense requested with the wrong datatype";
+
+ CHECK_EQ(out->shape().dims(), dims_)
+ << "Incompatible dimensions between SparseTensor and output";
+
+ CHECK_EQ(out->dtype(), DataTypeToEnum<T>::v())
+ << "Output must be type: " << DataTypeToEnum<T>::v()
+ << " but got: " << out->dtype();
+
+ // Make sure the dense output is the same rank and has room
+ // to hold the SparseTensor.
+ const auto& out_shape = out->shape();
+ if (shape_.dims() != out_shape.dims()) return false;
+ for (int d = 0; d < shape_.dims(); ++d) {
+ if (shape_.dim_size(d) > out_shape.dim_size(d)) return false;
+ }
+
+ if (initialize) {
+ auto out_t = out->flat<T>();
+ out_t.setConstant(T());
+ }
+
+ return true;
+}
+
+template <typename T>
+bool SparseTensor::ToDense(Tensor* out, bool initialize) {
+ if (!ValidateAndInitializeToDense<T>(out, initialize)) return false;
+
+ auto out_t = out->flat<T>();
+ auto ix_t = ix_.matrix<int64>();
+ auto vals_t = vals_.vec<T>();
+
+ std::vector<int64> strides(dims_);
+ const auto& out_shape = out->shape();
+ strides[dims_ - 1] = 1;
+ for (int d = dims_ - 2; d >= 0; --d) {
+ strides[d] = strides[d + 1] * out_shape.dim_size(d + 1);
+ }
+
+ for (std::size_t n = 0; n < vals_t.dimension(0); ++n) {
+ bool invalid_dims = false;
+ int64 ix = 0;
+ for (int d = 0; d < dims_; ++d) {
+ const int64 ix_n_d = ix_t(n, d);
+ if (ix_n_d < 0 || ix_n_d >= out_shape.dim_size(d)) {
+ invalid_dims = true;
+ }
+ ix += strides[d] * ix_n_d;
+ }
+ if (invalid_dims) return false;
+ out_t(ix) = vals_t(n);
+ }
+ return true;
+}
+
+template <typename T>
+SparseTensor SparseTensor::Concat(
+ const gtl::ArraySlice<SparseTensor>& tensors) {
+ CHECK_GE(tensors.size(), 1) << "Cannot concat 0 SparseTensors";
+ const int dims = tensors[0].dims_;
+ CHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors";
+ auto order_0 = tensors[0].order();
+ const int primary_dim = order_0[0];
+ gtl::InlinedVector<int64, 8> final_order(order_0.begin(), order_0.end());
+ TensorShape final_shape(tensors[0].shape());
+ final_shape.set_dim(primary_dim, 0); // We'll build this up as we go along.
+ int num_entries = 0;
+
+ bool fully_ordered = true;
+ for (const SparseTensor& st : tensors) {
+ CHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank.";
+ CHECK_EQ(DataTypeToEnum<T>::v(), st.dtype())
+ << "Concat requested with the wrong data type";
+ CHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered";
+ CHECK_EQ(st.order()[0], primary_dim)
+ << "All SparseTensors' order[0] must match. This is the concat dim.";
+ if (st.order() != final_order) fully_ordered = false;
+ const TensorShape st_shape = st.shape();
+ for (int d = 0; d < dims - 1; ++d) {
+ const int cdim = (d < primary_dim) ? d : d + 1;
+ CHECK_EQ(final_shape.dim_size(cdim), st_shape.dim_size(cdim))
+ << "All SparseTensors' shapes must match except on the concat dim. "
+ << "Concat dim: " << primary_dim
+ << ", mismatched shape at dim: " << cdim
+ << ". Expecting shape like: " << final_shape.DebugString()
+ << " but saw shape: " << st_shape.DebugString();
+ }
+
+ // Update dimension of final shape
+ final_shape.set_dim(primary_dim, final_shape.dim_size(primary_dim) +
+ st_shape.dim_size(primary_dim));
+
+ num_entries += st.num_entries(); // Update number of entries
+ }
+
+ // If nonconsistent ordering among inputs, set final order to -1s.
+ if (!fully_ordered) {
+ final_order = UndefinedOrder(final_shape);
+ }
+
+ Tensor output_ix(DT_INT64, TensorShape({num_entries, dims}));
+ Tensor output_vals(DataTypeToEnum<T>::v(), TensorShape({num_entries}));
+
+ auto ix_t = output_ix.matrix<int64>();
+ auto vals_t = output_vals.vec<T>();
+
+ Eigen::DenseIndex offset = 0;
+ int64 shape_offset = 0;
+ for (const SparseTensor& st : tensors) {
+ int st_num_entries = st.num_entries();
+ Eigen::DSizes<Eigen::DenseIndex, 2> ix_start(offset, 0);
+ Eigen::DSizes<Eigen::DenseIndex, 2> ix_size(st_num_entries, dims);
+ Eigen::DSizes<Eigen::DenseIndex, 1> vals_start(offset);
+ Eigen::DSizes<Eigen::DenseIndex, 1> vals_size(st_num_entries);
+
+ // Fill in indices & values.
+ ix_t.slice(ix_start, ix_size) = st.ix_.matrix<int64>();
+ vals_t.slice(vals_start, vals_size) = st.vals_.vec<T>();
+
+ Eigen::DSizes<Eigen::DenseIndex, 2> ix_update_start(offset, primary_dim);
+ Eigen::DSizes<Eigen::DenseIndex, 2> ix_update_size(st_num_entries, 1);
+ // The index associated with the primary dimension gets increased
+ // by the shapes of the previous concatted Tensors.
+ auto update_slice = ix_t.slice(ix_update_start, ix_update_size);
+ update_slice += update_slice.constant(shape_offset);
+
+ offset += st_num_entries;
+ shape_offset += st.shape().dim_size(primary_dim);
+ }
+
+ return SparseTensor(output_ix, output_vals, final_shape, final_order);
+}
+
+} // namespace sparse
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
diff --git a/tensorflow/core/util/sparse/sparse_tensor_test.cc b/tensorflow/core/util/sparse/sparse_tensor_test.cc
new file mode 100644
index 0000000000..47126b7187
--- /dev/null
+++ b/tensorflow/core/util/sparse/sparse_tensor_test.cc
@@ -0,0 +1,467 @@
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace sparse {
+namespace {
+
+Eigen::Tensor<int64, 2, Eigen::RowMajor, Eigen::DenseIndex>
+GetSimpleIndexTensor(int N, const int NDIM) {
+ Eigen::Tensor<int64, 2, Eigen::RowMajor, Eigen::DenseIndex> ix(N, NDIM);
+ ix(0, 0) = 0;
+ ix(0, 1) = 0;
+ ix(0, 2) = 0;
+
+ ix(1, 0) = 3;
+ ix(1, 1) = 0;
+ ix(1, 2) = 0;
+
+ ix(2, 0) = 2;
+ ix(2, 1) = 0;
+ ix(2, 2) = 0;
+
+ ix(3, 0) = 0;
+ ix(3, 1) = 1;
+ ix(3, 2) = 0;
+
+ ix(4, 0) = 0;
+ ix(4, 1) = 0;
+ ix(4, 2) = 2;
+ return ix;
+}
+
+TEST(SparseTensorTest, DimComparatorSorts) {
+ std::size_t N = 5;
+ const int NDIM = 3;
+ auto ix = GetSimpleIndexTensor(N, NDIM);
+ TTypes<int64>::Matrix map(ix.data(), N, NDIM);
+
+ std::vector<int64> sorting(N);
+ for (std::size_t n = 0; n < N; ++n) sorting[n] = n;
+
+ // new order should be: {0, 4, 3, 2, 1}
+ std::vector<int64> order{0, 1, 2};
+ DimComparator sorter(map, order, NDIM);
+ std::sort(sorting.begin(), sorting.end(), sorter);
+
+ EXPECT_EQ(sorting, std::vector<int64>({0, 4, 3, 2, 1}));
+
+ // new order should be: {0, 3, 2, 1, 4}
+ std::vector<int64> order1{2, 0, 1};
+ DimComparator sorter1(map, order1, NDIM);
+ for (std::size_t n = 0; n < N; ++n) sorting[n] = n;
+ std::sort(sorting.begin(), sorting.end(), sorter1);
+
+ EXPECT_EQ(sorting, std::vector<int64>({0, 3, 2, 1, 4}));
+}
+
+TEST(SparseTensorTest, SparseTensorConstruction) {
+ int N = 5;
+ const int NDIM = 3;
+ auto ix_c = GetSimpleIndexTensor(N, NDIM);
+ Eigen::Tensor<string, 1, Eigen::RowMajor> vals_c(N);
+ vals_c(0) = "hi0";
+ vals_c(1) = "hi1";
+ vals_c(2) = "hi2";
+ vals_c(3) = "hi3";
+ vals_c(4) = "hi4";
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ auto ix_t = ix.matrix<int64>();
+ auto vals_t = vals.vec<string>();
+ vals_t = vals_c;
+ ix_t = ix_c;
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+ SparseTensor st(ix, vals, shape, order);
+ EXPECT_FALSE(st.IndicesValid()); // Out of order
+
+ // Regardless of how order is updated; so long as there are no
+ // duplicates, the resulting indices are valid.
+ st.Reorder<string>({2, 0, 1});
+ EXPECT_TRUE(st.IndicesValid());
+ EXPECT_EQ(vals_t(0), "hi0");
+ EXPECT_EQ(vals_t(1), "hi3");
+ EXPECT_EQ(vals_t(2), "hi2");
+ EXPECT_EQ(vals_t(3), "hi1");
+ EXPECT_EQ(vals_t(4), "hi4");
+
+ ix_t = ix_c;
+ vals_t = vals_c;
+ st.Reorder<string>({0, 1, 2});
+ EXPECT_TRUE(st.IndicesValid());
+ EXPECT_EQ(vals_t(0), "hi0");
+ EXPECT_EQ(vals_t(1), "hi4");
+ EXPECT_EQ(vals_t(2), "hi3");
+ EXPECT_EQ(vals_t(3), "hi2");
+ EXPECT_EQ(vals_t(4), "hi1");
+
+ ix_t = ix_c;
+ vals_t = vals_c;
+ st.Reorder<string>({2, 1, 0});
+ EXPECT_TRUE(st.IndicesValid());
+}
+
+TEST(SparseTensorTest, EmptySparseTensorAllowed) {
+ int N = 0;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+ SparseTensor st(ix, vals, shape, order);
+ EXPECT_TRUE(st.IndicesValid());
+ EXPECT_EQ(st.order(), order);
+
+ std::vector<int64> new_order{1, 0, 2};
+ st.Reorder<string>(new_order);
+ EXPECT_TRUE(st.IndicesValid());
+ EXPECT_EQ(st.order(), new_order);
+}
+
+TEST(SparseTensorTest, SortingWorksCorrectly) {
+ int N = 30;
+ const int NDIM = 4;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+ TensorShape shape({1000, 1000, 1000, 1000});
+ SparseTensor st(ix, vals, shape);
+
+ auto ix_t = ix.matrix<int64>();
+
+ for (int n = 0; n < 100; ++n) {
+ ix_t = ix_t.random(Eigen::internal::UniformRandomGenerator<int64>(n + 1));
+ ix_t = ix_t.abs() % 1000;
+ st.Reorder<string>({0, 1, 2, 3});
+ EXPECT_TRUE(st.IndicesValid());
+ st.Reorder<string>({3, 2, 1, 0});
+ EXPECT_TRUE(st.IndicesValid());
+ st.Reorder<string>({1, 0, 2, 3});
+ EXPECT_TRUE(st.IndicesValid());
+ st.Reorder<string>({3, 0, 2, 1});
+ EXPECT_TRUE(st.IndicesValid());
+ }
+}
+
+TEST(SparseTensorTest, ValidateIndicesFindsInvalid) {
+ int N = 2;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ Eigen::Tensor<int64, 2, Eigen::RowMajor> ix_orig(N, NDIM);
+ ix_orig(0, 0) = 0;
+ ix_orig(0, 1) = 0;
+ ix_orig(0, 2) = 0;
+
+ ix_orig(1, 0) = 0;
+ ix_orig(1, 1) = 0;
+ ix_orig(1, 2) = 0;
+
+ auto ix_t = ix.matrix<int64>();
+ ix_t = ix_orig;
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+ SparseTensor st(ix, vals, shape, order);
+
+ st.Reorder<string>(order);
+ EXPECT_FALSE(st.IndicesValid()); // two indices are identical
+
+ ix_orig(1, 2) = 1;
+ ix_t = ix_orig;
+ st.Reorder<string>(order);
+ EXPECT_TRUE(st.IndicesValid()); // second index now (0, 0, 1)
+
+ ix_orig(0, 2) = 1;
+ ix_t = ix_orig;
+ st.Reorder<string>(order);
+ EXPECT_FALSE(st.IndicesValid()); // first index now (0, 0, 1)
+}
+
+TEST(SparseTensorTest, SparseTensorCheckBoundaries) {
+ int N = 5;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ auto ix_t = GetSimpleIndexTensor(N, NDIM);
+
+ ix.matrix<int64>() = ix_t;
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+
+ SparseTensor st(ix, vals, shape, order);
+ EXPECT_FALSE(st.IndicesValid());
+
+ st.Reorder<string>(order);
+ EXPECT_TRUE(st.IndicesValid());
+
+ ix_t(0, 0) = 11;
+ ix.matrix<int64>() = ix_t;
+ st.Reorder<string>(order);
+ EXPECT_FALSE(st.IndicesValid());
+
+ ix_t(0, 0) = -1;
+ ix.matrix<int64>() = ix_t;
+ st.Reorder<string>(order);
+ EXPECT_FALSE(st.IndicesValid());
+
+ ix_t(0, 0) = 0;
+ ix.matrix<int64>() = ix_t;
+ st.Reorder<string>(order);
+ EXPECT_TRUE(st.IndicesValid());
+}
+
+TEST(SparseTensorTest, SparseTensorToDenseTensor) {
+ int N = 5;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ auto ix_t = GetSimpleIndexTensor(N, NDIM);
+ auto vals_t = vals.vec<string>();
+
+ ix.matrix<int64>() = ix_t;
+
+ vals_t(0) = "hi0";
+ vals_t(1) = "hi1";
+ vals_t(2) = "hi2";
+ vals_t(3) = "hi3";
+ vals_t(4) = "hi4";
+
+ TensorShape shape({4, 4, 5});
+ std::vector<int64> order{0, 1, 2};
+ SparseTensor st(ix, vals, shape, order);
+
+ Tensor dense(DT_STRING, TensorShape({4, 4, 5}));
+ st.ToDense<string>(&dense);
+
+ auto dense_t = dense.tensor<string, 3>();
+ Eigen::array<Eigen::DenseIndex, NDIM> ix_n;
+ for (int n = 0; n < N; ++n) {
+ for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d);
+ EXPECT_EQ(dense_t(ix_n), vals_t(n));
+ }
+
+ // Spot checks on the others
+ EXPECT_EQ(dense_t(0, 0, 1), "");
+ EXPECT_EQ(dense_t(0, 0, 3), "");
+ EXPECT_EQ(dense_t(3, 3, 3), "");
+ EXPECT_EQ(dense_t(3, 3, 4), "");
+}
+
+TEST(SparseTensorTest, SparseTensorToLargerDenseTensor) {
+ int N = 5;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ auto ix_t = GetSimpleIndexTensor(N, NDIM);
+ auto vals_t = vals.vec<string>();
+
+ ix.matrix<int64>() = ix_t;
+
+ vals_t(0) = "hi0";
+ vals_t(1) = "hi1";
+ vals_t(2) = "hi2";
+ vals_t(3) = "hi3";
+ vals_t(4) = "hi4";
+
+ TensorShape shape({4, 4, 5});
+ std::vector<int64> order{0, 1, 2};
+ SparseTensor st(ix, vals, shape, order);
+
+ Tensor dense(DT_STRING, TensorShape({10, 10, 10}));
+ st.ToDense<string>(&dense);
+
+ auto dense_t = dense.tensor<string, 3>();
+ Eigen::array<Eigen::DenseIndex, NDIM> ix_n;
+ for (int n = 0; n < N; ++n) {
+ for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d);
+ EXPECT_EQ(dense_t(ix_n), vals_t(n));
+ }
+
+ // Spot checks on the others
+ EXPECT_EQ(dense_t(0, 0, 1), "");
+ EXPECT_EQ(dense_t(0, 0, 3), "");
+ EXPECT_EQ(dense_t(3, 3, 3), "");
+ EXPECT_EQ(dense_t(3, 3, 4), "");
+ EXPECT_EQ(dense_t(9, 0, 0), "");
+ EXPECT_EQ(dense_t(9, 0, 9), "");
+ EXPECT_EQ(dense_t(9, 9, 9), "");
+}
+
+TEST(SparseTensorTest, SparseTensorGroup) {
+ int N = 5;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_INT32, TensorShape({N}));
+
+ auto ix_t = ix.matrix<int64>();
+ auto vals_t = vals.vec<int32>();
+
+ ix_t = GetSimpleIndexTensor(N, NDIM);
+
+ vals_t(0) = 1; // associated with ix (000)
+ vals_t(1) = 2; // associated with ix (300)
+ vals_t(2) = 3; // associated with ix (200)
+ vals_t(3) = 4; // associated with ix (010)
+ vals_t(4) = 5; // associated with ix (002)
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+
+ SparseTensor st(ix, vals, shape, order);
+ st.Reorder<int32>(order);
+
+ std::vector<std::vector<int64> > groups;
+ std::vector<TTypes<int64>::UnalignedConstMatrix> grouped_indices;
+ std::vector<TTypes<int32>::UnalignedVec> grouped_values;
+
+ // Group by index 0
+ auto gi = st.group({0});
+
+ // All the hard work is right here!
+ for (const auto& g : gi) {
+ groups.push_back(g.group());
+ VLOG(1) << "Group: " << str_util::Join(g.group(), ",");
+ VLOG(1) << "Indices: " << g.indices();
+ VLOG(1) << "Values: " << g.values<int32>();
+
+ grouped_indices.push_back(g.indices());
+ grouped_values.push_back(g.values<int32>());
+ }
+
+ // Group by dimension 0, we have groups: 0--, 2--, 3--
+ EXPECT_EQ(groups.size(), 3);
+ EXPECT_EQ(groups[0], std::vector<int64>({0}));
+ EXPECT_EQ(groups[1], std::vector<int64>({2}));
+ EXPECT_EQ(groups[2], std::vector<int64>({3}));
+
+ std::vector<Eigen::Tensor<int64, 2, Eigen::RowMajor> > expected_indices;
+ std::vector<Eigen::Tensor<int32, 1, Eigen::RowMajor> > expected_vals;
+
+ // First group: 000, 002, 010
+ expected_indices.emplace_back(3, NDIM); // 3 x 3 tensor
+ expected_vals.emplace_back(3); // 3 x 5 x 1 x 1 tensor
+ expected_indices[0].setZero();
+ expected_indices[0](1, 2) = 2; // 002
+ expected_indices[0](2, 1) = 1; // 010
+ expected_vals[0].setConstant(-1);
+ expected_vals[0](0) = 1; // val associated with ix 000
+ expected_vals[0](1) = 5; // val associated with ix 002
+ expected_vals[0](2) = 4; // val associated with ix 010
+
+ // Second group: 200
+ expected_indices.emplace_back(1, NDIM);
+ expected_vals.emplace_back(1);
+ expected_indices[1].setZero();
+ expected_indices[1](0, 0) = 2; // 200
+ expected_vals[1](0) = 3; // val associated with ix 200
+
+ // Third group: 300
+ expected_indices.emplace_back(1, NDIM);
+ expected_vals.emplace_back(1);
+ expected_indices[2].setZero();
+ expected_indices[2](0, 0) = 3; // 300
+ expected_vals[2](0) = 2; // val associated with ix 300
+
+ for (std::size_t gix = 0; gix < groups.size(); ++gix) {
+ // Compare indices
+ auto gi_t = grouped_indices[gix];
+ Eigen::Tensor<bool, 0, Eigen::RowMajor> eval =
+ (gi_t == expected_indices[gix]).all();
+ EXPECT_TRUE(eval()) << gix << " indices: " << gi_t << " vs. "
+ << expected_indices[gix];
+
+ // Compare values
+ auto gv_t = grouped_values[gix];
+ eval = (gv_t == expected_vals[gix]).all();
+ EXPECT_TRUE(eval()) << gix << " values: " << gv_t << " vs. "
+ << expected_vals[gix];
+ }
+}
+
+TEST(SparseTensorTest, Concat) {
+ int N = 5;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ auto ix_c = GetSimpleIndexTensor(N, NDIM);
+
+ auto ix_t = ix.matrix<int64>();
+ auto vals_t = vals.vec<string>();
+
+ ix_t = ix_c;
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+
+ SparseTensor st(ix, vals, shape, order);
+ EXPECT_FALSE(st.IndicesValid());
+ st.Reorder<string>(order);
+ EXPECT_TRUE(st.IndicesValid());
+
+ SparseTensor concatted = SparseTensor::Concat<string>({st, st, st, st});
+ EXPECT_EQ(concatted.order(), st.order());
+ TensorShape expected_shape({40, 10, 10});
+ EXPECT_EQ(concatted.shape(), expected_shape);
+ EXPECT_EQ(concatted.num_entries(), 4 * N);
+ EXPECT_TRUE(concatted.IndicesValid());
+
+ auto conc_ix_t = concatted.indices().matrix<int64>();
+ auto conc_vals_t = concatted.values().vec<string>();
+
+ for (int n = 0; n < 4; ++n) {
+ for (int i = 0; i < N; ++i) {
+ // Dimensions match except the primary dim, which is offset by
+ // shape[order[0]]
+ EXPECT_EQ(conc_ix_t(n * N + i, 0), 10 * n + ix_t(i, 0));
+ EXPECT_EQ(conc_ix_t(n * N + i, 1), ix_t(i, 1));
+ EXPECT_EQ(conc_ix_t(n * N + i, 1), ix_t(i, 1));
+
+ // Values match
+ EXPECT_EQ(conc_vals_t(n * N + i), vals_t(i));
+ }
+ }
+
+ // Concat works if non-primary ix is out of order, but output order
+ // is not defined
+ SparseTensor st_ooo(ix, vals, shape, {0, 2, 1}); // non-primary ix OOO
+ SparseTensor conc_ooo = SparseTensor::Concat<string>({st, st, st, st_ooo});
+ std::vector<int64> expected_ooo{-1, -1, -1};
+ EXPECT_EQ(conc_ooo.order(), expected_ooo);
+ EXPECT_EQ(conc_ooo.shape(), expected_shape);
+ EXPECT_EQ(conc_ooo.num_entries(), 4 * N);
+}
+
+// TODO(ebrevdo): ReduceToDense(R={dim1,dim2,...}, reduce_fn, &output)
+// reduce_fn sees slices of resorted values based on generator (dim: DDIMS), and
+// slices of resorted indices on generator.
+
+} // namespace
+} // namespace sparse
+} // namespace tensorflow