aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/math_ops.py')
-rw-r--r--tensorflow/python/ops/math_ops.py220
1 files changed, 158 insertions, 62 deletions
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 886b2048f9..81b3c21808 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1265,16 +1265,19 @@ def _ReductionDims(x, axis, reduction_indices):
return range(0, array_ops.rank(x))
+@deprecated_args(None, "keep_dims is deprecated, use keepdims instead",
+ "keep_dims")
def reduce_sum(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the sum of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1287,7 +1290,7 @@ def reduce_sum(input_tensor,
tf.reduce_sum(x) # 6
tf.reduce_sum(x, 0) # [2, 2, 2]
tf.reduce_sum(x, 1) # [3, 3]
- tf.reduce_sum(x, 1, keep_dims=True) # [[3], [3]]
+ tf.reduce_sum(x, 1, keepdims=True) # [[3], [3]]
tf.reduce_sum(x, [0, 1]) # 6
```
@@ -1296,7 +1299,7 @@ def reduce_sum(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
@@ -1307,24 +1310,35 @@ def reduce_sum(input_tensor,
Equivalent to np.sum
@end_compatibility
"""
+
+ if keep_dims is not None:
+ if keepdims is not None:
+ raise ValueError("Cannot specify both 'keep_dims' and 'keepdims'")
+ keepdims = keep_dims
+ if keepdims is None:
+ keepdims = False
+
return gen_math_ops._sum(
input_tensor,
_ReductionDims(input_tensor, axis, reduction_indices),
- keep_dims,
+ keepdims,
name=name)
+@deprecated_args(None, "keep_dims is deprecated, use keepdims instead",
+ "keep_dims")
def count_nonzero(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
dtype=dtypes.int64,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes number of nonzero elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1341,7 +1355,7 @@ def count_nonzero(input_tensor,
tf.count_nonzero(x) # 3
tf.count_nonzero(x, 0) # [1, 2, 0]
tf.count_nonzero(x, 1) # [1, 2]
- tf.count_nonzero(x, 1, keep_dims=True) # [[1], [2]]
+ tf.count_nonzero(x, 1, keepdims=True) # [[1], [2]]
tf.count_nonzero(x, [0, 1]) # 3
```
@@ -1350,7 +1364,7 @@ def count_nonzero(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
dtype: The output dtype; defaults to `tf.int64`.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
@@ -1358,6 +1372,13 @@ def count_nonzero(input_tensor,
Returns:
The reduced tensor (number of nonzero values).
"""
+ if keep_dims is not None:
+ if keepdims is not None:
+ raise ValueError("Cannot specify both 'keep_dims' and 'keepdims'")
+ keepdims = keep_dims
+ if keepdims is None:
+ keepdims = False
+
with ops.name_scope(name, "count_nonzero", [input_tensor]):
input_tensor = ops.convert_to_tensor(input_tensor, name="input_tensor")
zero = input_tensor.dtype.as_numpy_dtype()
@@ -1366,21 +1387,24 @@ def count_nonzero(input_tensor,
# int64 reduction happens on GPU
to_int64(gen_math_ops.not_equal(input_tensor, zero)),
axis=axis,
- keep_dims=keep_dims,
+ keepdims=keepdims,
reduction_indices=reduction_indices),
dtype=dtype)
+@deprecated_args(None, "keep_dims is deprecated, use keepdims instead",
+ "keep_dims")
def reduce_mean(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the mean of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1400,7 +1424,7 @@ def reduce_mean(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
@@ -1409,25 +1433,44 @@ def reduce_mean(input_tensor,
@compatibility(numpy)
Equivalent to np.mean
+
+ Please note that `np.mean` has a `dtype` parameter that could be used to specify the output type. By default this is `dtype=float64`. On the other hand, `tf.reduce_mean` has an aggressive type inference from `input_tensor`, for example:
+
+ ```python
+ x = tf.constant([1, 0, 1, 0])
+ tf.reduce_mean(x) # 0
+ y = tf.constant([1., 0., 1., 0.])
+ tf.reduce_mean(y) # 0.5
+ ```
+
@end_compatibility
"""
+ if keep_dims is not None:
+ if keepdims is not None:
+ raise ValueError("Cannot specify both 'keep_dims' and 'keepdims'")
+ keepdims = keep_dims
+ if keepdims is None:
+ keepdims = False
return gen_math_ops._mean(
input_tensor,
_ReductionDims(input_tensor, axis, reduction_indices),
- keep_dims,
+ keepdims,
name=name)
+@deprecated_args(None, "keep_dims is deprecated, use keepdims instead",
+ "keep_dims")
def reduce_prod(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the product of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1438,7 +1481,7 @@ def reduce_prod(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
@@ -1449,23 +1492,32 @@ def reduce_prod(input_tensor,
Equivalent to np.prod
@end_compatibility
"""
+ if keep_dims is not None:
+ if keepdims is not None:
+ raise ValueError("Cannot specify both 'keep_dims' and 'keepdims'")
+ keepdims = keep_dims
+ if keepdims is None:
+ keepdims = False
return gen_math_ops._prod(
input_tensor,
_ReductionDims(input_tensor, axis, reduction_indices),
- keep_dims,
+ keepdims,
name=name)
+@deprecated_args(None, "keep_dims is deprecated, use keepdims instead",
+ "keep_dims")
def reduce_min(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the minimum of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1476,7 +1528,7 @@ def reduce_min(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
@@ -1487,23 +1539,32 @@ def reduce_min(input_tensor,
Equivalent to np.min
@end_compatibility
"""
+ if keep_dims is not None:
+ if keepdims is not None:
+ raise ValueError("Cannot specify both 'keep_dims' and 'keepdims'")
+ keepdims = keep_dims
+ if keepdims is None:
+ keepdims = False
return gen_math_ops._min(
input_tensor,
_ReductionDims(input_tensor, axis, reduction_indices),
- keep_dims,
+ keepdims,
name=name)
+@deprecated_args(None, "keep_dims is deprecated, use keepdims instead",
+ "keep_dims")
def reduce_max(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the maximum of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1514,7 +1575,7 @@ def reduce_max(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
@@ -1525,23 +1586,32 @@ def reduce_max(input_tensor,
Equivalent to np.max
@end_compatibility
"""
+ if keep_dims is not None:
+ if keepdims is not None:
+ raise ValueError("Cannot specify both 'keep_dims' and 'keepdims'")
+ keepdims = keep_dims
+ if keepdims is None:
+ keepdims = False
return gen_math_ops._max(
input_tensor,
_ReductionDims(input_tensor, axis, reduction_indices),
- keep_dims,
+ keepdims,
name=name)
+@deprecated_args(None, "keep_dims is deprecated, use keepdims instead",
+ "keep_dims")
def reduce_all(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the "logical and" of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1561,7 +1631,7 @@ def reduce_all(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
@@ -1572,23 +1642,32 @@ def reduce_all(input_tensor,
Equivalent to np.all
@end_compatibility
"""
+ if keep_dims is not None:
+ if keepdims is not None:
+ raise ValueError("Cannot specify both 'keep_dims' and 'keepdims'")
+ keepdims = keep_dims
+ if keepdims is None:
+ keepdims = False
return gen_math_ops._all(
input_tensor,
_ReductionDims(input_tensor, axis, reduction_indices),
- keep_dims,
+ keepdims,
name=name)
+@deprecated_args(None, "keep_dims is deprecated, use keepdims instead",
+ "keep_dims")
def reduce_any(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the "logical or" of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1608,7 +1687,7 @@ def reduce_any(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
@@ -1619,23 +1698,32 @@ def reduce_any(input_tensor,
Equivalent to np.any
@end_compatibility
"""
+ if keep_dims is not None:
+ if keepdims is not None:
+ raise ValueError("Cannot specify both 'keep_dims' and 'keepdims'")
+ keepdims = keep_dims
+ if keepdims is None:
+ keepdims = False
return gen_math_ops._any(
input_tensor,
_ReductionDims(input_tensor, axis, reduction_indices),
- keep_dims,
+ keepdims,
name=name)
+@deprecated_args(None, "keep_dims is deprecated, use keepdims instead",
+ "keep_dims")
def reduce_logsumexp(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes log(sum(exp(elements across dimensions of a tensor))).
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1652,7 +1740,7 @@ def reduce_logsumexp(input_tensor,
tf.reduce_logsumexp(x) # log(6)
tf.reduce_logsumexp(x, 0) # [log(2), log(2), log(2)]
tf.reduce_logsumexp(x, 1) # [log(3), log(3)]
- tf.reduce_logsumexp(x, 1, keep_dims=True) # [[log(3)], [log(3)]]
+ tf.reduce_logsumexp(x, 1, keepdims=True) # [[log(3)], [log(3)]]
tf.reduce_logsumexp(x, [0, 1]) # log(6)
```
@@ -1661,19 +1749,25 @@ def reduce_logsumexp(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
Returns:
The reduced tensor.
"""
+ if keep_dims is not None:
+ if keepdims is not None:
+ raise ValueError("Cannot specify both 'keep_dims' and 'keepdims'")
+ keepdims = keep_dims
+ if keepdims is None:
+ keepdims = False
with ops.name_scope(name, "ReduceLogSumExp", [input_tensor]) as name:
raw_max = reduce_max(
input_tensor,
axis=axis,
reduction_indices=reduction_indices,
- keep_dims=True)
+ keepdims=True)
my_max = array_ops.stop_gradient(
array_ops.where(
gen_math_ops.is_finite(raw_max),
@@ -1683,9 +1777,9 @@ def reduce_logsumexp(input_tensor,
reduce_sum(
gen_math_ops.exp(input_tensor - my_max),
axis,
- keep_dims=True,
+ keepdims=True,
reduction_indices=reduction_indices)) + my_max
- if not keep_dims:
+ if not keepdims:
if isinstance(axis, int):
axis = [axis]
result = array_ops.squeeze(result, axis)
@@ -2191,8 +2285,10 @@ def bincount(arr,
maxlength = ops.convert_to_tensor(
maxlength, name="maxlength", dtype=dtypes.int32)
output_size = gen_math_ops.minimum(maxlength, output_size)
- weights = (ops.convert_to_tensor(weights, name="weights")
- if weights is not None else constant_op.constant([], dtype))
+ if weights is not None:
+ weights = ops.convert_to_tensor(weights, name="weights")
+ return gen_math_ops.unsorted_segment_sum(weights, arr, output_size)
+ weights = constant_op.constant([], dtype)
return gen_math_ops.bincount(arr, output_size, weights)
@@ -2355,7 +2451,7 @@ def reduced_shape(input_shape, axes):
input_shape: 1-D Tensor, the shape of the Tensor being reduced.
axes: 1-D Tensor, the reduction axes.
Returns:
- A 1-D Tensor, the output shape as if keep_dims were set to True.
+ A 1-D Tensor, the output shape as if keepdims were set to True.
"""
# Example:
# cast needed for SparseTensor reductions