diff options
Diffstat (limited to 'tensorflow/core/kernels/set_kernels.cc')
-rw-r--r-- | tensorflow/core/kernels/set_kernels.cc | 44 |
1 files changed, 24 insertions, 20 deletions
diff --git a/tensorflow/core/kernels/set_kernels.cc b/tensorflow/core/kernels/set_kernels.cc index e836c764ac..f893d4e945 100644 --- a/tensorflow/core/kernels/set_kernels.cc +++ b/tensorflow/core/kernels/set_kernels.cc @@ -63,9 +63,9 @@ Status GroupShape(const VarDimArray& input_shape, ShapeArray* grouped_shape) { // Build `SparseTensor` from indices, values, and shape in inputs // [base_index, base_index + 3), and validate its rank and indices. -sparse::SparseTensor SparseTensorFromContext(OpKernelContext* ctx, - const int32 base_index, - bool validate_indices) { +Status SparseTensorFromContext(OpKernelContext* ctx, const int32 base_index, + bool validate_indices, + sparse::SparseTensor* tensor) { // Assume row-major order. const TensorShape shape = TensorShape(ctx->input(base_index + 2).vec<int64>()); @@ -73,13 +73,8 @@ sparse::SparseTensor SparseTensorFromContext(OpKernelContext* ctx, std::vector<int64> order(shape.dims()); std::iota(order.begin(), order.end(), 0); - const sparse::SparseTensor st(ctx->input(base_index), - ctx->input(base_index + 1), shape, order); - if (validate_indices) { - Status s = st.IndicesValid(); - if (!s.ok()) ctx->SetStatus(s); - } - return st; + return sparse::SparseTensor::Create( + ctx->input(base_index), ctx->input(base_index + 1), shape, order, tensor); } // TODO(ptucker): CheckGroup is just a sanity check on the result of @@ -253,11 +248,13 @@ class SetSizeOp : public OpKernel { template <typename T> void SetSizeOp<T>::Compute(OpKernelContext* ctx) { - const sparse::SparseTensor set_st = - SparseTensorFromContext(ctx, 0, validate_indices_); + sparse::SparseTensor set_st; + OP_REQUIRES_OK(ctx, + SparseTensorFromContext(ctx, 0, validate_indices_, &set_st)); + OP_REQUIRES_OK(ctx, set_st.IndicesValid()); - // Output shape is same as input except for last dimension, which reduces to - // the set size of values along that dimension. + // Output shape is same as input except for last dimension, which reduces + // to the set size of values along that dimension. ShapeArray output_shape; OP_REQUIRES_OK(ctx, GroupShape(set_st.shape(), &output_shape)); const auto output_strides = Strides(output_shape); @@ -484,8 +481,10 @@ void SetOperationOp<T>::ComputeDenseToDense(OpKernelContext* ctx) const { template <typename T> void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const { const Tensor& set1_t = ctx->input(0); - const sparse::SparseTensor set2_st = - SparseTensorFromContext(ctx, 1, validate_indices_); + sparse::SparseTensor set2_st; + OP_REQUIRES_OK(ctx, + SparseTensorFromContext(ctx, 1, validate_indices_, &set2_st)); + OP_REQUIRES_OK(ctx, set2_st.IndicesValid()); // The following should stay in sync with `_dense_to_sparse_shape` shape // assertions in python/ops/set_ops.py, and `SetShapeFn` for // `DenseToSparseSetOperation` in ops/set_ops.cc. @@ -597,10 +596,15 @@ const std::vector<int64> GROUP_ITER_END; // with the same first n-1 dimensions in set1 and set2. template <typename T> void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const { - const sparse::SparseTensor set1_st = - SparseTensorFromContext(ctx, 0, validate_indices_); - const sparse::SparseTensor set2_st = - SparseTensorFromContext(ctx, 3, validate_indices_); + sparse::SparseTensor set1_st; + OP_REQUIRES_OK(ctx, + SparseTensorFromContext(ctx, 0, validate_indices_, &set1_st)); + OP_REQUIRES_OK(ctx, set1_st.IndicesValid()); + + sparse::SparseTensor set2_st; + OP_REQUIRES_OK(ctx, + SparseTensorFromContext(ctx, 3, validate_indices_, &set2_st)); + // The following should stay in sync with `_sparse_to_sparse_shape` shape // assertions in python/ops/set_ops.py, and `SetShapeFn` for // `SparseToSparseSetOperation` in ops/set_ops.cc. |