diff options
Diffstat (limited to 'tensorflow/python/ops/array_ops.py')
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 38 |
1 files changed, 24 insertions, 14 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index c3c7ecd080..43238757c7 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1132,7 +1132,7 @@ def concat(values, axis, name="concat"): return gen_array_ops._concat_v2(values=values, axis=axis, name=name) -def boolean_mask(tensor, mask, name="boolean_mask"): +def boolean_mask(tensor, mask, name="boolean_mask", axis=None): """Apply boolean mask to tensor. Numpy equivalent is `tensor[mask]`. ```python @@ -1146,11 +1146,17 @@ def boolean_mask(tensor, mask, name="boolean_mask"): the first K dimensions of `tensor`'s shape. We then have: `boolean_mask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]` where `(i1,...,iK)` is the ith `True` entry of `mask` (row-major order). + The `axis` could be used with `mask` to indicate the axis to mask from. + In that case, `axis + dim(mask) <= dim(tensor)` and `mask`'s shape must match + the first `axis + dim(mask)` dimensions of `tensor`'s shape. Args: tensor: N-D tensor. mask: K-D boolean tensor, K <= N and K must be known statically. name: A name for this operation (optional). + axis: A 0-D int Tensor representing the axis in `tensor` to mask from. + By default, axis is 0 which will mask from the first dimension. Otherwise + K + axis <= N. Returns: (N-K+1)-dimensional tensor populated by entries in `tensor` corresponding @@ -1169,10 +1175,10 @@ def boolean_mask(tensor, mask, name="boolean_mask"): ``` """ - def _apply_mask_1d(reshaped_tensor, mask): + def _apply_mask_1d(reshaped_tensor, mask, axis=None): """Mask tensor along dimension 0 with a 1-D mask.""" indices = squeeze(where(mask), squeeze_dims=[1]) - return gather(reshaped_tensor, indices) + return gather(reshaped_tensor, indices, axis=axis) with ops.name_scope(name, values=[tensor, mask]): tensor = ops.convert_to_tensor(tensor, name="tensor") @@ -1187,19 +1193,22 @@ def boolean_mask(tensor, mask, name="boolean_mask"): raise ValueError( "Number of mask dimensions must be specified, even if some dimensions" " are None. E.g. shape=[None] is ok, but shape=None is not.") - shape_tensor[:ndims_mask].assert_is_compatible_with(shape_mask) + axis = 0 if axis is None else axis + shape_tensor[axis:axis+ndims_mask].assert_is_compatible_with(shape_mask) - leading_size = gen_math_ops._prod(shape(tensor)[:ndims_mask], [0]) + leading_size = gen_math_ops._prod(shape(tensor)[axis:axis+ndims_mask], [0]) tensor = reshape(tensor, - concat([[leading_size], - shape(tensor)[ndims_mask:]], 0)) - first_dim = shape_tensor[:ndims_mask].num_elements() + concat([shape(tensor)[:axis], + [leading_size], + shape(tensor)[axis+ndims_mask:]], 0)) + first_dim = shape_tensor[axis:axis+ndims_mask].num_elements() tensor.set_shape( - tensor_shape.as_shape([first_dim]) - .concatenate(shape_tensor[ndims_mask:])) + tensor_shape.as_shape(shape_tensor[:axis]) + .concatenate([first_dim]) + .concatenate(shape_tensor[axis+ndims_mask:])) mask = reshape(mask, [-1]) - return _apply_mask_1d(tensor, mask) + return _apply_mask_1d(tensor, mask, axis) def sparse_mask(a, mask_indices, name=None): @@ -1521,7 +1530,8 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True): Args: tensor: A `Tensor`. dtype: A type for the returned `Tensor`. Must be `float32`, `float64`, - `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, or `complex128`. + `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`, + `complex64`, `complex128` or `bool`. name: A name for the operation (optional). optimize: if true, attempt to statically determine the shape of 'tensor' and encode it as a constant. @@ -1572,8 +1582,8 @@ def ones_like(tensor, dtype=None, name=None, optimize=True): Args: tensor: A `Tensor`. dtype: A type for the returned `Tensor`. Must be `float32`, `float64`, - `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, `complex128` or - `bool`. + `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`, + `complex64`, `complex128` or `bool`. name: A name for the operation (optional). optimize: if true, attempt to statically determine the shape of 'tensor' and encode it as a constant. |