diff options
Diffstat (limited to 'tensorflow/python/ops/array_ops.py')
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 38 |
1 files changed, 14 insertions, 24 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 43238757c7..c3c7ecd080 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1132,7 +1132,7 @@ def concat(values, axis, name="concat"): return gen_array_ops._concat_v2(values=values, axis=axis, name=name) -def boolean_mask(tensor, mask, name="boolean_mask", axis=None): +def boolean_mask(tensor, mask, name="boolean_mask"): """Apply boolean mask to tensor. Numpy equivalent is `tensor[mask]`. ```python @@ -1146,17 +1146,11 @@ def boolean_mask(tensor, mask, name="boolean_mask", axis=None): the first K dimensions of `tensor`'s shape. We then have: `boolean_mask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]` where `(i1,...,iK)` is the ith `True` entry of `mask` (row-major order). - The `axis` could be used with `mask` to indicate the axis to mask from. - In that case, `axis + dim(mask) <= dim(tensor)` and `mask`'s shape must match - the first `axis + dim(mask)` dimensions of `tensor`'s shape. Args: tensor: N-D tensor. mask: K-D boolean tensor, K <= N and K must be known statically. name: A name for this operation (optional). - axis: A 0-D int Tensor representing the axis in `tensor` to mask from. - By default, axis is 0 which will mask from the first dimension. Otherwise - K + axis <= N. Returns: (N-K+1)-dimensional tensor populated by entries in `tensor` corresponding @@ -1175,10 +1169,10 @@ def boolean_mask(tensor, mask, name="boolean_mask", axis=None): ``` """ - def _apply_mask_1d(reshaped_tensor, mask, axis=None): + def _apply_mask_1d(reshaped_tensor, mask): """Mask tensor along dimension 0 with a 1-D mask.""" indices = squeeze(where(mask), squeeze_dims=[1]) - return gather(reshaped_tensor, indices, axis=axis) + return gather(reshaped_tensor, indices) with ops.name_scope(name, values=[tensor, mask]): tensor = ops.convert_to_tensor(tensor, name="tensor") @@ -1193,22 +1187,19 @@ def boolean_mask(tensor, mask, name="boolean_mask", axis=None): raise ValueError( "Number of mask dimensions must be specified, even if some dimensions" " are None. E.g. shape=[None] is ok, but shape=None is not.") - axis = 0 if axis is None else axis - shape_tensor[axis:axis+ndims_mask].assert_is_compatible_with(shape_mask) + shape_tensor[:ndims_mask].assert_is_compatible_with(shape_mask) - leading_size = gen_math_ops._prod(shape(tensor)[axis:axis+ndims_mask], [0]) + leading_size = gen_math_ops._prod(shape(tensor)[:ndims_mask], [0]) tensor = reshape(tensor, - concat([shape(tensor)[:axis], - [leading_size], - shape(tensor)[axis+ndims_mask:]], 0)) - first_dim = shape_tensor[axis:axis+ndims_mask].num_elements() + concat([[leading_size], + shape(tensor)[ndims_mask:]], 0)) + first_dim = shape_tensor[:ndims_mask].num_elements() tensor.set_shape( - tensor_shape.as_shape(shape_tensor[:axis]) - .concatenate([first_dim]) - .concatenate(shape_tensor[axis+ndims_mask:])) + tensor_shape.as_shape([first_dim]) + .concatenate(shape_tensor[ndims_mask:])) mask = reshape(mask, [-1]) - return _apply_mask_1d(tensor, mask, axis) + return _apply_mask_1d(tensor, mask) def sparse_mask(a, mask_indices, name=None): @@ -1530,8 +1521,7 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True): Args: tensor: A `Tensor`. dtype: A type for the returned `Tensor`. Must be `float32`, `float64`, - `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`, - `complex64`, `complex128` or `bool`. + `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, or `complex128`. name: A name for the operation (optional). optimize: if true, attempt to statically determine the shape of 'tensor' and encode it as a constant. @@ -1582,8 +1572,8 @@ def ones_like(tensor, dtype=None, name=None, optimize=True): Args: tensor: A `Tensor`. dtype: A type for the returned `Tensor`. Must be `float32`, `float64`, - `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`, - `complex64`, `complex128` or `bool`. + `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, `complex128` or + `bool`. name: A name for the operation (optional). optimize: if true, attempt to statically determine the shape of 'tensor' and encode it as a constant. |