aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/set_kernels.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/set_kernels.cc')
-rw-r--r--tensorflow/core/kernels/set_kernels.cc44
1 files changed, 24 insertions, 20 deletions
diff --git a/tensorflow/core/kernels/set_kernels.cc b/tensorflow/core/kernels/set_kernels.cc
index e836c764ac..f893d4e945 100644
--- a/tensorflow/core/kernels/set_kernels.cc
+++ b/tensorflow/core/kernels/set_kernels.cc
@@ -63,9 +63,9 @@ Status GroupShape(const VarDimArray& input_shape, ShapeArray* grouped_shape) {
// Build `SparseTensor` from indices, values, and shape in inputs
// [base_index, base_index + 3), and validate its rank and indices.
-sparse::SparseTensor SparseTensorFromContext(OpKernelContext* ctx,
- const int32 base_index,
- bool validate_indices) {
+Status SparseTensorFromContext(OpKernelContext* ctx, const int32 base_index,
+ bool validate_indices,
+ sparse::SparseTensor* tensor) {
// Assume row-major order.
const TensorShape shape =
TensorShape(ctx->input(base_index + 2).vec<int64>());
@@ -73,13 +73,8 @@ sparse::SparseTensor SparseTensorFromContext(OpKernelContext* ctx,
std::vector<int64> order(shape.dims());
std::iota(order.begin(), order.end(), 0);
- const sparse::SparseTensor st(ctx->input(base_index),
- ctx->input(base_index + 1), shape, order);
- if (validate_indices) {
- Status s = st.IndicesValid();
- if (!s.ok()) ctx->SetStatus(s);
- }
- return st;
+ return sparse::SparseTensor::Create(
+ ctx->input(base_index), ctx->input(base_index + 1), shape, order, tensor);
}
// TODO(ptucker): CheckGroup is just a sanity check on the result of
@@ -253,11 +248,13 @@ class SetSizeOp : public OpKernel {
template <typename T>
void SetSizeOp<T>::Compute(OpKernelContext* ctx) {
- const sparse::SparseTensor set_st =
- SparseTensorFromContext(ctx, 0, validate_indices_);
+ sparse::SparseTensor set_st;
+ OP_REQUIRES_OK(ctx,
+ SparseTensorFromContext(ctx, 0, validate_indices_, &set_st));
+ OP_REQUIRES_OK(ctx, set_st.IndicesValid());
- // Output shape is same as input except for last dimension, which reduces to
- // the set size of values along that dimension.
+ // Output shape is same as input except for last dimension, which reduces
+ // to the set size of values along that dimension.
ShapeArray output_shape;
OP_REQUIRES_OK(ctx, GroupShape(set_st.shape(), &output_shape));
const auto output_strides = Strides(output_shape);
@@ -484,8 +481,10 @@ void SetOperationOp<T>::ComputeDenseToDense(OpKernelContext* ctx) const {
template <typename T>
void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const {
const Tensor& set1_t = ctx->input(0);
- const sparse::SparseTensor set2_st =
- SparseTensorFromContext(ctx, 1, validate_indices_);
+ sparse::SparseTensor set2_st;
+ OP_REQUIRES_OK(ctx,
+ SparseTensorFromContext(ctx, 1, validate_indices_, &set2_st));
+ OP_REQUIRES_OK(ctx, set2_st.IndicesValid());
// The following should stay in sync with `_dense_to_sparse_shape` shape
// assertions in python/ops/set_ops.py, and `SetShapeFn` for
// `DenseToSparseSetOperation` in ops/set_ops.cc.
@@ -597,10 +596,15 @@ const std::vector<int64> GROUP_ITER_END;
// with the same first n-1 dimensions in set1 and set2.
template <typename T>
void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const {
- const sparse::SparseTensor set1_st =
- SparseTensorFromContext(ctx, 0, validate_indices_);
- const sparse::SparseTensor set2_st =
- SparseTensorFromContext(ctx, 3, validate_indices_);
+ sparse::SparseTensor set1_st;
+ OP_REQUIRES_OK(ctx,
+ SparseTensorFromContext(ctx, 0, validate_indices_, &set1_st));
+ OP_REQUIRES_OK(ctx, set1_st.IndicesValid());
+
+ sparse::SparseTensor set2_st;
+ OP_REQUIRES_OK(ctx,
+ SparseTensorFromContext(ctx, 3, validate_indices_, &set2_st));
+
// The following should stay in sync with `_sparse_to_sparse_shape` shape
// assertions in python/ops/set_ops.py, and `SetShapeFn` for
// `SparseToSparseSetOperation` in ops/set_ops.cc.