aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/array_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/array_ops.py')
-rw-r--r--tensorflow/python/ops/array_ops.py38
1 files changed, 24 insertions, 14 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index c3c7ecd080..43238757c7 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1132,7 +1132,7 @@ def concat(values, axis, name="concat"):
return gen_array_ops._concat_v2(values=values, axis=axis, name=name)
-def boolean_mask(tensor, mask, name="boolean_mask"):
+def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
"""Apply boolean mask to tensor. Numpy equivalent is `tensor[mask]`.
```python
@@ -1146,11 +1146,17 @@ def boolean_mask(tensor, mask, name="boolean_mask"):
the first K dimensions of `tensor`'s shape. We then have:
`boolean_mask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]`
where `(i1,...,iK)` is the ith `True` entry of `mask` (row-major order).
+ The `axis` could be used with `mask` to indicate the axis to mask from.
+ In that case, `axis + dim(mask) <= dim(tensor)` and `mask`'s shape must match
+ the first `axis + dim(mask)` dimensions of `tensor`'s shape.
Args:
tensor: N-D tensor.
mask: K-D boolean tensor, K <= N and K must be known statically.
name: A name for this operation (optional).
+ axis: A 0-D int Tensor representing the axis in `tensor` to mask from.
+ By default, axis is 0 which will mask from the first dimension. Otherwise
+ K + axis <= N.
Returns:
(N-K+1)-dimensional tensor populated by entries in `tensor` corresponding
@@ -1169,10 +1175,10 @@ def boolean_mask(tensor, mask, name="boolean_mask"):
```
"""
- def _apply_mask_1d(reshaped_tensor, mask):
+ def _apply_mask_1d(reshaped_tensor, mask, axis=None):
"""Mask tensor along dimension 0 with a 1-D mask."""
indices = squeeze(where(mask), squeeze_dims=[1])
- return gather(reshaped_tensor, indices)
+ return gather(reshaped_tensor, indices, axis=axis)
with ops.name_scope(name, values=[tensor, mask]):
tensor = ops.convert_to_tensor(tensor, name="tensor")
@@ -1187,19 +1193,22 @@ def boolean_mask(tensor, mask, name="boolean_mask"):
raise ValueError(
"Number of mask dimensions must be specified, even if some dimensions"
" are None. E.g. shape=[None] is ok, but shape=None is not.")
- shape_tensor[:ndims_mask].assert_is_compatible_with(shape_mask)
+ axis = 0 if axis is None else axis
+ shape_tensor[axis:axis+ndims_mask].assert_is_compatible_with(shape_mask)
- leading_size = gen_math_ops._prod(shape(tensor)[:ndims_mask], [0])
+ leading_size = gen_math_ops._prod(shape(tensor)[axis:axis+ndims_mask], [0])
tensor = reshape(tensor,
- concat([[leading_size],
- shape(tensor)[ndims_mask:]], 0))
- first_dim = shape_tensor[:ndims_mask].num_elements()
+ concat([shape(tensor)[:axis],
+ [leading_size],
+ shape(tensor)[axis+ndims_mask:]], 0))
+ first_dim = shape_tensor[axis:axis+ndims_mask].num_elements()
tensor.set_shape(
- tensor_shape.as_shape([first_dim])
- .concatenate(shape_tensor[ndims_mask:]))
+ tensor_shape.as_shape(shape_tensor[:axis])
+ .concatenate([first_dim])
+ .concatenate(shape_tensor[axis+ndims_mask:]))
mask = reshape(mask, [-1])
- return _apply_mask_1d(tensor, mask)
+ return _apply_mask_1d(tensor, mask, axis)
def sparse_mask(a, mask_indices, name=None):
@@ -1521,7 +1530,8 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
Args:
tensor: A `Tensor`.
dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
- `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, or `complex128`.
+ `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`,
+ `complex64`, `complex128` or `bool`.
name: A name for the operation (optional).
optimize: if true, attempt to statically determine the shape of 'tensor'
and encode it as a constant.
@@ -1572,8 +1582,8 @@ def ones_like(tensor, dtype=None, name=None, optimize=True):
Args:
tensor: A `Tensor`.
dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
- `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, `complex128` or
- `bool`.
+ `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`,
+ `complex64`, `complex128` or `bool`.
name: A name for the operation (optional).
optimize: if true, attempt to statically determine the shape of 'tensor'
and encode it as a constant.