aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/sparse_split_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/sparse_split_op.cc')
-rw-r--r--tensorflow/core/kernels/sparse_split_op.cc14
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/sparse_split_op.cc b/tensorflow/core/kernels/sparse_split_op.cc
index 67dcf05a6c..3d02be47cb 100644
--- a/tensorflow/core/kernels/sparse_split_op.cc
+++ b/tensorflow/core/kernels/sparse_split_op.cc
@@ -63,10 +63,16 @@ class SparseSplitOp : public OpKernel {
input_shape.vec<int64>()(split_dim), "), got ",
num_split_));
- sparse::SparseTensor sparse_tensor(input_indices, input_values,
- TensorShape(input_shape.vec<int64>()));
- const std::vector<sparse::SparseTensor> outputs =
- sparse::SparseTensor::Split<T>(sparse_tensor, split_dim, num_split_);
+ sparse::SparseTensor sparse_tensor;
+ OP_REQUIRES_OK(context,
+ sparse::SparseTensor::Create(
+ input_indices, input_values,
+ TensorShape(input_shape.vec<int64>()), &sparse_tensor));
+
+ std::vector<sparse::SparseTensor> outputs;
+ OP_REQUIRES_OK(context,
+ sparse::SparseTensor::Split<T>(sparse_tensor, split_dim,
+ num_split_, &outputs));
for (int slice_index = 0; slice_index < num_split_; ++slice_index) {
context->set_output(slice_index, outputs[slice_index].indices());