diff options
Diffstat (limited to 'tensorflow/core/util/sparse')
-rw-r--r-- | tensorflow/core/util/sparse/dim_comparator.h | 16 | ||||
-rw-r--r-- | tensorflow/core/util/sparse/group_iterator.h | 6 | ||||
-rw-r--r-- | tensorflow/core/util/sparse/sparse_tensor.h | 196 | ||||
-rw-r--r-- | tensorflow/core/util/sparse/sparse_tensor_test.cc | 91 |
4 files changed, 225 insertions, 84 deletions
diff --git a/tensorflow/core/util/sparse/dim_comparator.h b/tensorflow/core/util/sparse/dim_comparator.h index b773b33008..0782e7e1a8 100644 --- a/tensorflow/core/util/sparse/dim_comparator.h +++ b/tensorflow/core/util/sparse/dim_comparator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_ -#define TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_ +#ifndef TENSORFLOW_CORE_UTIL_SPARSE_DIM_COMPARATOR_H_ +#define TENSORFLOW_CORE_UTIL_SPARSE_DIM_COMPARATOR_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/kernels/bounds_check.h" @@ -49,11 +49,11 @@ class DimComparator { DimComparator(const TTypes<int64>::Matrix& ix, const VarDimArray& order, const VarDimArray& shape) : ix_(ix), order_(order), dims_(shape.size()) { - CHECK_GT(order.size(), size_t{0}) << "Must order using at least one index"; - CHECK_LE(order.size(), shape.size()) << "Can only sort up to dims"; + DCHECK_GT(order.size(), size_t{0}) << "Must order using at least one index"; + DCHECK_LE(order.size(), shape.size()) << "Can only sort up to dims"; for (size_t d = 0; d < order.size(); ++d) { - CHECK_GE(order[d], 0); - CHECK_LT(order[d], shape.size()); + DCHECK_GE(order[d], 0); + DCHECK_LT(order[d], shape.size()); } } @@ -97,7 +97,7 @@ class FixedDimComparator : DimComparator { FixedDimComparator(const TTypes<int64>::Matrix& ix, const VarDimArray& order, const VarDimArray& shape) : DimComparator(ix, order, shape) { - CHECK_EQ(order.size(), ORDER_DIM); + DCHECK_EQ(order.size(), ORDER_DIM); } inline bool operator()(const int64 i, const int64 j) const { bool value = false; @@ -116,4 +116,4 @@ class FixedDimComparator : DimComparator { } // namespace sparse } // namespace tensorflow -#endif // TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_ +#endif // TENSORFLOW_CORE_UTIL_SPARSE_DIM_COMPARATOR_H_ diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h index fb70318078..3fa8cb6116 100644 --- a/tensorflow/core/util/sparse/group_iterator.h +++ b/tensorflow/core/util/sparse/group_iterator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_ -#define TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_ +#ifndef TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_ +#define TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_ #include <vector> #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -143,4 +143,4 @@ typename TTypes<T>::UnalignedVec Group::values() const { } // namespace sparse } // namespace tensorflow -#endif // TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_ +#endif // TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_ diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h index 258ee418c1..0f04b65f60 100644 --- a/tensorflow/core/util/sparse/sparse_tensor.h +++ b/tensorflow/core/util/sparse/sparse_tensor.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_ -#define TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_ +#ifndef TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_ +#define TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_ #include <limits> #include <numeric> @@ -26,8 +26,10 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/sparse/dim_comparator.h" @@ -41,32 +43,88 @@ class SparseTensor { typedef typename gtl::ArraySlice<int64> VarDimArray; typedef typename gtl::InlinedVector<int64, 8> ShapeArray; + static Status Create(Tensor ix, Tensor vals, const VarDimArray shape, + const VarDimArray order, SparseTensor* result) { + if (ix.dtype() != DT_INT64) { + return Status( + error::INVALID_ARGUMENT, + strings::StrCat("indices must be type int64 but got: ", ix.dtype())); + } + if (!TensorShapeUtils::IsVector(vals.shape())) { + return Status(error::INVALID_ARGUMENT, + strings::StrCat("vals must be a vec, but got: ", + vals.shape().DebugString())); + } + if (ix.shape().dim_size(0) != vals.shape().dim_size(0)) { + return Status(error::INVALID_ARGUMENT, + strings::StrCat("indices and values rows (indexing " + "dimension) must match. (indices = ", + ix.shape().dim_size(0), ", values = ", + vals.shape().dim_size(0), ")")); + } + int dims; + TF_RETURN_IF_ERROR(GetDimsFromIx(ix, &dims)); + if (order.size() != dims) { + return Status(error::INVALID_ARGUMENT, + "Order length must be SparseTensor rank."); + } + if (shape.size() != dims) { + return Status(error::INVALID_ARGUMENT, + "Shape rank must be SparseTensor rank."); + } + + *result = SparseTensor(ix, vals, shape, order); + return Status(); + } + + static Status Create(Tensor ix, Tensor vals, const TensorShape& shape, + SparseTensor* result) { + return Create(ix, vals, TensorShapeToVector(shape), + UndefinedOrder(TensorShapeToVector(shape)), result); + } + + static Status Create(Tensor ix, Tensor vals, const VarDimArray shape, + SparseTensor* result) { + return Create(ix, vals, shape, UndefinedOrder(shape), result); + } + + static Status Create(Tensor ix, Tensor vals, const TensorShape& shape, + const VarDimArray order, SparseTensor* result) { + return Create(ix, vals, TensorShapeToVector(shape), order, result); + } + + SparseTensor() : dims_(0) {} + + // DEPRECATED: use Create() functions instead of constructors directly. SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape) : SparseTensor(ix, vals, TensorShapeToVector(shape), UndefinedOrder(TensorShapeToVector(shape))) {} + // DEPRECATED: use Create() functions instead of constructors directly. SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape) : SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {} + // DEPRECATED: use Create() functions instead of constructors directly. SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape, const VarDimArray order) : SparseTensor(ix, vals, TensorShapeToVector(shape), order) {} + // DEPRECATED: use Create() functions instead of constructors directly. SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape, const VarDimArray order) : ix_(ix), vals_(vals), shape_(shape.begin(), shape.end()), order_(order.begin(), order.end()), - dims_(GetDimsFromIx(ix)) { - CHECK_EQ(ix.dtype(), DT_INT64) + dims_(UnsafeGetDimsFromIx(ix)) { + DCHECK_EQ(ix.dtype(), DT_INT64) << "indices must be type int64 but got: " << ix.dtype(); - CHECK(TensorShapeUtils::IsVector(vals.shape())) + DCHECK(TensorShapeUtils::IsVector(vals.shape())) << "vals must be a vec, but got: " << vals.shape().DebugString(); - CHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0)) + DCHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0)) << "indices and values rows (indexing dimension) must match."; - CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank."; - CHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank."; + DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank."; + DCHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank."; } SparseTensor(const SparseTensor& other) @@ -81,6 +139,16 @@ class SparseTensor { vals_ = other.vals_; shape_ = other.shape_; order_ = other.order_; + dims_ = other.dims_; + return *this; + } + + SparseTensor& operator=(SparseTensor&& other) { + ix_ = std::move(other.ix_); + vals_ = std::move(other.vals_); + shape_ = std::move(other.shape_); + order_ = std::move(other.order_); + dims_ = std::move(other.dims_); return *this; } @@ -126,11 +194,11 @@ class SparseTensor { // // See the README.md in this directory for more usage information. GroupIterable group(const VarDimArray& group_ix) const { - CHECK_LE(group_ix.size(), dims_); + DCHECK_LE(group_ix.size(), dims_); for (std::size_t di = 0; di < group_ix.size(); ++di) { - CHECK_GE(group_ix[di], 0) << "Group dimension out of range"; - CHECK_LT(group_ix[di], dims_) << "Group dimension out of range"; - CHECK_EQ(group_ix[di], order_[di]) + DCHECK_GE(group_ix[di], 0) << "Group dimension out of range"; + DCHECK_LT(group_ix[di], dims_) << "Group dimension out of range"; + DCHECK_EQ(group_ix[di], order_[di]) << "Group dimension does not match sorted order"; } return GroupIterable(ix_, vals_, dims_, group_ix); @@ -166,9 +234,16 @@ class SparseTensor { // isn't an integer multiple of split_dim, we add one extra dimension for // each slice. template <typename T> + static Status Split(const SparseTensor& tensor, const int split_dim, + const int num_split, std::vector<SparseTensor>* result); + + // DEPRECATED: use the form of Split() that takes an output pointer and + // returns a status instead. + template <typename T> static std::vector<SparseTensor> Split(const SparseTensor& tensor, const int split_dim, - const int num_split); + const int num_split, + Status* status = nullptr); // Slice() will slice the input SparseTensor into a SparseTensor based on // specified start and size. Both start and size are 1-D array with each @@ -189,9 +264,18 @@ class SparseTensor { } private: - static int GetDimsFromIx(const Tensor& ix) { - CHECK(TensorShapeUtils::IsMatrix(ix.shape())) - << "indices must be a matrix, but got: " << ix.shape().DebugString(); + static Status GetDimsFromIx(const Tensor& ix, int* result) { + if (!TensorShapeUtils::IsMatrix(ix.shape())) { + return Status(error::INVALID_ARGUMENT, + strings::StrCat("indices must be a matrix, but got: ", + ix.shape().DebugString())); + } + *result = UnsafeGetDimsFromIx(ix); + return Status(); + } + + static int UnsafeGetDimsFromIx(const Tensor& ix) { + DCHECK(TensorShapeUtils::IsMatrix(ix.shape())); return ix.dim_size(1); } @@ -251,8 +335,8 @@ class SparseTensor { // Helper for Split() that returns the slice index. static inline int GetSliceIndex(const int dim, const int split_size, const int residual) { - CHECK_GT(split_size, 0); - CHECK_GE(dim, 0); + DCHECK_GT(split_size, 0); + DCHECK_GE(dim, 0); if (residual == 0) return dim / split_size; const int offset = residual * (split_size + 1); if (dim < offset) { @@ -265,8 +349,8 @@ class SparseTensor { // Helper for Split() that returns the dimension in the slice. static inline int GetDimensionInSlice(const int dim, const int split_size, const int residual) { - CHECK_GT(split_size, 0); - CHECK_GE(dim, 0); + DCHECK_GT(split_size, 0); + DCHECK_GE(dim, 0); if (residual == 0) return dim % split_size; const int offset = residual * (split_size + 1); if (dim < offset) { @@ -279,8 +363,8 @@ class SparseTensor { // Helper for Split() that returns the shape given a slice index. static inline int GetSliceShape(const int slice_index, const int split_size, const int residual) { - CHECK_GT(split_size, 0); - CHECK_GE(slice_index, 0); + DCHECK_GT(split_size, 0); + DCHECK_GE(slice_index, 0); if (residual == 0) return split_size; if (slice_index < residual) { return split_size + 1; @@ -293,7 +377,7 @@ class SparseTensor { Tensor vals_; ShapeArray shape_; ShapeArray order_; - const int dims_; + int dims_; }; // This operation updates the indices and values Tensor rows, so it is @@ -301,9 +385,9 @@ class SparseTensor { // temporary space. template <typename T> void SparseTensor::Reorder(const VarDimArray& order) { - CHECK_EQ(DataTypeToEnum<T>::v(), dtype()) + DCHECK_EQ(DataTypeToEnum<T>::v(), dtype()) << "Reorder requested with the wrong datatype"; - CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank"; + DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank"; auto ix_t = ix_.matrix<int64>(); auto vals_t = vals_.vec<T>(); @@ -360,13 +444,13 @@ void SparseTensor::Reorder(const VarDimArray& order) { template <typename T> bool SparseTensor::ValidateAndInitializeToDense(Tensor* out, bool initialize) { - CHECK_EQ(DataTypeToEnum<T>::v(), dtype()) + DCHECK_EQ(DataTypeToEnum<T>::v(), dtype()) << "ToDense requested with the wrong datatype"; - CHECK_EQ(out->shape().dims(), dims_) + DCHECK_EQ(out->shape().dims(), dims_) << "Incompatible dimensions between SparseTensor and output"; - CHECK_EQ(out->dtype(), DataTypeToEnum<T>::v()) + DCHECK_EQ(out->dtype(), DataTypeToEnum<T>::v()) << "Output must be type: " << DataTypeToEnum<T>::v() << " but got: " << out->dtype(); @@ -422,9 +506,9 @@ bool SparseTensor::ToDense(Tensor* out, bool initialize) { template <typename T> SparseTensor SparseTensor::Concat( const gtl::ArraySlice<SparseTensor>& tensors) { - CHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors"; + DCHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors"; const int dims = tensors[0].dims_; - CHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors"; + DCHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors"; auto order_0 = tensors[0].order(); const int primary_dim = order_0[0]; ShapeArray final_order(order_0.begin(), order_0.end()); @@ -434,17 +518,17 @@ SparseTensor SparseTensor::Concat( bool fully_ordered = true; for (const SparseTensor& st : tensors) { - CHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank."; - CHECK_EQ(DataTypeToEnum<T>::v(), st.dtype()) + DCHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank."; + DCHECK_EQ(DataTypeToEnum<T>::v(), st.dtype()) << "Concat requested with the wrong data type"; - CHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered"; - CHECK_EQ(st.order()[0], primary_dim) + DCHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered"; + DCHECK_EQ(st.order()[0], primary_dim) << "All SparseTensors' order[0] must match. This is the concat dim."; if (st.order() != final_order) fully_ordered = false; const VarDimArray& st_shape = st.shape(); for (int d = 0; d < dims - 1; ++d) { const int cdim = (d < primary_dim) ? d : d + 1; - CHECK_EQ(final_shape[cdim], st_shape[cdim]) + DCHECK_EQ(final_shape[cdim], st_shape[cdim]) << "All SparseTensors' shapes must match except on the concat dim. " << "Concat dim: " << primary_dim << ", mismatched shape at dim: " << cdim @@ -494,7 +578,8 @@ SparseTensor SparseTensor::Concat( template <typename T> std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor, const int split_dim, - const int num_split) { + const int num_split, + Status* status /* = nullptr */) { std::vector<Tensor> output_indices; std::vector<Tensor> output_values; std::vector<TensorShape> output_shapes; @@ -514,12 +599,18 @@ std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor, const int split_dim_size = input_tensor.shape()[split_dim]; const int split_size = split_dim_size / num_split; - CHECK(num_split > 0 && num_split <= split_dim_size) << "num_split must be in " - "the interval (0, " - << split_dim_size << "]"; - CHECK(split_dim >= 0 && split_dim < num_dim) << "num_dim must be in " - "the interval [0, " - << num_dim << ")"; + if (!(num_split > 0 && num_split <= split_dim_size) && status != nullptr) { + *status = Status(error::INVALID_ARGUMENT, + strings::StrCat("num_split must be in the interval (0, ", + split_dim_size, "]")); + return {}; + } + if (!(split_dim >= 0 && split_dim < num_dim)) { + *status = Status( + error::INVALID_ARGUMENT, + strings::StrCat("num_dim must be in the interval [0, ", num_dim, ")")); + return {}; + } const int residual = split_dim_size % num_split; for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) { @@ -559,13 +650,28 @@ std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor, std::vector<SparseTensor> output_tensors; output_tensors.reserve(num_split); for (int i = 0; i < num_split; ++i) { - output_tensors.emplace_back(output_indices[i], output_values[i], - output_shapes[i]); + SparseTensor tensor; + Status create_status = + Create(output_indices[i], output_values[i], output_shapes[i], &tensor); + if (!create_status.ok() && status != nullptr) { + *status = create_status; + return {}; + } + output_tensors.push_back(std::move(tensor)); } return output_tensors; } template <typename T> +Status SparseTensor::Split(const SparseTensor& input_tensor, + const int split_dim, const int num_split, + std::vector<SparseTensor>* result) { + Status status; + *result = Split<T>(input_tensor, split_dim, num_split, &status); + return status; +} + +template <typename T> SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor, const gtl::ArraySlice<int64>& start, const gtl::ArraySlice<int64>& size) { @@ -643,4 +749,4 @@ SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor, } // namespace sparse } // namespace tensorflow -#endif // TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_ +#endif // TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_ diff --git a/tensorflow/core/util/sparse/sparse_tensor_test.cc b/tensorflow/core/util/sparse/sparse_tensor_test.cc index 85de032085..5578e42625 100644 --- a/tensorflow/core/util/sparse/sparse_tensor_test.cc +++ b/tensorflow/core/util/sparse/sparse_tensor_test.cc @@ -94,9 +94,12 @@ TEST(SparseTensorTest, SparseTensorInvalidIndicesType) { const int NDIM = 3; Tensor ix(DT_INT32, TensorShape({N, NDIM})); Tensor vals(DT_STRING, TensorShape({N})); + SparseTensor result; - EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2}), - "indices must be type int64"); + EXPECT_EQ(SparseTensor::Create(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2}, + &result) + .code(), + error::INVALID_ARGUMENT); } TEST(SparseTensorTest, SparseTensorInvalidIndicesShape) { @@ -104,9 +107,12 @@ TEST(SparseTensorTest, SparseTensorInvalidIndicesShape) { const int NDIM = 3; Tensor ix(DT_INT64, TensorShape({N, NDIM, 1})); Tensor vals(DT_STRING, TensorShape({N})); + SparseTensor result; - EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2}), - "indices must be a matrix"); + EXPECT_EQ(SparseTensor::Create(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2}, + &result) + .code(), + error::INVALID_ARGUMENT); } TEST(SparseTensorTest, SparseTensorInvalidValues) { @@ -114,9 +120,12 @@ TEST(SparseTensorTest, SparseTensorInvalidValues) { const int NDIM = 3; Tensor ix(DT_INT64, TensorShape({N, NDIM})); Tensor vals(DT_STRING, TensorShape({N, 1})); + SparseTensor result; - EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2}), - "vals must be a vec"); + EXPECT_EQ(SparseTensor::Create(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2}, + &result) + .code(), + error::INVALID_ARGUMENT); } TEST(SparseTensorTest, SparseTensorInvalidN) { @@ -124,9 +133,12 @@ TEST(SparseTensorTest, SparseTensorInvalidN) { const int NDIM = 3; Tensor ix(DT_INT64, TensorShape({N, NDIM})); Tensor vals(DT_STRING, TensorShape({N - 1})); + SparseTensor result; - EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2}), - "indices and values rows .* must match"); + EXPECT_EQ(SparseTensor::Create(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2}, + &result) + .code(), + error::INVALID_ARGUMENT); } TEST(SparseTensorTest, SparseTensorInvalidOrder) { @@ -134,18 +146,24 @@ TEST(SparseTensorTest, SparseTensorInvalidOrder) { const int NDIM = 3; Tensor ix(DT_INT64, TensorShape({N, NDIM})); Tensor vals(DT_STRING, TensorShape({N})); + SparseTensor result; - EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10, 10}), {0, 1}), - "Order length must be SparseTensor rank"); + EXPECT_EQ( + SparseTensor::Create(ix, vals, TensorShape({10, 10, 10}), {0, 1}, &result) + .code(), + error::INVALID_ARGUMENT); } TEST(SparseTensorTest, SparseTensorInvalidShape) { int N = 5; const int NDIM = 3; Tensor ix(DT_INT64, TensorShape({N, NDIM})); Tensor vals(DT_STRING, TensorShape({N})); + SparseTensor result; - EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10}), {0, 1, 2}), - "Shape rank must be SparseTensor rank"); + EXPECT_EQ( + SparseTensor::Create(ix, vals, TensorShape({10, 10}), {0, 1, 2}, &result) + .code(), + error::INVALID_ARGUMENT); } TEST(SparseTensorTest, SparseTensorConstruction) { @@ -169,7 +187,8 @@ TEST(SparseTensorTest, SparseTensorConstruction) { TensorShape shape({10, 10, 10}); std::vector<int64> order{0, 1, 2}; - SparseTensor st(ix, vals, shape, order); + SparseTensor st; + TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); Status st_indices_valid = st.IndicesValid(); EXPECT_FALSE(st_indices_valid.ok()); EXPECT_EQ("indices[2] = [2,0,0] is out of order", @@ -210,7 +229,8 @@ TEST(SparseTensorTest, EmptySparseTensorAllowed) { std::vector<int64> shape{10, 10, 10}; std::vector<int64> order{0, 1, 2}; - SparseTensor st(ix, vals, shape, order); + SparseTensor st; + TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); TF_EXPECT_OK(st.IndicesValid()); EXPECT_EQ(st.order(), order); @@ -227,7 +247,8 @@ TEST(SparseTensorTest, SortingWorksCorrectly) { Tensor ix(DT_INT64, TensorShape({N, NDIM})); Tensor vals(DT_STRING, TensorShape({N})); TensorShape shape({1000, 1000, 1000, 1000}); - SparseTensor st(ix, vals, shape); + SparseTensor st; + TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, &st)); auto ix_t = ix.matrix<int64>(); @@ -266,7 +287,8 @@ TEST(SparseTensorTest, ValidateIndicesFindsInvalid) { TensorShape shape({10, 10, 10}); std::vector<int64> order{0, 1, 2}; - SparseTensor st(ix, vals, shape, order); + SparseTensor st; + TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); st.Reorder<string>(order); Status st_indices_valid = st.IndicesValid(); @@ -302,7 +324,8 @@ TEST(SparseTensorTest, SparseTensorCheckBoundaries) { TensorShape shape({10, 10, 10}); std::vector<int64> order{0, 1, 2}; - SparseTensor st(ix, vals, shape, order); + SparseTensor st; + TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); EXPECT_FALSE(st.IndicesValid().ok()); st.Reorder<string>(order); @@ -351,7 +374,8 @@ TEST(SparseTensorTest, SparseTensorToDenseTensor) { TensorShape shape({4, 4, 5}); std::vector<int64> order{0, 1, 2}; - SparseTensor st(ix, vals, shape, order); + SparseTensor st; + TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); Tensor dense(DT_STRING, TensorShape({4, 4, 5})); st.ToDense<string>(&dense); @@ -390,7 +414,8 @@ TEST(SparseTensorTest, SparseTensorToLargerDenseTensor) { TensorShape shape({4, 4, 5}); std::vector<int64> order{0, 1, 2}; - SparseTensor st(ix, vals, shape, order); + SparseTensor st; + TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); Tensor dense(DT_STRING, TensorShape({10, 10, 10})); st.ToDense<string>(&dense); @@ -433,7 +458,8 @@ TEST(SparseTensorTest, SparseTensorGroup) { TensorShape shape({10, 10, 10}); std::vector<int64> order{0, 1, 2}; - SparseTensor st(ix, vals, shape, order); + SparseTensor st; + TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); st.Reorder<int32>(order); std::vector<std::vector<int64> > groups; @@ -521,7 +547,8 @@ TEST(SparseTensorTest, Concat) { TensorShape shape({10, 10, 10}); std::vector<int64> order{0, 1, 2}; - SparseTensor st(ix, vals, shape, order); + SparseTensor st; + TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); EXPECT_FALSE(st.IndicesValid().ok()); st.Reorder<string>(order); TF_EXPECT_OK(st.IndicesValid()); @@ -551,7 +578,9 @@ TEST(SparseTensorTest, Concat) { // Concat works if non-primary ix is out of order, but output order // is not defined - SparseTensor st_ooo(ix, vals, shape, {0, 2, 1}); // non-primary ix OOO + SparseTensor st_ooo; + TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, {0, 2, 1}, + &st_ooo)); // non-primary ix OOO SparseTensor conc_ooo = SparseTensor::Concat<string>({st, st, st, st_ooo}); std::vector<int64> expected_ooo{-1, -1, -1}; EXPECT_EQ(conc_ooo.order(), expected_ooo); @@ -584,9 +613,11 @@ TEST(SparseTensorTest, Split) { vals.vec<int64>()(2) = 3; vals.vec<int64>()(3) = 4; - SparseTensor st(ids, vals, TensorShape({4, 3})); + SparseTensor st; + TF_ASSERT_OK(SparseTensor::Create(ids, vals, TensorShape({4, 3}), &st)); - std::vector<SparseTensor> st_list = SparseTensor::Split<int64>(st, 0, 2); + std::vector<SparseTensor> st_list; + TF_ASSERT_OK(SparseTensor::Split<int64>(st, 0, 2, &st_list)); EXPECT_EQ(st_list.size(), 2); auto expected_shape = gtl::InlinedVector<int64, 8>{2, 3}; @@ -633,7 +664,8 @@ TEST(SparseTensorTest, Slice) { vals.vec<int64>()(2) = 3; vals.vec<int64>()(3) = 4; - SparseTensor st(ids, vals, TensorShape({4, 3})); + SparseTensor st; + TF_ASSERT_OK(SparseTensor::Create(ids, vals, TensorShape({4, 3}), &st)); std::vector<int64> start(2, 0); std::vector<int64> size(2); @@ -662,7 +694,8 @@ TEST(SparseTensorTest, Dim0SparseTensorToDenseTensor) { vals.scalar<int32>()() = 5; TensorShape shape({}); - SparseTensor st(ix, vals, shape); + SparseTensor st; + TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, &st)); Tensor dense(DT_INT32, TensorShape({})); st.ToDense<int32>(&dense); @@ -699,7 +732,8 @@ static void BM_SparseReorderFloat(int iters, int N32, int NDIM32) { ix_t(i, d) = rnd.Rand64() % 1000; } } - SparseTensor st(ix, vals, shape, order); + SparseTensor st; + TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); testing::StartTiming(); st.Reorder<float>(reorder); @@ -740,7 +774,8 @@ static void BM_SparseReorderString(int iters, int N32, int NDIM32) { ix_t(i, d) = rnd.Rand64() % 1000; } } - SparseTensor st(ix, vals, shape, order); + SparseTensor st; + TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); testing::StartTiming(); st.Reorder<string>(reorder); |