aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/array_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/array_ops.py')
-rw-r--r--tensorflow/python/ops/array_ops.py38
1 files changed, 14 insertions, 24 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 43238757c7..c3c7ecd080 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1132,7 +1132,7 @@ def concat(values, axis, name="concat"):
return gen_array_ops._concat_v2(values=values, axis=axis, name=name)
-def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
+def boolean_mask(tensor, mask, name="boolean_mask"):
"""Apply boolean mask to tensor. Numpy equivalent is `tensor[mask]`.
```python
@@ -1146,17 +1146,11 @@ def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
the first K dimensions of `tensor`'s shape. We then have:
`boolean_mask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]`
where `(i1,...,iK)` is the ith `True` entry of `mask` (row-major order).
- The `axis` could be used with `mask` to indicate the axis to mask from.
- In that case, `axis + dim(mask) <= dim(tensor)` and `mask`'s shape must match
- the first `axis + dim(mask)` dimensions of `tensor`'s shape.
Args:
tensor: N-D tensor.
mask: K-D boolean tensor, K <= N and K must be known statically.
name: A name for this operation (optional).
- axis: A 0-D int Tensor representing the axis in `tensor` to mask from.
- By default, axis is 0 which will mask from the first dimension. Otherwise
- K + axis <= N.
Returns:
(N-K+1)-dimensional tensor populated by entries in `tensor` corresponding
@@ -1175,10 +1169,10 @@ def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
```
"""
- def _apply_mask_1d(reshaped_tensor, mask, axis=None):
+ def _apply_mask_1d(reshaped_tensor, mask):
"""Mask tensor along dimension 0 with a 1-D mask."""
indices = squeeze(where(mask), squeeze_dims=[1])
- return gather(reshaped_tensor, indices, axis=axis)
+ return gather(reshaped_tensor, indices)
with ops.name_scope(name, values=[tensor, mask]):
tensor = ops.convert_to_tensor(tensor, name="tensor")
@@ -1193,22 +1187,19 @@ def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
raise ValueError(
"Number of mask dimensions must be specified, even if some dimensions"
" are None. E.g. shape=[None] is ok, but shape=None is not.")
- axis = 0 if axis is None else axis
- shape_tensor[axis:axis+ndims_mask].assert_is_compatible_with(shape_mask)
+ shape_tensor[:ndims_mask].assert_is_compatible_with(shape_mask)
- leading_size = gen_math_ops._prod(shape(tensor)[axis:axis+ndims_mask], [0])
+ leading_size = gen_math_ops._prod(shape(tensor)[:ndims_mask], [0])
tensor = reshape(tensor,
- concat([shape(tensor)[:axis],
- [leading_size],
- shape(tensor)[axis+ndims_mask:]], 0))
- first_dim = shape_tensor[axis:axis+ndims_mask].num_elements()
+ concat([[leading_size],
+ shape(tensor)[ndims_mask:]], 0))
+ first_dim = shape_tensor[:ndims_mask].num_elements()
tensor.set_shape(
- tensor_shape.as_shape(shape_tensor[:axis])
- .concatenate([first_dim])
- .concatenate(shape_tensor[axis+ndims_mask:]))
+ tensor_shape.as_shape([first_dim])
+ .concatenate(shape_tensor[ndims_mask:]))
mask = reshape(mask, [-1])
- return _apply_mask_1d(tensor, mask, axis)
+ return _apply_mask_1d(tensor, mask)
def sparse_mask(a, mask_indices, name=None):
@@ -1530,8 +1521,7 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
Args:
tensor: A `Tensor`.
dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
- `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`,
- `complex64`, `complex128` or `bool`.
+ `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, or `complex128`.
name: A name for the operation (optional).
optimize: if true, attempt to statically determine the shape of 'tensor'
and encode it as a constant.
@@ -1582,8 +1572,8 @@ def ones_like(tensor, dtype=None, name=None, optimize=True):
Args:
tensor: A `Tensor`.
dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
- `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`,
- `complex64`, `complex128` or `bool`.
+ `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, `complex128` or
+ `bool`.
name: A name for the operation (optional).
optimize: if true, attempt to statically determine the shape of 'tensor'
and encode it as a constant.