aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/math_ops.py')
-rw-r--r--tensorflow/python/ops/math_ops.py258
1 files changed, 175 insertions, 83 deletions
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 4c400423b6..e2e23dccef 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -170,14 +170,13 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops.gen_math_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.util import compat
-from tensorflow.python.util.deprecation import deprecated
-from tensorflow.python.util.deprecation import deprecated_args
+from tensorflow.python.util import deprecation
# Aliases for some automatically-generated names.
linspace = gen_math_ops.lin_space
-arg_max = deprecated(None, "Use `argmax` instead")(arg_max) # pylint: disable=used-before-assignment
-arg_min = deprecated(None, "Use `argmin` instead")(arg_min) # pylint: disable=used-before-assignment
+arg_max = deprecation.deprecated(None, "Use `argmax` instead")(arg_max) # pylint: disable=used-before-assignment
+arg_min = deprecation.deprecated(None, "Use `argmin` instead")(arg_min) # pylint: disable=used-before-assignment
def _set_doc(doc):
@@ -190,7 +189,8 @@ def _set_doc(doc):
# pylint: disable=redefined-builtin
-@deprecated_args(None, "Use the `axis` argument instead", "dimension")
+@deprecation.deprecated_args(None, "Use the `axis` argument instead",
+ "dimension")
@_set_doc(
gen_math_ops.arg_max.__doc__.replace("dimensions", "axes").replace(
"dimension", "axis"))
@@ -208,7 +208,8 @@ def argmax(input,
return gen_math_ops.arg_max(input, axis, name=name, output_type=output_type)
-@deprecated_args(None, "Use the `axis` argument instead", "dimension")
+@deprecation.deprecated_args(None, "Use the `axis` argument instead",
+ "dimension")
@_set_doc(
gen_math_ops.arg_min.__doc__.replace("dimensions", "axes").replace(
"dimension", "axis"))
@@ -324,7 +325,7 @@ multiply.__doc__ = gen_math_ops._mul.__doc__.replace("Mul", "`tf.multiply`")
# TODO(aselle): put deprecation in after another round of global code changes
-@deprecated(
+@deprecation.deprecated(
"2016-12-30",
"`tf.mul(x, y)` is deprecated, please use `tf.multiply(x, y)` or `x * y`")
def _mul(x, y, name=None):
@@ -343,7 +344,7 @@ subtract.__doc__ = gen_math_ops._sub.__doc__.replace("`Sub`", "`tf.subtract`")
# TODO(aselle): put deprecation in after another round of global code changes
-@deprecated(
+@deprecation.deprecated(
"2016-12-30",
"`tf.sub(x, y)` is deprecated, please use `tf.subtract(x, y)` or `x - y`")
def _sub(x, y, name=None):
@@ -381,8 +382,9 @@ def negative(x, name=None):
# pylint: disable=g-docstring-has-escape
-@deprecated("2016-12-30",
- "`tf.neg(x)` is deprecated, please use `tf.negative(x)` or `-x`")
+@deprecation.deprecated(
+ "2016-12-30",
+ "`tf.neg(x)` is deprecated, please use `tf.negative(x)` or `-x`")
def _neg(x, name=None):
"""Computes numerical negative value element-wise.
@@ -1269,24 +1271,27 @@ def _ReductionDims(x, axis, reduction_indices):
return range(0, array_ops.rank(x))
-def _may_reduce_to_scalar(keep_dims, axis, reduction_indices, output):
+def _may_reduce_to_scalar(keepdims, axis, reduction_indices, output):
"""Set a reduction's output's shape to be a scalar if we are certain."""
- if (not output.shape.is_fully_defined()) and (not keep_dims) and (
+ if (not output.shape.is_fully_defined()) and (not keepdims) and (
axis is None) and (reduction_indices is None):
output.set_shape(())
return output
+@deprecation.deprecated_args(
+ None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_sum(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the sum of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1299,7 +1304,7 @@ def reduce_sum(input_tensor,
tf.reduce_sum(x) # 6
tf.reduce_sum(x, 0) # [2, 2, 2]
tf.reduce_sum(x, 1) # [3, 3]
- tf.reduce_sum(x, 1, keep_dims=True) # [[3], [3]]
+ tf.reduce_sum(x, 1, keepdims=True) # [[3], [3]]
tf.reduce_sum(x, [0, 1]) # 6
```
@@ -1308,9 +1313,10 @@ def reduce_sum(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
+ keep_dims: Deprecated alias for `keepdims`.
Returns:
The reduced tensor.
@@ -1319,26 +1325,34 @@ def reduce_sum(input_tensor,
Equivalent to np.sum
@end_compatibility
"""
- return _may_reduce_to_scalar(keep_dims, axis, reduction_indices,
+ keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
+ "keep_dims", keep_dims)
+ if keepdims is None:
+ keepdims = False
+
+ return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
gen_math_ops._sum(
input_tensor,
_ReductionDims(input_tensor, axis,
reduction_indices),
- keep_dims,
+ keepdims,
name=name))
+@deprecation.deprecated_args(
+ None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def count_nonzero(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
dtype=dtypes.int64,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes number of nonzero elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1355,7 +1369,7 @@ def count_nonzero(input_tensor,
tf.count_nonzero(x) # 3
tf.count_nonzero(x, 0) # [1, 2, 0]
tf.count_nonzero(x, 1) # [1, 2]
- tf.count_nonzero(x, 1, keep_dims=True) # [[1], [2]]
+ tf.count_nonzero(x, 1, keepdims=True) # [[1], [2]]
tf.count_nonzero(x, [0, 1]) # 3
```
@@ -1364,14 +1378,20 @@ def count_nonzero(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
dtype: The output dtype; defaults to `tf.int64`.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
+ keep_dims: Deprecated alias for `keepdims`.
Returns:
The reduced tensor (number of nonzero values).
"""
+ keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
+ "keep_dims", keep_dims)
+ if keepdims is None:
+ keepdims = False
+
with ops.name_scope(name, "count_nonzero", [input_tensor]):
input_tensor = ops.convert_to_tensor(input_tensor, name="input_tensor")
zero = input_tensor.dtype.as_numpy_dtype()
@@ -1380,21 +1400,24 @@ def count_nonzero(input_tensor,
# int64 reduction happens on GPU
to_int64(gen_math_ops.not_equal(input_tensor, zero)),
axis=axis,
- keep_dims=keep_dims,
+ keepdims=keepdims,
reduction_indices=reduction_indices),
dtype=dtype)
+@deprecation.deprecated_args(
+ None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_mean(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the mean of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1414,36 +1437,58 @@ def reduce_mean(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
+ keep_dims: Deprecated alias for `keepdims`.
Returns:
The reduced tensor.
@compatibility(numpy)
Equivalent to np.mean
+
+ Please note that `np.mean` has a `dtype` parameter that could be used to
+ specify the output type. By default this is `dtype=float64`. On the other
+ hand, `tf.reduce_mean` has an aggressive type inference from `input_tensor`,
+ for example:
+
+ ```python
+ x = tf.constant([1, 0, 1, 0])
+ tf.reduce_mean(x) # 0
+ y = tf.constant([1., 0., 1., 0.])
+ tf.reduce_mean(y) # 0.5
+ ```
+
@end_compatibility
"""
- return _may_reduce_to_scalar(keep_dims, axis, reduction_indices,
+ keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
+ "keep_dims", keep_dims)
+
+ if keepdims is None:
+ keepdims = False
+ return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
gen_math_ops._mean(
input_tensor,
_ReductionDims(input_tensor, axis,
reduction_indices),
- keep_dims,
+ keepdims,
name=name))
+@deprecation.deprecated_args(
+ None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_prod(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the product of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1454,9 +1499,10 @@ def reduce_prod(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
+ keep_dims: Deprecated alias for `keepdims`.
Returns:
The reduced tensor.
@@ -1465,25 +1511,33 @@ def reduce_prod(input_tensor,
Equivalent to np.prod
@end_compatibility
"""
- return _may_reduce_to_scalar(keep_dims, axis, reduction_indices,
+ keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
+ "keep_dims", keep_dims)
+
+ if keepdims is None:
+ keepdims = False
+ return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
gen_math_ops._prod(
input_tensor,
_ReductionDims(input_tensor, axis,
reduction_indices),
- keep_dims,
+ keepdims,
name=name))
+@deprecation.deprecated_args(
+ None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_min(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the minimum of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1494,9 +1548,10 @@ def reduce_min(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
+ keep_dims: Deprecated alias for `keepdims`.
Returns:
The reduced tensor.
@@ -1505,25 +1560,32 @@ def reduce_min(input_tensor,
Equivalent to np.min
@end_compatibility
"""
- return _may_reduce_to_scalar(keep_dims, axis, reduction_indices,
+ keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
+ "keep_dims", keep_dims)
+ if keepdims is None:
+ keepdims = False
+ return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
gen_math_ops._min(
input_tensor,
_ReductionDims(input_tensor, axis,
reduction_indices),
- keep_dims,
+ keepdims,
name=name))
+@deprecation.deprecated_args(
+ None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_max(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the maximum of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1534,9 +1596,10 @@ def reduce_max(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
+ keep_dims: Deprecated alias for `keepdims`.
Returns:
The reduced tensor.
@@ -1545,25 +1608,32 @@ def reduce_max(input_tensor,
Equivalent to np.max
@end_compatibility
"""
- return _may_reduce_to_scalar(keep_dims, axis, reduction_indices,
+ keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
+ "keep_dims", keep_dims)
+ if keepdims is None:
+ keepdims = False
+ return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
gen_math_ops._max(
input_tensor,
_ReductionDims(input_tensor, axis,
reduction_indices),
- keep_dims,
+ keepdims,
name=name))
+@deprecation.deprecated_args(
+ None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_all(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the "logical and" of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1583,9 +1653,10 @@ def reduce_all(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
+ keep_dims: Deprecated alias for `keepdims`.
Returns:
The reduced tensor.
@@ -1594,25 +1665,32 @@ def reduce_all(input_tensor,
Equivalent to np.all
@end_compatibility
"""
- return _may_reduce_to_scalar(keep_dims, axis, reduction_indices,
+ keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
+ "keep_dims", keep_dims)
+ if keepdims is None:
+ keepdims = False
+ return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
gen_math_ops._all(
input_tensor,
_ReductionDims(input_tensor, axis,
reduction_indices),
- keep_dims,
+ keepdims,
name=name))
+@deprecation.deprecated_args(
+ None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_any(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the "logical or" of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1632,9 +1710,10 @@ def reduce_any(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
+ keep_dims: Deprecated alias for `keepdims`.
Returns:
The reduced tensor.
@@ -1643,25 +1722,32 @@ def reduce_any(input_tensor,
Equivalent to np.any
@end_compatibility
"""
- return _may_reduce_to_scalar(keep_dims, axis, reduction_indices,
+ keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
+ "keep_dims", keep_dims)
+ if keepdims is None:
+ keepdims = False
+ return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
gen_math_ops._any(
input_tensor,
_ReductionDims(input_tensor, axis,
reduction_indices),
- keep_dims,
+ keepdims,
name=name))
+@deprecation.deprecated_args(
+ None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_logsumexp(input_tensor,
axis=None,
- keep_dims=False,
+ keepdims=None,
name=None,
- reduction_indices=None):
+ reduction_indices=None,
+ keep_dims=None):
"""Computes log(sum(exp(elements across dimensions of a tensor))).
Reduces `input_tensor` along the dimensions given in `axis`.
- Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
- entry in `axis`. If `keep_dims` is true, the reduced dimensions
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
@@ -1678,7 +1764,7 @@ def reduce_logsumexp(input_tensor,
tf.reduce_logsumexp(x) # log(6)
tf.reduce_logsumexp(x, 0) # [log(2), log(2), log(2)]
tf.reduce_logsumexp(x, 1) # [log(3), log(3)]
- tf.reduce_logsumexp(x, 1, keep_dims=True) # [[log(3)], [log(3)]]
+ tf.reduce_logsumexp(x, 1, keepdims=True) # [[log(3)], [log(3)]]
tf.reduce_logsumexp(x, [0, 1]) # log(6)
```
@@ -1687,19 +1773,24 @@ def reduce_logsumexp(input_tensor,
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
- keep_dims: If true, retains reduced dimensions with length 1.
+ keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
+ keep_dims: Deprecated alias for `keepdims`.
Returns:
The reduced tensor.
"""
+ keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
+ "keep_dims", keep_dims)
+ if keepdims is None:
+ keepdims = False
with ops.name_scope(name, "ReduceLogSumExp", [input_tensor]) as name:
raw_max = reduce_max(
input_tensor,
axis=axis,
reduction_indices=reduction_indices,
- keep_dims=True)
+ keepdims=True)
my_max = array_ops.stop_gradient(
array_ops.where(
gen_math_ops.is_finite(raw_max), raw_max,
@@ -1708,13 +1799,13 @@ def reduce_logsumexp(input_tensor,
reduce_sum(
gen_math_ops.exp(input_tensor - my_max),
axis,
- keep_dims=True,
+ keepdims=True,
reduction_indices=reduction_indices)) + my_max
- if not keep_dims:
+ if not keepdims:
if isinstance(axis, int):
axis = [axis]
result = array_ops.squeeze(result, axis)
- return _may_reduce_to_scalar(keep_dims, axis, reduction_indices, result)
+ return _may_reduce_to_scalar(keepdims, axis, reduction_indices, result)
def trace(x, name=None):
@@ -2216,9 +2307,10 @@ def bincount(arr,
maxlength = ops.convert_to_tensor(
maxlength, name="maxlength", dtype=dtypes.int32)
output_size = gen_math_ops.minimum(maxlength, output_size)
- weights = (
- ops.convert_to_tensor(weights, name="weights")
- if weights is not None else constant_op.constant([], dtype))
+ if weights is not None:
+ weights = ops.convert_to_tensor(weights, name="weights")
+ return gen_math_ops.unsorted_segment_sum(weights, arr, output_size)
+ weights = constant_op.constant([], dtype)
return gen_math_ops.bincount(arr, output_size, weights)
@@ -2381,7 +2473,7 @@ def reduced_shape(input_shape, axes):
input_shape: 1-D Tensor, the shape of the Tensor being reduced.
axes: 1-D Tensor, the reduction axes.
Returns:
- A 1-D Tensor, the output shape as if keep_dims were set to True.
+ A 1-D Tensor, the output shape as if keepdims were set to True.
"""
# Example:
# cast needed for SparseTensor reductions