blob: e3e608a9e23bffa743c77fc9cf0273c76b5cd78d (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
|
### `tf.sparse_split(split_dim, num_split, sp_input, name=None)` {#sparse_split}
Split a `SparseTensor` into `num_split` tensors along `split_dim`.
If the `sp_input.shape[split_dim]` is not an integer multiple of `num_split`
each slice starting from 0:`shape[split_dim] % num_split` gets extra one
dimension. For example, if `split_dim = 1` and `num_split = 2` and the
input is:
input_tensor = shape = [2, 7]
[ a d e ]
[b c ]
Graphically the output tensors are:
output_tensor[0] =
[ a ]
[b c ]
output_tensor[1] =
[ d e ]
[ ]
##### Args:
* <b>`split_dim`</b>: A 0-D `int32` `Tensor`. The dimension along which to split.
* <b>`num_split`</b>: A Python integer. The number of ways to split.
* <b>`sp_input`</b>: The `SparseTensor` to split.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
`num_split` `SparseTensor` objects resulting from splitting `value`.
##### Raises:
* <b>`TypeError`</b>: If `sp_input` is not a `SparseTensor`.
|