diff options
Diffstat (limited to 'tensorflow/core/kernels/sparse_split_op.cc')
-rw-r--r-- | tensorflow/core/kernels/sparse_split_op.cc | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/sparse_split_op.cc b/tensorflow/core/kernels/sparse_split_op.cc index 67dcf05a6c..3d02be47cb 100644 --- a/tensorflow/core/kernels/sparse_split_op.cc +++ b/tensorflow/core/kernels/sparse_split_op.cc @@ -63,10 +63,16 @@ class SparseSplitOp : public OpKernel { input_shape.vec<int64>()(split_dim), "), got ", num_split_)); - sparse::SparseTensor sparse_tensor(input_indices, input_values, - TensorShape(input_shape.vec<int64>())); - const std::vector<sparse::SparseTensor> outputs = - sparse::SparseTensor::Split<T>(sparse_tensor, split_dim, num_split_); + sparse::SparseTensor sparse_tensor; + OP_REQUIRES_OK(context, + sparse::SparseTensor::Create( + input_indices, input_values, + TensorShape(input_shape.vec<int64>()), &sparse_tensor)); + + std::vector<sparse::SparseTensor> outputs; + OP_REQUIRES_OK(context, + sparse::SparseTensor::Split<T>(sparse_tensor, split_dim, + num_split_, &outputs)); for (int slice_index = 0; slice_index < num_split_; ++slice_index) { context->set_output(slice_index, outputs[slice_index].indices()); |