diff options
Diffstat (limited to 'tensorflow/core/util/sparse/sparse_tensor.h')
-rw-r--r-- | tensorflow/core/util/sparse/sparse_tensor.h | 196 |
1 files changed, 151 insertions, 45 deletions
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h index 258ee418c1..0f04b65f60 100644 --- a/tensorflow/core/util/sparse/sparse_tensor.h +++ b/tensorflow/core/util/sparse/sparse_tensor.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_ -#define TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_ +#ifndef TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_ +#define TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_ #include <limits> #include <numeric> @@ -26,8 +26,10 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/sparse/dim_comparator.h" @@ -41,32 +43,88 @@ class SparseTensor { typedef typename gtl::ArraySlice<int64> VarDimArray; typedef typename gtl::InlinedVector<int64, 8> ShapeArray; + static Status Create(Tensor ix, Tensor vals, const VarDimArray shape, + const VarDimArray order, SparseTensor* result) { + if (ix.dtype() != DT_INT64) { + return Status( + error::INVALID_ARGUMENT, + strings::StrCat("indices must be type int64 but got: ", ix.dtype())); + } + if (!TensorShapeUtils::IsVector(vals.shape())) { + return Status(error::INVALID_ARGUMENT, + strings::StrCat("vals must be a vec, but got: ", + vals.shape().DebugString())); + } + if (ix.shape().dim_size(0) != vals.shape().dim_size(0)) { + return Status(error::INVALID_ARGUMENT, + strings::StrCat("indices and values rows (indexing " + "dimension) must match. (indices = ", + ix.shape().dim_size(0), ", values = ", + vals.shape().dim_size(0), ")")); + } + int dims; + TF_RETURN_IF_ERROR(GetDimsFromIx(ix, &dims)); + if (order.size() != dims) { + return Status(error::INVALID_ARGUMENT, + "Order length must be SparseTensor rank."); + } + if (shape.size() != dims) { + return Status(error::INVALID_ARGUMENT, + "Shape rank must be SparseTensor rank."); + } + + *result = SparseTensor(ix, vals, shape, order); + return Status(); + } + + static Status Create(Tensor ix, Tensor vals, const TensorShape& shape, + SparseTensor* result) { + return Create(ix, vals, TensorShapeToVector(shape), + UndefinedOrder(TensorShapeToVector(shape)), result); + } + + static Status Create(Tensor ix, Tensor vals, const VarDimArray shape, + SparseTensor* result) { + return Create(ix, vals, shape, UndefinedOrder(shape), result); + } + + static Status Create(Tensor ix, Tensor vals, const TensorShape& shape, + const VarDimArray order, SparseTensor* result) { + return Create(ix, vals, TensorShapeToVector(shape), order, result); + } + + SparseTensor() : dims_(0) {} + + // DEPRECATED: use Create() functions instead of constructors directly. SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape) : SparseTensor(ix, vals, TensorShapeToVector(shape), UndefinedOrder(TensorShapeToVector(shape))) {} + // DEPRECATED: use Create() functions instead of constructors directly. SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape) : SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {} + // DEPRECATED: use Create() functions instead of constructors directly. SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape, const VarDimArray order) : SparseTensor(ix, vals, TensorShapeToVector(shape), order) {} + // DEPRECATED: use Create() functions instead of constructors directly. SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape, const VarDimArray order) : ix_(ix), vals_(vals), shape_(shape.begin(), shape.end()), order_(order.begin(), order.end()), - dims_(GetDimsFromIx(ix)) { - CHECK_EQ(ix.dtype(), DT_INT64) + dims_(UnsafeGetDimsFromIx(ix)) { + DCHECK_EQ(ix.dtype(), DT_INT64) << "indices must be type int64 but got: " << ix.dtype(); - CHECK(TensorShapeUtils::IsVector(vals.shape())) + DCHECK(TensorShapeUtils::IsVector(vals.shape())) << "vals must be a vec, but got: " << vals.shape().DebugString(); - CHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0)) + DCHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0)) << "indices and values rows (indexing dimension) must match."; - CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank."; - CHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank."; + DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank."; + DCHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank."; } SparseTensor(const SparseTensor& other) @@ -81,6 +139,16 @@ class SparseTensor { vals_ = other.vals_; shape_ = other.shape_; order_ = other.order_; + dims_ = other.dims_; + return *this; + } + + SparseTensor& operator=(SparseTensor&& other) { + ix_ = std::move(other.ix_); + vals_ = std::move(other.vals_); + shape_ = std::move(other.shape_); + order_ = std::move(other.order_); + dims_ = std::move(other.dims_); return *this; } @@ -126,11 +194,11 @@ class SparseTensor { // // See the README.md in this directory for more usage information. GroupIterable group(const VarDimArray& group_ix) const { - CHECK_LE(group_ix.size(), dims_); + DCHECK_LE(group_ix.size(), dims_); for (std::size_t di = 0; di < group_ix.size(); ++di) { - CHECK_GE(group_ix[di], 0) << "Group dimension out of range"; - CHECK_LT(group_ix[di], dims_) << "Group dimension out of range"; - CHECK_EQ(group_ix[di], order_[di]) + DCHECK_GE(group_ix[di], 0) << "Group dimension out of range"; + DCHECK_LT(group_ix[di], dims_) << "Group dimension out of range"; + DCHECK_EQ(group_ix[di], order_[di]) << "Group dimension does not match sorted order"; } return GroupIterable(ix_, vals_, dims_, group_ix); @@ -166,9 +234,16 @@ class SparseTensor { // isn't an integer multiple of split_dim, we add one extra dimension for // each slice. template <typename T> + static Status Split(const SparseTensor& tensor, const int split_dim, + const int num_split, std::vector<SparseTensor>* result); + + // DEPRECATED: use the form of Split() that takes an output pointer and + // returns a status instead. + template <typename T> static std::vector<SparseTensor> Split(const SparseTensor& tensor, const int split_dim, - const int num_split); + const int num_split, + Status* status = nullptr); // Slice() will slice the input SparseTensor into a SparseTensor based on // specified start and size. Both start and size are 1-D array with each @@ -189,9 +264,18 @@ class SparseTensor { } private: - static int GetDimsFromIx(const Tensor& ix) { - CHECK(TensorShapeUtils::IsMatrix(ix.shape())) - << "indices must be a matrix, but got: " << ix.shape().DebugString(); + static Status GetDimsFromIx(const Tensor& ix, int* result) { + if (!TensorShapeUtils::IsMatrix(ix.shape())) { + return Status(error::INVALID_ARGUMENT, + strings::StrCat("indices must be a matrix, but got: ", + ix.shape().DebugString())); + } + *result = UnsafeGetDimsFromIx(ix); + return Status(); + } + + static int UnsafeGetDimsFromIx(const Tensor& ix) { + DCHECK(TensorShapeUtils::IsMatrix(ix.shape())); return ix.dim_size(1); } @@ -251,8 +335,8 @@ class SparseTensor { // Helper for Split() that returns the slice index. static inline int GetSliceIndex(const int dim, const int split_size, const int residual) { - CHECK_GT(split_size, 0); - CHECK_GE(dim, 0); + DCHECK_GT(split_size, 0); + DCHECK_GE(dim, 0); if (residual == 0) return dim / split_size; const int offset = residual * (split_size + 1); if (dim < offset) { @@ -265,8 +349,8 @@ class SparseTensor { // Helper for Split() that returns the dimension in the slice. static inline int GetDimensionInSlice(const int dim, const int split_size, const int residual) { - CHECK_GT(split_size, 0); - CHECK_GE(dim, 0); + DCHECK_GT(split_size, 0); + DCHECK_GE(dim, 0); if (residual == 0) return dim % split_size; const int offset = residual * (split_size + 1); if (dim < offset) { @@ -279,8 +363,8 @@ class SparseTensor { // Helper for Split() that returns the shape given a slice index. static inline int GetSliceShape(const int slice_index, const int split_size, const int residual) { - CHECK_GT(split_size, 0); - CHECK_GE(slice_index, 0); + DCHECK_GT(split_size, 0); + DCHECK_GE(slice_index, 0); if (residual == 0) return split_size; if (slice_index < residual) { return split_size + 1; @@ -293,7 +377,7 @@ class SparseTensor { Tensor vals_; ShapeArray shape_; ShapeArray order_; - const int dims_; + int dims_; }; // This operation updates the indices and values Tensor rows, so it is @@ -301,9 +385,9 @@ class SparseTensor { // temporary space. template <typename T> void SparseTensor::Reorder(const VarDimArray& order) { - CHECK_EQ(DataTypeToEnum<T>::v(), dtype()) + DCHECK_EQ(DataTypeToEnum<T>::v(), dtype()) << "Reorder requested with the wrong datatype"; - CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank"; + DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank"; auto ix_t = ix_.matrix<int64>(); auto vals_t = vals_.vec<T>(); @@ -360,13 +444,13 @@ void SparseTensor::Reorder(const VarDimArray& order) { template <typename T> bool SparseTensor::ValidateAndInitializeToDense(Tensor* out, bool initialize) { - CHECK_EQ(DataTypeToEnum<T>::v(), dtype()) + DCHECK_EQ(DataTypeToEnum<T>::v(), dtype()) << "ToDense requested with the wrong datatype"; - CHECK_EQ(out->shape().dims(), dims_) + DCHECK_EQ(out->shape().dims(), dims_) << "Incompatible dimensions between SparseTensor and output"; - CHECK_EQ(out->dtype(), DataTypeToEnum<T>::v()) + DCHECK_EQ(out->dtype(), DataTypeToEnum<T>::v()) << "Output must be type: " << DataTypeToEnum<T>::v() << " but got: " << out->dtype(); @@ -422,9 +506,9 @@ bool SparseTensor::ToDense(Tensor* out, bool initialize) { template <typename T> SparseTensor SparseTensor::Concat( const gtl::ArraySlice<SparseTensor>& tensors) { - CHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors"; + DCHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors"; const int dims = tensors[0].dims_; - CHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors"; + DCHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors"; auto order_0 = tensors[0].order(); const int primary_dim = order_0[0]; ShapeArray final_order(order_0.begin(), order_0.end()); @@ -434,17 +518,17 @@ SparseTensor SparseTensor::Concat( bool fully_ordered = true; for (const SparseTensor& st : tensors) { - CHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank."; - CHECK_EQ(DataTypeToEnum<T>::v(), st.dtype()) + DCHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank."; + DCHECK_EQ(DataTypeToEnum<T>::v(), st.dtype()) << "Concat requested with the wrong data type"; - CHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered"; - CHECK_EQ(st.order()[0], primary_dim) + DCHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered"; + DCHECK_EQ(st.order()[0], primary_dim) << "All SparseTensors' order[0] must match. This is the concat dim."; if (st.order() != final_order) fully_ordered = false; const VarDimArray& st_shape = st.shape(); for (int d = 0; d < dims - 1; ++d) { const int cdim = (d < primary_dim) ? d : d + 1; - CHECK_EQ(final_shape[cdim], st_shape[cdim]) + DCHECK_EQ(final_shape[cdim], st_shape[cdim]) << "All SparseTensors' shapes must match except on the concat dim. " << "Concat dim: " << primary_dim << ", mismatched shape at dim: " << cdim @@ -494,7 +578,8 @@ SparseTensor SparseTensor::Concat( template <typename T> std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor, const int split_dim, - const int num_split) { + const int num_split, + Status* status /* = nullptr */) { std::vector<Tensor> output_indices; std::vector<Tensor> output_values; std::vector<TensorShape> output_shapes; @@ -514,12 +599,18 @@ std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor, const int split_dim_size = input_tensor.shape()[split_dim]; const int split_size = split_dim_size / num_split; - CHECK(num_split > 0 && num_split <= split_dim_size) << "num_split must be in " - "the interval (0, " - << split_dim_size << "]"; - CHECK(split_dim >= 0 && split_dim < num_dim) << "num_dim must be in " - "the interval [0, " - << num_dim << ")"; + if (!(num_split > 0 && num_split <= split_dim_size) && status != nullptr) { + *status = Status(error::INVALID_ARGUMENT, + strings::StrCat("num_split must be in the interval (0, ", + split_dim_size, "]")); + return {}; + } + if (!(split_dim >= 0 && split_dim < num_dim)) { + *status = Status( + error::INVALID_ARGUMENT, + strings::StrCat("num_dim must be in the interval [0, ", num_dim, ")")); + return {}; + } const int residual = split_dim_size % num_split; for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) { @@ -559,13 +650,28 @@ std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor, std::vector<SparseTensor> output_tensors; output_tensors.reserve(num_split); for (int i = 0; i < num_split; ++i) { - output_tensors.emplace_back(output_indices[i], output_values[i], - output_shapes[i]); + SparseTensor tensor; + Status create_status = + Create(output_indices[i], output_values[i], output_shapes[i], &tensor); + if (!create_status.ok() && status != nullptr) { + *status = create_status; + return {}; + } + output_tensors.push_back(std::move(tensor)); } return output_tensors; } template <typename T> +Status SparseTensor::Split(const SparseTensor& input_tensor, + const int split_dim, const int num_split, + std::vector<SparseTensor>* result) { + Status status; + *result = Split<T>(input_tensor, split_dim, num_split, &status); + return status; +} + +template <typename T> SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor, const gtl::ArraySlice<int64>& start, const gtl::ArraySlice<int64>& size) { @@ -643,4 +749,4 @@ SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor, } // namespace sparse } // namespace tensorflow -#endif // TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_ +#endif // TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_ |