diff options
Diffstat (limited to 'tensorflow/python/ops/math_ops.py')
-rw-r--r-- | tensorflow/python/ops/math_ops.py | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 81499bee56..c9da1a0bba 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -2130,7 +2130,8 @@ def add_n(inputs, name=None): """Adds all input tensors element-wise. Args: - inputs: A list of `Tensor` objects, each with same shape and type. + inputs: A list of `Tensor` or `IndexedSlices` objects, each with same shape + and type. name: A name for the operation (optional). Returns: @@ -2141,17 +2142,21 @@ def add_n(inputs, name=None): cannot be inferred. """ if not inputs or not isinstance(inputs, (list, tuple)): - raise ValueError("inputs must be a list of at least one Tensor with the " - "same dtype and shape") + raise ValueError("inputs must be a list of at least one" + "Tensor/IndexedSlices with the same dtype and shape") inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs) - if not all(isinstance(x, ops.Tensor) for x in inputs): - raise ValueError("inputs must be a list of at least one Tensor with the " - "same dtype and shape") + if not all(isinstance(x, (ops.Tensor, ops.IndexedSlices)) for x in inputs): + raise ValueError("inputs must be a list of at least one" + "Tensor/IndexedSlices with the same dtype and shape") if len(inputs) == 1: + if isinstance(inputs[0], ops.IndexedSlices): + values = inputs[0].values + else: + values = inputs[0] if name: - return array_ops.identity(inputs[0], name=name) - return inputs[0] + return array_ops.identity(values, name=name) + return values return gen_math_ops.add_n(inputs, name=name) |