aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/sparse_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/sparse_ops.py')
-rw-r--r--tensorflow/python/ops/sparse_ops.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index 93a9656950..d05fabce86 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -25,6 +25,7 @@
@@sparse_concat
@@sparse_reorder
@@sparse_reshape
+@@sparse_slice
@@sparse_split
@@sparse_retain
@@sparse_reset_shape
@@ -657,6 +658,50 @@ def sparse_split(keyword_required=KeywordRequired(),
return sparse_tensors
+def sparse_slice(sp_input, start, size, name=None):
+ """Slice a `SparseTensor` based on the `start` and `size.
+
+ For example, if the input is
+
+ input_tensor = shape = [2, 7]
+ [ a d e ]
+ [b c ]
+
+ Graphically the output tensors are:
+
+ sparse_slice([0, 0], [2, 4]) = shape = [2, 4]
+ [ a ]
+ [b c ]
+
+ sparse_slice([0, 4], [2, 3]) = shape = [2, 3]
+ [ d e ]
+ [ ]
+
+ Args:
+ sp_input: The `SparseTensor` to split.
+ start: 1-D. tensor represents the start of the slice.
+ size: 1-D. tensor represents the size of the slice.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `SparseTensor` objects resulting from splicing.
+
+ Raises:
+ TypeError: If `sp_input` is not a `SparseTensor`.
+ """
+ sp_input = _convert_to_sparse_tensor(sp_input)
+ start = ops.convert_to_tensor(start, dtypes.int64)
+ size = ops.convert_to_tensor(size, dtypes.int64)
+
+ with ops.name_scope(name, "SparseSlice", [sp_input]) as name:
+ output_indices, output_values, output_shape = gen_sparse_ops.sparse_slice(
+ sp_input.indices, sp_input.values, sp_input.dense_shape, start, size, name=name)
+
+ return sparse_tensor.SparseTensor(
+ output_indices,
+ output_values,
+ output_shape)
+
def sparse_to_dense(sparse_indices,
output_shape,
sparse_values,