aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/math_ops.py')
-rw-r--r--tensorflow/python/ops/math_ops.py21
1 files changed, 13 insertions, 8 deletions
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 81499bee56..c9da1a0bba 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -2130,7 +2130,8 @@ def add_n(inputs, name=None):
"""Adds all input tensors element-wise.
Args:
- inputs: A list of `Tensor` objects, each with same shape and type.
+ inputs: A list of `Tensor` or `IndexedSlices` objects, each with same shape
+ and type.
name: A name for the operation (optional).
Returns:
@@ -2141,17 +2142,21 @@ def add_n(inputs, name=None):
cannot be inferred.
"""
if not inputs or not isinstance(inputs, (list, tuple)):
- raise ValueError("inputs must be a list of at least one Tensor with the "
- "same dtype and shape")
+ raise ValueError("inputs must be a list of at least one"
+ "Tensor/IndexedSlices with the same dtype and shape")
inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
- if not all(isinstance(x, ops.Tensor) for x in inputs):
- raise ValueError("inputs must be a list of at least one Tensor with the "
- "same dtype and shape")
+ if not all(isinstance(x, (ops.Tensor, ops.IndexedSlices)) for x in inputs):
+ raise ValueError("inputs must be a list of at least one"
+ "Tensor/IndexedSlices with the same dtype and shape")
if len(inputs) == 1:
+ if isinstance(inputs[0], ops.IndexedSlices):
+ values = inputs[0].values
+ else:
+ values = inputs[0]
if name:
- return array_ops.identity(inputs[0], name=name)
- return inputs[0]
+ return array_ops.identity(values, name=name)
+ return values
return gen_math_ops.add_n(inputs, name=name)