aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/array_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/array_ops.py')
-rw-r--r--tensorflow/python/ops/array_ops.py56
1 files changed, 24 insertions, 32 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 61bd41e7de..f5f1278bfd 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1136,7 +1136,7 @@ def concat(values, axis, name="concat"):
return gen_array_ops._concat_v2(values=values, axis=axis, name=name)
-def boolean_mask(tensor, mask, name="boolean_mask"):
+def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
"""Apply boolean mask to tensor. Numpy equivalent is `tensor[mask]`.
```python
@@ -1150,11 +1150,17 @@ def boolean_mask(tensor, mask, name="boolean_mask"):
the first K dimensions of `tensor`'s shape. We then have:
`boolean_mask(tensor, mask)[i, j1,...,jd] = tensor[i1,...,iK,j1,...,jd]`
where `(i1,...,iK)` is the ith `True` entry of `mask` (row-major order).
+ The `axis` could be used with `mask` to indicate the axis to mask from.
+ In that case, `axis + dim(mask) <= dim(tensor)` and `mask`'s shape must match
+ the first `axis + dim(mask)` dimensions of `tensor`'s shape.
Args:
tensor: N-D tensor.
mask: K-D boolean tensor, K <= N and K must be known statically.
name: A name for this operation (optional).
+ axis: A 0-D int Tensor representing the axis in `tensor` to mask from.
+ By default, axis is 0 which will mask from the first dimension. Otherwise
+ K + axis <= N.
Returns:
(N-K+1)-dimensional tensor populated by entries in `tensor` corresponding
@@ -1173,10 +1179,10 @@ def boolean_mask(tensor, mask, name="boolean_mask"):
```
"""
- def _apply_mask_1d(reshaped_tensor, mask):
+ def _apply_mask_1d(reshaped_tensor, mask, axis=None):
"""Mask tensor along dimension 0 with a 1-D mask."""
indices = squeeze(where(mask), squeeze_dims=[1])
- return gather(reshaped_tensor, indices)
+ return gather(reshaped_tensor, indices, axis=axis)
with ops.name_scope(name, values=[tensor, mask]):
tensor = ops.convert_to_tensor(tensor, name="tensor")
@@ -1191,19 +1197,22 @@ def boolean_mask(tensor, mask, name="boolean_mask"):
raise ValueError(
"Number of mask dimensions must be specified, even if some dimensions"
" are None. E.g. shape=[None] is ok, but shape=None is not.")
- shape_tensor[:ndims_mask].assert_is_compatible_with(shape_mask)
+ axis = 0 if axis is None else axis
+ shape_tensor[axis:axis+ndims_mask].assert_is_compatible_with(shape_mask)
- leading_size = gen_math_ops._prod(shape(tensor)[:ndims_mask], [0])
+ leading_size = gen_math_ops._prod(shape(tensor)[axis:axis+ndims_mask], [0])
tensor = reshape(tensor,
- concat([[leading_size],
- shape(tensor)[ndims_mask:]], 0))
- first_dim = shape_tensor[:ndims_mask].num_elements()
+ concat([shape(tensor)[:axis],
+ [leading_size],
+ shape(tensor)[axis+ndims_mask:]], 0))
+ first_dim = shape_tensor[axis:axis+ndims_mask].num_elements()
tensor.set_shape(
- tensor_shape.as_shape([first_dim])
- .concatenate(shape_tensor[ndims_mask:]))
+ tensor_shape.as_shape(shape_tensor[:axis])
+ .concatenate([first_dim])
+ .concatenate(shape_tensor[axis+ndims_mask:]))
mask = reshape(mask, [-1])
- return _apply_mask_1d(tensor, mask)
+ return _apply_mask_1d(tensor, mask, axis)
def sparse_mask(a, mask_indices, name=None):
@@ -1525,7 +1534,8 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
Args:
tensor: A `Tensor`.
dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
- `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, or `complex128`.
+ `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`,
+ `complex64`, `complex128` or `bool`.
name: A name for the operation (optional).
optimize: if true, attempt to statically determine the shape of 'tensor'
and encode it as a constant.
@@ -1576,8 +1586,8 @@ def ones_like(tensor, dtype=None, name=None, optimize=True):
Args:
tensor: A `Tensor`.
dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
- `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, `complex128` or
- `bool`.
+ `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`,
+ `complex64`, `complex128` or `bool`.
name: A name for the operation (optional).
optimize: if true, attempt to statically determine the shape of 'tensor'
and encode it as a constant.
@@ -1653,8 +1663,6 @@ def placeholder(dtype, shape=None, name=None):
print(sess.run(y, feed_dict={x: rand_array})) # Will succeed.
```
- @compatibility{eager} Placeholders are not compatible with eager execution.
-
Args:
dtype: The type of elements in the tensor to be fed.
shape: The shape of the tensor to be fed (optional). If the shape is not
@@ -1664,14 +1672,7 @@ def placeholder(dtype, shape=None, name=None):
Returns:
A `Tensor` that may be used as a handle for feeding a value, but not
evaluated directly.
-
- Raises:
- RuntimeError: if eager execution is enabled
"""
- if context.in_eager_mode():
- raise RuntimeError("tf.placeholder() is not compatible with "
- "eager execution.")
-
return gen_array_ops._placeholder(dtype=dtype, shape=shape, name=name)
@@ -1715,8 +1716,6 @@ def sparse_placeholder(dtype, shape=None, name=None):
print(sess.run(y, feed_dict={x: sp_value})) # Will succeed.
```
- @compatibility{eager} Placeholders are not compatible with eager execution.
-
Args:
dtype: The type of `values` elements in the tensor to be fed.
shape: The shape of the tensor to be fed (optional). If the shape is not
@@ -1726,14 +1725,7 @@ def sparse_placeholder(dtype, shape=None, name=None):
Returns:
A `SparseTensor` that may be used as a handle for feeding a value, but not
evaluated directly.
-
- Raises:
- RuntimeError: if eager execution is enabled
"""
- if context.in_eager_mode():
- raise RuntimeError("tf.placeholder() is not compatible with "
- "eager execution.")
-
shape_name = (name + "/shape") if name is not None else None
shape, rank = _normalize_sparse_shape(shape, shape_name)
if shape is None: