aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/sparse/sparse_tensor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util/sparse/sparse_tensor.h')
-rw-r--r--tensorflow/core/util/sparse/sparse_tensor.h196
1 files changed, 151 insertions, 45 deletions
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h
index 258ee418c1..0f04b65f60 100644
--- a/tensorflow/core/util/sparse/sparse_tensor.h
+++ b/tensorflow/core/util/sparse/sparse_tensor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
-#define TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
+#ifndef TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
+#define TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
#include <limits>
#include <numeric>
@@ -26,8 +26,10 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/sparse/dim_comparator.h"
@@ -41,32 +43,88 @@ class SparseTensor {
typedef typename gtl::ArraySlice<int64> VarDimArray;
typedef typename gtl::InlinedVector<int64, 8> ShapeArray;
+ static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
+ const VarDimArray order, SparseTensor* result) {
+ if (ix.dtype() != DT_INT64) {
+ return Status(
+ error::INVALID_ARGUMENT,
+ strings::StrCat("indices must be type int64 but got: ", ix.dtype()));
+ }
+ if (!TensorShapeUtils::IsVector(vals.shape())) {
+ return Status(error::INVALID_ARGUMENT,
+ strings::StrCat("vals must be a vec, but got: ",
+ vals.shape().DebugString()));
+ }
+ if (ix.shape().dim_size(0) != vals.shape().dim_size(0)) {
+ return Status(error::INVALID_ARGUMENT,
+ strings::StrCat("indices and values rows (indexing "
+ "dimension) must match. (indices = ",
+ ix.shape().dim_size(0), ", values = ",
+ vals.shape().dim_size(0), ")"));
+ }
+ int dims;
+ TF_RETURN_IF_ERROR(GetDimsFromIx(ix, &dims));
+ if (order.size() != dims) {
+ return Status(error::INVALID_ARGUMENT,
+ "Order length must be SparseTensor rank.");
+ }
+ if (shape.size() != dims) {
+ return Status(error::INVALID_ARGUMENT,
+ "Shape rank must be SparseTensor rank.");
+ }
+
+ *result = SparseTensor(ix, vals, shape, order);
+ return Status();
+ }
+
+ static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
+ SparseTensor* result) {
+ return Create(ix, vals, TensorShapeToVector(shape),
+ UndefinedOrder(TensorShapeToVector(shape)), result);
+ }
+
+ static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
+ SparseTensor* result) {
+ return Create(ix, vals, shape, UndefinedOrder(shape), result);
+ }
+
+ static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
+ const VarDimArray order, SparseTensor* result) {
+ return Create(ix, vals, TensorShapeToVector(shape), order, result);
+ }
+
+ SparseTensor() : dims_(0) {}
+
+ // DEPRECATED: use Create() functions instead of constructors directly.
SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
: SparseTensor(ix, vals, TensorShapeToVector(shape),
UndefinedOrder(TensorShapeToVector(shape))) {}
+ // DEPRECATED: use Create() functions instead of constructors directly.
SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape)
: SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {}
+ // DEPRECATED: use Create() functions instead of constructors directly.
SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
const VarDimArray order)
: SparseTensor(ix, vals, TensorShapeToVector(shape), order) {}
+ // DEPRECATED: use Create() functions instead of constructors directly.
SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
const VarDimArray order)
: ix_(ix),
vals_(vals),
shape_(shape.begin(), shape.end()),
order_(order.begin(), order.end()),
- dims_(GetDimsFromIx(ix)) {
- CHECK_EQ(ix.dtype(), DT_INT64)
+ dims_(UnsafeGetDimsFromIx(ix)) {
+ DCHECK_EQ(ix.dtype(), DT_INT64)
<< "indices must be type int64 but got: " << ix.dtype();
- CHECK(TensorShapeUtils::IsVector(vals.shape()))
+ DCHECK(TensorShapeUtils::IsVector(vals.shape()))
<< "vals must be a vec, but got: " << vals.shape().DebugString();
- CHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0))
+ DCHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0))
<< "indices and values rows (indexing dimension) must match.";
- CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank.";
- CHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank.";
+ DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank.";
+ DCHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank.";
}
SparseTensor(const SparseTensor& other)
@@ -81,6 +139,16 @@ class SparseTensor {
vals_ = other.vals_;
shape_ = other.shape_;
order_ = other.order_;
+ dims_ = other.dims_;
+ return *this;
+ }
+
+ SparseTensor& operator=(SparseTensor&& other) {
+ ix_ = std::move(other.ix_);
+ vals_ = std::move(other.vals_);
+ shape_ = std::move(other.shape_);
+ order_ = std::move(other.order_);
+ dims_ = std::move(other.dims_);
return *this;
}
@@ -126,11 +194,11 @@ class SparseTensor {
//
// See the README.md in this directory for more usage information.
GroupIterable group(const VarDimArray& group_ix) const {
- CHECK_LE(group_ix.size(), dims_);
+ DCHECK_LE(group_ix.size(), dims_);
for (std::size_t di = 0; di < group_ix.size(); ++di) {
- CHECK_GE(group_ix[di], 0) << "Group dimension out of range";
- CHECK_LT(group_ix[di], dims_) << "Group dimension out of range";
- CHECK_EQ(group_ix[di], order_[di])
+ DCHECK_GE(group_ix[di], 0) << "Group dimension out of range";
+ DCHECK_LT(group_ix[di], dims_) << "Group dimension out of range";
+ DCHECK_EQ(group_ix[di], order_[di])
<< "Group dimension does not match sorted order";
}
return GroupIterable(ix_, vals_, dims_, group_ix);
@@ -166,9 +234,16 @@ class SparseTensor {
// isn't an integer multiple of split_dim, we add one extra dimension for
// each slice.
template <typename T>
+ static Status Split(const SparseTensor& tensor, const int split_dim,
+ const int num_split, std::vector<SparseTensor>* result);
+
+ // DEPRECATED: use the form of Split() that takes an output pointer and
+ // returns a status instead.
+ template <typename T>
static std::vector<SparseTensor> Split(const SparseTensor& tensor,
const int split_dim,
- const int num_split);
+ const int num_split,
+ Status* status = nullptr);
// Slice() will slice the input SparseTensor into a SparseTensor based on
// specified start and size. Both start and size are 1-D array with each
@@ -189,9 +264,18 @@ class SparseTensor {
}
private:
- static int GetDimsFromIx(const Tensor& ix) {
- CHECK(TensorShapeUtils::IsMatrix(ix.shape()))
- << "indices must be a matrix, but got: " << ix.shape().DebugString();
+ static Status GetDimsFromIx(const Tensor& ix, int* result) {
+ if (!TensorShapeUtils::IsMatrix(ix.shape())) {
+ return Status(error::INVALID_ARGUMENT,
+ strings::StrCat("indices must be a matrix, but got: ",
+ ix.shape().DebugString()));
+ }
+ *result = UnsafeGetDimsFromIx(ix);
+ return Status();
+ }
+
+ static int UnsafeGetDimsFromIx(const Tensor& ix) {
+ DCHECK(TensorShapeUtils::IsMatrix(ix.shape()));
return ix.dim_size(1);
}
@@ -251,8 +335,8 @@ class SparseTensor {
// Helper for Split() that returns the slice index.
static inline int GetSliceIndex(const int dim, const int split_size,
const int residual) {
- CHECK_GT(split_size, 0);
- CHECK_GE(dim, 0);
+ DCHECK_GT(split_size, 0);
+ DCHECK_GE(dim, 0);
if (residual == 0) return dim / split_size;
const int offset = residual * (split_size + 1);
if (dim < offset) {
@@ -265,8 +349,8 @@ class SparseTensor {
// Helper for Split() that returns the dimension in the slice.
static inline int GetDimensionInSlice(const int dim, const int split_size,
const int residual) {
- CHECK_GT(split_size, 0);
- CHECK_GE(dim, 0);
+ DCHECK_GT(split_size, 0);
+ DCHECK_GE(dim, 0);
if (residual == 0) return dim % split_size;
const int offset = residual * (split_size + 1);
if (dim < offset) {
@@ -279,8 +363,8 @@ class SparseTensor {
// Helper for Split() that returns the shape given a slice index.
static inline int GetSliceShape(const int slice_index, const int split_size,
const int residual) {
- CHECK_GT(split_size, 0);
- CHECK_GE(slice_index, 0);
+ DCHECK_GT(split_size, 0);
+ DCHECK_GE(slice_index, 0);
if (residual == 0) return split_size;
if (slice_index < residual) {
return split_size + 1;
@@ -293,7 +377,7 @@ class SparseTensor {
Tensor vals_;
ShapeArray shape_;
ShapeArray order_;
- const int dims_;
+ int dims_;
};
// This operation updates the indices and values Tensor rows, so it is
@@ -301,9 +385,9 @@ class SparseTensor {
// temporary space.
template <typename T>
void SparseTensor::Reorder(const VarDimArray& order) {
- CHECK_EQ(DataTypeToEnum<T>::v(), dtype())
+ DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
<< "Reorder requested with the wrong datatype";
- CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank";
+ DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank";
auto ix_t = ix_.matrix<int64>();
auto vals_t = vals_.vec<T>();
@@ -360,13 +444,13 @@ void SparseTensor::Reorder(const VarDimArray& order) {
template <typename T>
bool SparseTensor::ValidateAndInitializeToDense(Tensor* out, bool initialize) {
- CHECK_EQ(DataTypeToEnum<T>::v(), dtype())
+ DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
<< "ToDense requested with the wrong datatype";
- CHECK_EQ(out->shape().dims(), dims_)
+ DCHECK_EQ(out->shape().dims(), dims_)
<< "Incompatible dimensions between SparseTensor and output";
- CHECK_EQ(out->dtype(), DataTypeToEnum<T>::v())
+ DCHECK_EQ(out->dtype(), DataTypeToEnum<T>::v())
<< "Output must be type: " << DataTypeToEnum<T>::v()
<< " but got: " << out->dtype();
@@ -422,9 +506,9 @@ bool SparseTensor::ToDense(Tensor* out, bool initialize) {
template <typename T>
SparseTensor SparseTensor::Concat(
const gtl::ArraySlice<SparseTensor>& tensors) {
- CHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors";
+ DCHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors";
const int dims = tensors[0].dims_;
- CHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors";
+ DCHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors";
auto order_0 = tensors[0].order();
const int primary_dim = order_0[0];
ShapeArray final_order(order_0.begin(), order_0.end());
@@ -434,17 +518,17 @@ SparseTensor SparseTensor::Concat(
bool fully_ordered = true;
for (const SparseTensor& st : tensors) {
- CHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank.";
- CHECK_EQ(DataTypeToEnum<T>::v(), st.dtype())
+ DCHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank.";
+ DCHECK_EQ(DataTypeToEnum<T>::v(), st.dtype())
<< "Concat requested with the wrong data type";
- CHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered";
- CHECK_EQ(st.order()[0], primary_dim)
+ DCHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered";
+ DCHECK_EQ(st.order()[0], primary_dim)
<< "All SparseTensors' order[0] must match. This is the concat dim.";
if (st.order() != final_order) fully_ordered = false;
const VarDimArray& st_shape = st.shape();
for (int d = 0; d < dims - 1; ++d) {
const int cdim = (d < primary_dim) ? d : d + 1;
- CHECK_EQ(final_shape[cdim], st_shape[cdim])
+ DCHECK_EQ(final_shape[cdim], st_shape[cdim])
<< "All SparseTensors' shapes must match except on the concat dim. "
<< "Concat dim: " << primary_dim
<< ", mismatched shape at dim: " << cdim
@@ -494,7 +578,8 @@ SparseTensor SparseTensor::Concat(
template <typename T>
std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor,
const int split_dim,
- const int num_split) {
+ const int num_split,
+ Status* status /* = nullptr */) {
std::vector<Tensor> output_indices;
std::vector<Tensor> output_values;
std::vector<TensorShape> output_shapes;
@@ -514,12 +599,18 @@ std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor,
const int split_dim_size = input_tensor.shape()[split_dim];
const int split_size = split_dim_size / num_split;
- CHECK(num_split > 0 && num_split <= split_dim_size) << "num_split must be in "
- "the interval (0, "
- << split_dim_size << "]";
- CHECK(split_dim >= 0 && split_dim < num_dim) << "num_dim must be in "
- "the interval [0, "
- << num_dim << ")";
+ if (!(num_split > 0 && num_split <= split_dim_size) && status != nullptr) {
+ *status = Status(error::INVALID_ARGUMENT,
+ strings::StrCat("num_split must be in the interval (0, ",
+ split_dim_size, "]"));
+ return {};
+ }
+ if (!(split_dim >= 0 && split_dim < num_dim)) {
+ *status = Status(
+ error::INVALID_ARGUMENT,
+ strings::StrCat("num_dim must be in the interval [0, ", num_dim, ")"));
+ return {};
+ }
const int residual = split_dim_size % num_split;
for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
@@ -559,13 +650,28 @@ std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor,
std::vector<SparseTensor> output_tensors;
output_tensors.reserve(num_split);
for (int i = 0; i < num_split; ++i) {
- output_tensors.emplace_back(output_indices[i], output_values[i],
- output_shapes[i]);
+ SparseTensor tensor;
+ Status create_status =
+ Create(output_indices[i], output_values[i], output_shapes[i], &tensor);
+ if (!create_status.ok() && status != nullptr) {
+ *status = create_status;
+ return {};
+ }
+ output_tensors.push_back(std::move(tensor));
}
return output_tensors;
}
template <typename T>
+Status SparseTensor::Split(const SparseTensor& input_tensor,
+ const int split_dim, const int num_split,
+ std::vector<SparseTensor>* result) {
+ Status status;
+ *result = Split<T>(input_tensor, split_dim, num_split, &status);
+ return status;
+}
+
+template <typename T>
SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor,
const gtl::ArraySlice<int64>& start,
const gtl::ArraySlice<int64>& size) {
@@ -643,4 +749,4 @@ SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor,
} // namespace sparse
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
+#endif // TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_