aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/sparse
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util/sparse')
-rw-r--r--tensorflow/core/util/sparse/dim_comparator.h16
-rw-r--r--tensorflow/core/util/sparse/group_iterator.h6
-rw-r--r--tensorflow/core/util/sparse/sparse_tensor.h196
-rw-r--r--tensorflow/core/util/sparse/sparse_tensor_test.cc91
4 files changed, 225 insertions, 84 deletions
diff --git a/tensorflow/core/util/sparse/dim_comparator.h b/tensorflow/core/util/sparse/dim_comparator.h
index b773b33008..0782e7e1a8 100644
--- a/tensorflow/core/util/sparse/dim_comparator.h
+++ b/tensorflow/core/util/sparse/dim_comparator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
-#define TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
+#ifndef TENSORFLOW_CORE_UTIL_SPARSE_DIM_COMPARATOR_H_
+#define TENSORFLOW_CORE_UTIL_SPARSE_DIM_COMPARATOR_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/bounds_check.h"
@@ -49,11 +49,11 @@ class DimComparator {
DimComparator(const TTypes<int64>::Matrix& ix, const VarDimArray& order,
const VarDimArray& shape)
: ix_(ix), order_(order), dims_(shape.size()) {
- CHECK_GT(order.size(), size_t{0}) << "Must order using at least one index";
- CHECK_LE(order.size(), shape.size()) << "Can only sort up to dims";
+ DCHECK_GT(order.size(), size_t{0}) << "Must order using at least one index";
+ DCHECK_LE(order.size(), shape.size()) << "Can only sort up to dims";
for (size_t d = 0; d < order.size(); ++d) {
- CHECK_GE(order[d], 0);
- CHECK_LT(order[d], shape.size());
+ DCHECK_GE(order[d], 0);
+ DCHECK_LT(order[d], shape.size());
}
}
@@ -97,7 +97,7 @@ class FixedDimComparator : DimComparator {
FixedDimComparator(const TTypes<int64>::Matrix& ix, const VarDimArray& order,
const VarDimArray& shape)
: DimComparator(ix, order, shape) {
- CHECK_EQ(order.size(), ORDER_DIM);
+ DCHECK_EQ(order.size(), ORDER_DIM);
}
inline bool operator()(const int64 i, const int64 j) const {
bool value = false;
@@ -116,4 +116,4 @@ class FixedDimComparator : DimComparator {
} // namespace sparse
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
+#endif // TENSORFLOW_CORE_UTIL_SPARSE_DIM_COMPARATOR_H_
diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h
index fb70318078..3fa8cb6116 100644
--- a/tensorflow/core/util/sparse/group_iterator.h
+++ b/tensorflow/core/util/sparse/group_iterator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
-#define TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
+#ifndef TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
+#define TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -143,4 +143,4 @@ typename TTypes<T>::UnalignedVec Group::values() const {
} // namespace sparse
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
+#endif // TENSORFLOW_CORE_UTIL_SPARSE_GROUP_ITERATOR_H_
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h
index 258ee418c1..0f04b65f60 100644
--- a/tensorflow/core/util/sparse/sparse_tensor.h
+++ b/tensorflow/core/util/sparse/sparse_tensor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
-#define TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
+#ifndef TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
+#define TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
#include <limits>
#include <numeric>
@@ -26,8 +26,10 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/sparse/dim_comparator.h"
@@ -41,32 +43,88 @@ class SparseTensor {
typedef typename gtl::ArraySlice<int64> VarDimArray;
typedef typename gtl::InlinedVector<int64, 8> ShapeArray;
+ static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
+ const VarDimArray order, SparseTensor* result) {
+ if (ix.dtype() != DT_INT64) {
+ return Status(
+ error::INVALID_ARGUMENT,
+ strings::StrCat("indices must be type int64 but got: ", ix.dtype()));
+ }
+ if (!TensorShapeUtils::IsVector(vals.shape())) {
+ return Status(error::INVALID_ARGUMENT,
+ strings::StrCat("vals must be a vec, but got: ",
+ vals.shape().DebugString()));
+ }
+ if (ix.shape().dim_size(0) != vals.shape().dim_size(0)) {
+ return Status(error::INVALID_ARGUMENT,
+ strings::StrCat("indices and values rows (indexing "
+ "dimension) must match. (indices = ",
+ ix.shape().dim_size(0), ", values = ",
+ vals.shape().dim_size(0), ")"));
+ }
+ int dims;
+ TF_RETURN_IF_ERROR(GetDimsFromIx(ix, &dims));
+ if (order.size() != dims) {
+ return Status(error::INVALID_ARGUMENT,
+ "Order length must be SparseTensor rank.");
+ }
+ if (shape.size() != dims) {
+ return Status(error::INVALID_ARGUMENT,
+ "Shape rank must be SparseTensor rank.");
+ }
+
+ *result = SparseTensor(ix, vals, shape, order);
+ return Status();
+ }
+
+ static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
+ SparseTensor* result) {
+ return Create(ix, vals, TensorShapeToVector(shape),
+ UndefinedOrder(TensorShapeToVector(shape)), result);
+ }
+
+ static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
+ SparseTensor* result) {
+ return Create(ix, vals, shape, UndefinedOrder(shape), result);
+ }
+
+ static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
+ const VarDimArray order, SparseTensor* result) {
+ return Create(ix, vals, TensorShapeToVector(shape), order, result);
+ }
+
+ SparseTensor() : dims_(0) {}
+
+ // DEPRECATED: use Create() functions instead of constructors directly.
SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
: SparseTensor(ix, vals, TensorShapeToVector(shape),
UndefinedOrder(TensorShapeToVector(shape))) {}
+ // DEPRECATED: use Create() functions instead of constructors directly.
SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape)
: SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {}
+ // DEPRECATED: use Create() functions instead of constructors directly.
SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
const VarDimArray order)
: SparseTensor(ix, vals, TensorShapeToVector(shape), order) {}
+ // DEPRECATED: use Create() functions instead of constructors directly.
SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
const VarDimArray order)
: ix_(ix),
vals_(vals),
shape_(shape.begin(), shape.end()),
order_(order.begin(), order.end()),
- dims_(GetDimsFromIx(ix)) {
- CHECK_EQ(ix.dtype(), DT_INT64)
+ dims_(UnsafeGetDimsFromIx(ix)) {
+ DCHECK_EQ(ix.dtype(), DT_INT64)
<< "indices must be type int64 but got: " << ix.dtype();
- CHECK(TensorShapeUtils::IsVector(vals.shape()))
+ DCHECK(TensorShapeUtils::IsVector(vals.shape()))
<< "vals must be a vec, but got: " << vals.shape().DebugString();
- CHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0))
+ DCHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0))
<< "indices and values rows (indexing dimension) must match.";
- CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank.";
- CHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank.";
+ DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank.";
+ DCHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank.";
}
SparseTensor(const SparseTensor& other)
@@ -81,6 +139,16 @@ class SparseTensor {
vals_ = other.vals_;
shape_ = other.shape_;
order_ = other.order_;
+ dims_ = other.dims_;
+ return *this;
+ }
+
+ SparseTensor& operator=(SparseTensor&& other) {
+ ix_ = std::move(other.ix_);
+ vals_ = std::move(other.vals_);
+ shape_ = std::move(other.shape_);
+ order_ = std::move(other.order_);
+ dims_ = std::move(other.dims_);
return *this;
}
@@ -126,11 +194,11 @@ class SparseTensor {
//
// See the README.md in this directory for more usage information.
GroupIterable group(const VarDimArray& group_ix) const {
- CHECK_LE(group_ix.size(), dims_);
+ DCHECK_LE(group_ix.size(), dims_);
for (std::size_t di = 0; di < group_ix.size(); ++di) {
- CHECK_GE(group_ix[di], 0) << "Group dimension out of range";
- CHECK_LT(group_ix[di], dims_) << "Group dimension out of range";
- CHECK_EQ(group_ix[di], order_[di])
+ DCHECK_GE(group_ix[di], 0) << "Group dimension out of range";
+ DCHECK_LT(group_ix[di], dims_) << "Group dimension out of range";
+ DCHECK_EQ(group_ix[di], order_[di])
<< "Group dimension does not match sorted order";
}
return GroupIterable(ix_, vals_, dims_, group_ix);
@@ -166,9 +234,16 @@ class SparseTensor {
// isn't an integer multiple of split_dim, we add one extra dimension for
// each slice.
template <typename T>
+ static Status Split(const SparseTensor& tensor, const int split_dim,
+ const int num_split, std::vector<SparseTensor>* result);
+
+ // DEPRECATED: use the form of Split() that takes an output pointer and
+ // returns a status instead.
+ template <typename T>
static std::vector<SparseTensor> Split(const SparseTensor& tensor,
const int split_dim,
- const int num_split);
+ const int num_split,
+ Status* status = nullptr);
// Slice() will slice the input SparseTensor into a SparseTensor based on
// specified start and size. Both start and size are 1-D array with each
@@ -189,9 +264,18 @@ class SparseTensor {
}
private:
- static int GetDimsFromIx(const Tensor& ix) {
- CHECK(TensorShapeUtils::IsMatrix(ix.shape()))
- << "indices must be a matrix, but got: " << ix.shape().DebugString();
+ static Status GetDimsFromIx(const Tensor& ix, int* result) {
+ if (!TensorShapeUtils::IsMatrix(ix.shape())) {
+ return Status(error::INVALID_ARGUMENT,
+ strings::StrCat("indices must be a matrix, but got: ",
+ ix.shape().DebugString()));
+ }
+ *result = UnsafeGetDimsFromIx(ix);
+ return Status();
+ }
+
+ static int UnsafeGetDimsFromIx(const Tensor& ix) {
+ DCHECK(TensorShapeUtils::IsMatrix(ix.shape()));
return ix.dim_size(1);
}
@@ -251,8 +335,8 @@ class SparseTensor {
// Helper for Split() that returns the slice index.
static inline int GetSliceIndex(const int dim, const int split_size,
const int residual) {
- CHECK_GT(split_size, 0);
- CHECK_GE(dim, 0);
+ DCHECK_GT(split_size, 0);
+ DCHECK_GE(dim, 0);
if (residual == 0) return dim / split_size;
const int offset = residual * (split_size + 1);
if (dim < offset) {
@@ -265,8 +349,8 @@ class SparseTensor {
// Helper for Split() that returns the dimension in the slice.
static inline int GetDimensionInSlice(const int dim, const int split_size,
const int residual) {
- CHECK_GT(split_size, 0);
- CHECK_GE(dim, 0);
+ DCHECK_GT(split_size, 0);
+ DCHECK_GE(dim, 0);
if (residual == 0) return dim % split_size;
const int offset = residual * (split_size + 1);
if (dim < offset) {
@@ -279,8 +363,8 @@ class SparseTensor {
// Helper for Split() that returns the shape given a slice index.
static inline int GetSliceShape(const int slice_index, const int split_size,
const int residual) {
- CHECK_GT(split_size, 0);
- CHECK_GE(slice_index, 0);
+ DCHECK_GT(split_size, 0);
+ DCHECK_GE(slice_index, 0);
if (residual == 0) return split_size;
if (slice_index < residual) {
return split_size + 1;
@@ -293,7 +377,7 @@ class SparseTensor {
Tensor vals_;
ShapeArray shape_;
ShapeArray order_;
- const int dims_;
+ int dims_;
};
// This operation updates the indices and values Tensor rows, so it is
@@ -301,9 +385,9 @@ class SparseTensor {
// temporary space.
template <typename T>
void SparseTensor::Reorder(const VarDimArray& order) {
- CHECK_EQ(DataTypeToEnum<T>::v(), dtype())
+ DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
<< "Reorder requested with the wrong datatype";
- CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank";
+ DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank";
auto ix_t = ix_.matrix<int64>();
auto vals_t = vals_.vec<T>();
@@ -360,13 +444,13 @@ void SparseTensor::Reorder(const VarDimArray& order) {
template <typename T>
bool SparseTensor::ValidateAndInitializeToDense(Tensor* out, bool initialize) {
- CHECK_EQ(DataTypeToEnum<T>::v(), dtype())
+ DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
<< "ToDense requested with the wrong datatype";
- CHECK_EQ(out->shape().dims(), dims_)
+ DCHECK_EQ(out->shape().dims(), dims_)
<< "Incompatible dimensions between SparseTensor and output";
- CHECK_EQ(out->dtype(), DataTypeToEnum<T>::v())
+ DCHECK_EQ(out->dtype(), DataTypeToEnum<T>::v())
<< "Output must be type: " << DataTypeToEnum<T>::v()
<< " but got: " << out->dtype();
@@ -422,9 +506,9 @@ bool SparseTensor::ToDense(Tensor* out, bool initialize) {
template <typename T>
SparseTensor SparseTensor::Concat(
const gtl::ArraySlice<SparseTensor>& tensors) {
- CHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors";
+ DCHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors";
const int dims = tensors[0].dims_;
- CHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors";
+ DCHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors";
auto order_0 = tensors[0].order();
const int primary_dim = order_0[0];
ShapeArray final_order(order_0.begin(), order_0.end());
@@ -434,17 +518,17 @@ SparseTensor SparseTensor::Concat(
bool fully_ordered = true;
for (const SparseTensor& st : tensors) {
- CHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank.";
- CHECK_EQ(DataTypeToEnum<T>::v(), st.dtype())
+ DCHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank.";
+ DCHECK_EQ(DataTypeToEnum<T>::v(), st.dtype())
<< "Concat requested with the wrong data type";
- CHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered";
- CHECK_EQ(st.order()[0], primary_dim)
+ DCHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered";
+ DCHECK_EQ(st.order()[0], primary_dim)
<< "All SparseTensors' order[0] must match. This is the concat dim.";
if (st.order() != final_order) fully_ordered = false;
const VarDimArray& st_shape = st.shape();
for (int d = 0; d < dims - 1; ++d) {
const int cdim = (d < primary_dim) ? d : d + 1;
- CHECK_EQ(final_shape[cdim], st_shape[cdim])
+ DCHECK_EQ(final_shape[cdim], st_shape[cdim])
<< "All SparseTensors' shapes must match except on the concat dim. "
<< "Concat dim: " << primary_dim
<< ", mismatched shape at dim: " << cdim
@@ -494,7 +578,8 @@ SparseTensor SparseTensor::Concat(
template <typename T>
std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor,
const int split_dim,
- const int num_split) {
+ const int num_split,
+ Status* status /* = nullptr */) {
std::vector<Tensor> output_indices;
std::vector<Tensor> output_values;
std::vector<TensorShape> output_shapes;
@@ -514,12 +599,18 @@ std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor,
const int split_dim_size = input_tensor.shape()[split_dim];
const int split_size = split_dim_size / num_split;
- CHECK(num_split > 0 && num_split <= split_dim_size) << "num_split must be in "
- "the interval (0, "
- << split_dim_size << "]";
- CHECK(split_dim >= 0 && split_dim < num_dim) << "num_dim must be in "
- "the interval [0, "
- << num_dim << ")";
+ if (!(num_split > 0 && num_split <= split_dim_size) && status != nullptr) {
+ *status = Status(error::INVALID_ARGUMENT,
+ strings::StrCat("num_split must be in the interval (0, ",
+ split_dim_size, "]"));
+ return {};
+ }
+ if (!(split_dim >= 0 && split_dim < num_dim)) {
+ *status = Status(
+ error::INVALID_ARGUMENT,
+ strings::StrCat("num_dim must be in the interval [0, ", num_dim, ")"));
+ return {};
+ }
const int residual = split_dim_size % num_split;
for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
@@ -559,13 +650,28 @@ std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor,
std::vector<SparseTensor> output_tensors;
output_tensors.reserve(num_split);
for (int i = 0; i < num_split; ++i) {
- output_tensors.emplace_back(output_indices[i], output_values[i],
- output_shapes[i]);
+ SparseTensor tensor;
+ Status create_status =
+ Create(output_indices[i], output_values[i], output_shapes[i], &tensor);
+ if (!create_status.ok() && status != nullptr) {
+ *status = create_status;
+ return {};
+ }
+ output_tensors.push_back(std::move(tensor));
}
return output_tensors;
}
template <typename T>
+Status SparseTensor::Split(const SparseTensor& input_tensor,
+ const int split_dim, const int num_split,
+ std::vector<SparseTensor>* result) {
+ Status status;
+ *result = Split<T>(input_tensor, split_dim, num_split, &status);
+ return status;
+}
+
+template <typename T>
SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor,
const gtl::ArraySlice<int64>& start,
const gtl::ArraySlice<int64>& size) {
@@ -643,4 +749,4 @@ SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor,
} // namespace sparse
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
+#endif // TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
diff --git a/tensorflow/core/util/sparse/sparse_tensor_test.cc b/tensorflow/core/util/sparse/sparse_tensor_test.cc
index 85de032085..5578e42625 100644
--- a/tensorflow/core/util/sparse/sparse_tensor_test.cc
+++ b/tensorflow/core/util/sparse/sparse_tensor_test.cc
@@ -94,9 +94,12 @@ TEST(SparseTensorTest, SparseTensorInvalidIndicesType) {
const int NDIM = 3;
Tensor ix(DT_INT32, TensorShape({N, NDIM}));
Tensor vals(DT_STRING, TensorShape({N}));
+ SparseTensor result;
- EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2}),
- "indices must be type int64");
+ EXPECT_EQ(SparseTensor::Create(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2},
+ &result)
+ .code(),
+ error::INVALID_ARGUMENT);
}
TEST(SparseTensorTest, SparseTensorInvalidIndicesShape) {
@@ -104,9 +107,12 @@ TEST(SparseTensorTest, SparseTensorInvalidIndicesShape) {
const int NDIM = 3;
Tensor ix(DT_INT64, TensorShape({N, NDIM, 1}));
Tensor vals(DT_STRING, TensorShape({N}));
+ SparseTensor result;
- EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2}),
- "indices must be a matrix");
+ EXPECT_EQ(SparseTensor::Create(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2},
+ &result)
+ .code(),
+ error::INVALID_ARGUMENT);
}
TEST(SparseTensorTest, SparseTensorInvalidValues) {
@@ -114,9 +120,12 @@ TEST(SparseTensorTest, SparseTensorInvalidValues) {
const int NDIM = 3;
Tensor ix(DT_INT64, TensorShape({N, NDIM}));
Tensor vals(DT_STRING, TensorShape({N, 1}));
+ SparseTensor result;
- EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2}),
- "vals must be a vec");
+ EXPECT_EQ(SparseTensor::Create(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2},
+ &result)
+ .code(),
+ error::INVALID_ARGUMENT);
}
TEST(SparseTensorTest, SparseTensorInvalidN) {
@@ -124,9 +133,12 @@ TEST(SparseTensorTest, SparseTensorInvalidN) {
const int NDIM = 3;
Tensor ix(DT_INT64, TensorShape({N, NDIM}));
Tensor vals(DT_STRING, TensorShape({N - 1}));
+ SparseTensor result;
- EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2}),
- "indices and values rows .* must match");
+ EXPECT_EQ(SparseTensor::Create(ix, vals, TensorShape({10, 10, 10}), {0, 1, 2},
+ &result)
+ .code(),
+ error::INVALID_ARGUMENT);
}
TEST(SparseTensorTest, SparseTensorInvalidOrder) {
@@ -134,18 +146,24 @@ TEST(SparseTensorTest, SparseTensorInvalidOrder) {
const int NDIM = 3;
Tensor ix(DT_INT64, TensorShape({N, NDIM}));
Tensor vals(DT_STRING, TensorShape({N}));
+ SparseTensor result;
- EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10, 10}), {0, 1}),
- "Order length must be SparseTensor rank");
+ EXPECT_EQ(
+ SparseTensor::Create(ix, vals, TensorShape({10, 10, 10}), {0, 1}, &result)
+ .code(),
+ error::INVALID_ARGUMENT);
}
TEST(SparseTensorTest, SparseTensorInvalidShape) {
int N = 5;
const int NDIM = 3;
Tensor ix(DT_INT64, TensorShape({N, NDIM}));
Tensor vals(DT_STRING, TensorShape({N}));
+ SparseTensor result;
- EXPECT_DEATH(SparseTensor(ix, vals, TensorShape({10, 10}), {0, 1, 2}),
- "Shape rank must be SparseTensor rank");
+ EXPECT_EQ(
+ SparseTensor::Create(ix, vals, TensorShape({10, 10}), {0, 1, 2}, &result)
+ .code(),
+ error::INVALID_ARGUMENT);
}
TEST(SparseTensorTest, SparseTensorConstruction) {
@@ -169,7 +187,8 @@ TEST(SparseTensorTest, SparseTensorConstruction) {
TensorShape shape({10, 10, 10});
std::vector<int64> order{0, 1, 2};
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
Status st_indices_valid = st.IndicesValid();
EXPECT_FALSE(st_indices_valid.ok());
EXPECT_EQ("indices[2] = [2,0,0] is out of order",
@@ -210,7 +229,8 @@ TEST(SparseTensorTest, EmptySparseTensorAllowed) {
std::vector<int64> shape{10, 10, 10};
std::vector<int64> order{0, 1, 2};
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
TF_EXPECT_OK(st.IndicesValid());
EXPECT_EQ(st.order(), order);
@@ -227,7 +247,8 @@ TEST(SparseTensorTest, SortingWorksCorrectly) {
Tensor ix(DT_INT64, TensorShape({N, NDIM}));
Tensor vals(DT_STRING, TensorShape({N}));
TensorShape shape({1000, 1000, 1000, 1000});
- SparseTensor st(ix, vals, shape);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, &st));
auto ix_t = ix.matrix<int64>();
@@ -266,7 +287,8 @@ TEST(SparseTensorTest, ValidateIndicesFindsInvalid) {
TensorShape shape({10, 10, 10});
std::vector<int64> order{0, 1, 2};
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
st.Reorder<string>(order);
Status st_indices_valid = st.IndicesValid();
@@ -302,7 +324,8 @@ TEST(SparseTensorTest, SparseTensorCheckBoundaries) {
TensorShape shape({10, 10, 10});
std::vector<int64> order{0, 1, 2};
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
EXPECT_FALSE(st.IndicesValid().ok());
st.Reorder<string>(order);
@@ -351,7 +374,8 @@ TEST(SparseTensorTest, SparseTensorToDenseTensor) {
TensorShape shape({4, 4, 5});
std::vector<int64> order{0, 1, 2};
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
Tensor dense(DT_STRING, TensorShape({4, 4, 5}));
st.ToDense<string>(&dense);
@@ -390,7 +414,8 @@ TEST(SparseTensorTest, SparseTensorToLargerDenseTensor) {
TensorShape shape({4, 4, 5});
std::vector<int64> order{0, 1, 2};
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
Tensor dense(DT_STRING, TensorShape({10, 10, 10}));
st.ToDense<string>(&dense);
@@ -433,7 +458,8 @@ TEST(SparseTensorTest, SparseTensorGroup) {
TensorShape shape({10, 10, 10});
std::vector<int64> order{0, 1, 2};
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
st.Reorder<int32>(order);
std::vector<std::vector<int64> > groups;
@@ -521,7 +547,8 @@ TEST(SparseTensorTest, Concat) {
TensorShape shape({10, 10, 10});
std::vector<int64> order{0, 1, 2};
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
EXPECT_FALSE(st.IndicesValid().ok());
st.Reorder<string>(order);
TF_EXPECT_OK(st.IndicesValid());
@@ -551,7 +578,9 @@ TEST(SparseTensorTest, Concat) {
// Concat works if non-primary ix is out of order, but output order
// is not defined
- SparseTensor st_ooo(ix, vals, shape, {0, 2, 1}); // non-primary ix OOO
+ SparseTensor st_ooo;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, {0, 2, 1},
+ &st_ooo)); // non-primary ix OOO
SparseTensor conc_ooo = SparseTensor::Concat<string>({st, st, st, st_ooo});
std::vector<int64> expected_ooo{-1, -1, -1};
EXPECT_EQ(conc_ooo.order(), expected_ooo);
@@ -584,9 +613,11 @@ TEST(SparseTensorTest, Split) {
vals.vec<int64>()(2) = 3;
vals.vec<int64>()(3) = 4;
- SparseTensor st(ids, vals, TensorShape({4, 3}));
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ids, vals, TensorShape({4, 3}), &st));
- std::vector<SparseTensor> st_list = SparseTensor::Split<int64>(st, 0, 2);
+ std::vector<SparseTensor> st_list;
+ TF_ASSERT_OK(SparseTensor::Split<int64>(st, 0, 2, &st_list));
EXPECT_EQ(st_list.size(), 2);
auto expected_shape = gtl::InlinedVector<int64, 8>{2, 3};
@@ -633,7 +664,8 @@ TEST(SparseTensorTest, Slice) {
vals.vec<int64>()(2) = 3;
vals.vec<int64>()(3) = 4;
- SparseTensor st(ids, vals, TensorShape({4, 3}));
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ids, vals, TensorShape({4, 3}), &st));
std::vector<int64> start(2, 0);
std::vector<int64> size(2);
@@ -662,7 +694,8 @@ TEST(SparseTensorTest, Dim0SparseTensorToDenseTensor) {
vals.scalar<int32>()() = 5;
TensorShape shape({});
- SparseTensor st(ix, vals, shape);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, &st));
Tensor dense(DT_INT32, TensorShape({}));
st.ToDense<int32>(&dense);
@@ -699,7 +732,8 @@ static void BM_SparseReorderFloat(int iters, int N32, int NDIM32) {
ix_t(i, d) = rnd.Rand64() % 1000;
}
}
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
testing::StartTiming();
st.Reorder<float>(reorder);
@@ -740,7 +774,8 @@ static void BM_SparseReorderString(int iters, int N32, int NDIM32) {
ix_t(i, d) = rnd.Rand64() % 1000;
}
}
- SparseTensor st(ix, vals, shape, order);
+ SparseTensor st;
+ TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st));
testing::StartTiming();
st.Reorder<string>(reorder);