diff options
Diffstat (limited to 'tensorflow/python/ops/sparse_ops.py')
-rw-r--r-- | tensorflow/python/ops/sparse_ops.py | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 93a9656950..d05fabce86 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -25,6 +25,7 @@ @@sparse_concat @@sparse_reorder @@sparse_reshape +@@sparse_slice @@sparse_split @@sparse_retain @@sparse_reset_shape @@ -657,6 +658,50 @@ def sparse_split(keyword_required=KeywordRequired(), return sparse_tensors +def sparse_slice(sp_input, start, size, name=None): + """Slice a `SparseTensor` based on the `start` and `size. + + For example, if the input is + + input_tensor = shape = [2, 7] + [ a d e ] + [b c ] + + Graphically the output tensors are: + + sparse_slice([0, 0], [2, 4]) = shape = [2, 4] + [ a ] + [b c ] + + sparse_slice([0, 4], [2, 3]) = shape = [2, 3] + [ d e ] + [ ] + + Args: + sp_input: The `SparseTensor` to split. + start: 1-D. tensor represents the start of the slice. + size: 1-D. tensor represents the size of the slice. + name: A name for the operation (optional). + + Returns: + A `SparseTensor` objects resulting from splicing. + + Raises: + TypeError: If `sp_input` is not a `SparseTensor`. + """ + sp_input = _convert_to_sparse_tensor(sp_input) + start = ops.convert_to_tensor(start, dtypes.int64) + size = ops.convert_to_tensor(size, dtypes.int64) + + with ops.name_scope(name, "SparseSlice", [sp_input]) as name: + output_indices, output_values, output_shape = gen_sparse_ops.sparse_slice( + sp_input.indices, sp_input.values, sp_input.dense_shape, start, size, name=name) + + return sparse_tensor.SparseTensor( + output_indices, + output_values, + output_shape) + def sparse_to_dense(sparse_indices, output_shape, sparse_values, |