diff options
Diffstat (limited to 'tensorflow/python/ops/math_ops.py')
-rw-r--r-- | tensorflow/python/ops/math_ops.py | 103 |
1 files changed, 102 insertions, 1 deletions
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index cac14f286e..fc7b299978 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -13,7 +13,10 @@ # limitations under the License. # ============================================================================== -"""## Arithmetic Operators +"""Note: Elementwise binary operations in TensorFlow follow [numpy-style +broadcasting](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). + +## Arithmetic Operators TensorFlow provides several operations that you can use to add basic arithmetic operators to your graph. @@ -145,6 +148,14 @@ common math computations that reduce various dimensions of a tensor. @@accumulate_n +## Scan + +TensorFlow provides several operations that you can use to perform scans +(running totals) across one axis of a tensor. + +@@cumsum +@@cumprod + ## Segmentation TensorFlow provides several operations that you can use to perform common @@ -1585,6 +1596,94 @@ def tanh(x, name=None): return gen_math_ops._tanh(x, name=name) +def cumsum(x, axis=0, exclusive=False, reverse=False, name=None): + """Compute the cumulative sum of the tensor `x` along `axis`. + + By default, this op performs an inclusive cumsum, which means that the first + element of the input is identical to the first element of the output: + ```prettyprint + tf.cumsum([a, b, c]) ==> [a, a + b, a + b + c] + ``` + + By setting the `exclusive` kwarg to `True`, an exclusive cumsum is performed + instead: + ```prettyprint + tf.cumsum([a, b, c], exclusive=True) ==> [0, a, a + b] + ``` + + By setting the `reverse` kwarg to `True`, the cumsum is performed in the + opposite direction: + ```prettyprint + tf.cumsum([a, b, c], reverse=True) ==> [a + b + c, b + c, c] + ``` + This is more efficient than using separate `tf.reverse` ops. + + The `reverse` and `exclusive` kwargs can also be combined: + ```prettyprint + tf.cumsum([a, b, c], exclusive=True, reverse=True) ==> [b + c, c, 0] + ``` + + Args: + x: A `Tensor`. Must be one of the following types: `float32`, `float64`, + `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, + `complex128`, `qint8`, `quint8`, `qint32`, `half`. + axis: A `Tensor` of type `int32` (default: 0). + reverse: A `bool` (default: False). + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `x`. + """ + with ops.op_scope([x], name, "Cumsum") as name: + x = ops.convert_to_tensor(x, name="x") + return gen_math_ops.cumsum(x, axis, exclusive=exclusive, + reverse=reverse, name=name) + + +def cumprod(x, axis=0, exclusive=False, reverse=False, name=None): + """Compute the cumulative product of the tensor `x` along `axis`. + + By default, this op performs an inclusive cumprod, which means that the first + element of the input is identical to the first element of the output: + ```prettyprint + tf.cumprod([a, b, c]) ==> [a, a * b, a * b * c] + ``` + + By setting the `exclusive` kwarg to `True`, an exclusive cumprod is performed + instead: + ```prettyprint + tf.cumprod([a, b, c], exclusive=True) ==> [0, a, a * b] + ``` + + By setting the `reverse` kwarg to `True`, the cumprod is performed in the + opposite direction: + ```prettyprint + tf.cumprod([a, b, c], reverse=True) ==> [a * b * c, b * c, c] + ``` + This is more efficient than using separate `tf.reverse` ops. + + The `reverse` and `exclusive` kwargs can also be combined: + ```prettyprint + tf.cumprod([a, b, c], exclusive=True, reverse=True) ==> [b * c, c, 0] + ``` + + Args: + x: A `Tensor`. Must be one of the following types: `float32`, `float64`, + `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, + `complex128`, `qint8`, `quint8`, `qint32`, `half`. + axis: A `Tensor` of type `int32` (default: 0). + reverse: A `bool` (default: False). + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `x`. + """ + with ops.op_scope([x], name, "Cumprod") as name: + x = ops.convert_to_tensor(x, name="x") + return gen_math_ops.cumprod(x, axis, exclusive=exclusive, + reverse=reverse, name=name) + + ops.RegisterShape("Abs")(common_shapes.unchanged_shape) ops.RegisterShape("Acos")(common_shapes.unchanged_shape) ops.RegisterShape("Asin")(common_shapes.unchanged_shape) @@ -1632,6 +1731,8 @@ ops.RegisterShape("BatchFFT3D")(common_shapes.unchanged_shape) ops.RegisterShape("BatchIFFT3D")(common_shapes.unchanged_shape) ops.RegisterShape("TanhGrad")(common_shapes.unchanged_shape) ops.RegisterShape("SigmoidGrad")(common_shapes.unchanged_shape) +ops.RegisterShape("Cumsum")(common_shapes.unchanged_shape) +ops.RegisterShape("Cumprod")(common_shapes.unchanged_shape) @ops.RegisterShape("Add") |