diff options
Diffstat (limited to 'tensorflow/python/ops/math_ops.py')
-rw-r--r-- | tensorflow/python/ops/math_ops.py | 260 |
1 files changed, 136 insertions, 124 deletions
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 3e91ec0684..7ee095745a 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -233,9 +233,9 @@ def abs(x, name=None): `float32` or `float64` that is the absolute value of each element in `x`. All elements in `x` must be complex numbers of the form \\(a + bj\\). The absolute value is computed as \\( \sqrt{a^2 + b^2}\\). For example: - ``` - # tensor 'x' is [[-2.25 + 4.75j], [-3.25 + 5.75j]] - tf.complex_abs(x) ==> [5.25594902, 6.60492229] + ```python + x = tf.constant([[-2.25 + 4.75j], [-3.25 + 5.75j]]) + tf.abs(x) # [5.25594902, 6.60492229] ``` Args: @@ -524,10 +524,10 @@ def pow(x, y, name=None): Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for corresponding elements in `x` and `y`. For example: - ``` - # tensor 'x' is [[2, 2], [3, 3]] - # tensor 'y' is [[8, 16], [2, 3]] - tf.pow(x, y) ==> [[256, 65536], [9, 27]] + ```python + x = tf.constant([[2, 2], [3, 3]]) + y = tf.constant([[8, 16], [2, 3]]) + tf.pow(x, y) # [[256, 65536], [9, 27]] ``` Args: @@ -557,10 +557,10 @@ def complex(real, imag, name=None): For example: - ``` - # tensor 'real' is [2.25, 3.25] - # tensor `imag` is [4.75, 5.75] - tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]] + ```python + real = tf.constant([2.25, 3.25]) + imag = tf.constant([4.75, 5.75]) + tf.complex(real, imag) # [[2.25 + 4.75j], [3.25 + 5.75j]] ``` Args: @@ -597,9 +597,9 @@ def real(input, name=None): For example: - ``` - # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] - tf.real(input) ==> [-2.25, 3.25] + ```python + x = tf.constant([-2.25 + 4.75j, 3.25 + 5.75j]) + tf.real(x) # [-2.25, 3.25] ``` If `input` is already real, it is returned unchanged. @@ -629,9 +629,9 @@ def imag(input, name=None): For example: - ``` - # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] - tf.imag(input) ==> [4.75, 5.75] + ```python + x = tf.constant([-2.25 + 4.75j, 3.25 + 5.75j]) + tf.imag(x) # [4.75, 5.75] ``` Args: @@ -657,8 +657,8 @@ def round(x, name=None): For example: ```python - # 'a' is [0.9, 2.5, 2.3, 1.5, -4.5] - tf.round(a) ==> [ 1.0, 2.0, 2.0, 2.0, -4.0 ] + x = tf.constant([0.9, 2.5, 2.3, 1.5, -4.5]) + tf.round(x) # [ 1.0, 2.0, 2.0, 2.0, -4.0 ] ``` Args: @@ -684,8 +684,8 @@ def cast(x, dtype, name=None): For example: ```python - # tensor `a` is [1.8, 2.2], dtype=tf.float - tf.cast(a, tf.int32) ==> [1, 2] # dtype=tf.int32 + x = tf.constant([1.8, 2.2], dtype=tf.float32) + tf.cast(x, tf.int32) # [1, 2], dtype=tf.int32 ``` Args: @@ -1147,18 +1147,18 @@ def range(start, limit=None, delta=1, dtype=None, name="range"): For example: ```python - # 'start' is 3 - # 'limit' is 18 - # 'delta' is 3 - tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] - - # 'start' is 3 - # 'limit' is 1 - # 'delta' is -0.5 - tf.range(start, limit, delta) ==> [3, 2.5, 2, 1.5] - - # 'limit' is 5 - tf.range(limit) ==> [0, 1, 2, 3, 4] + start = 3 + limit = 18 + delta = 3 + tf.range(start, limit, delta) # [3, 6, 9, 12, 15] + + start = 3 + limit = 1 + delta = -0.5 + tf.range(start, limit, delta) # [3, 2.5, 2, 1.5] + + limit = 5 + tf.range(limit) # [0, 1, 2, 3, 4] ``` Args: @@ -1247,13 +1247,12 @@ def reduce_sum(input_tensor, For example: ```python - # 'x' is [[1, 1, 1] - # [1, 1, 1]] - tf.reduce_sum(x) ==> 6 - tf.reduce_sum(x, 0) ==> [2, 2, 2] - tf.reduce_sum(x, 1) ==> [3, 3] - tf.reduce_sum(x, 1, keep_dims=True) ==> [[3], [3]] - tf.reduce_sum(x, [0, 1]) ==> 6 + x = tf.constant([[1, 1, 1], [1, 1, 1]]) + tf.reduce_sum(x) # 6 + tf.reduce_sum(x, 0) # [2, 2, 2] + tf.reduce_sum(x, 1) # [3, 3] + tf.reduce_sum(x, 1, keep_dims=True) # [[3], [3]] + tf.reduce_sum(x, [0, 1]) # 6 ``` Args: @@ -1302,13 +1301,12 @@ def count_nonzero(input_tensor, For example: ```python - # 'x' is [[0, 1, 0] - # [1, 1, 0]] - tf.count_nonzero(x) ==> 3 - tf.count_nonzero(x, 0) ==> [1, 2, 0] - tf.count_nonzero(x, 1) ==> [1, 2] - tf.count_nonzero(x, 1, keep_dims=True) ==> [[1], [2]] - tf.count_nonzero(x, [0, 1]) ==> 3 + x = tf.constant([[0, 1, 0], [1, 1, 0]]) + tf.count_nonzero(x) # 3 + tf.count_nonzero(x, 0) # [1, 2, 0] + tf.count_nonzero(x, 1) # [1, 2] + tf.count_nonzero(x, 1, keep_dims=True) # [[1], [2]] + tf.count_nonzero(x, [0, 1]) # 3 ``` Args: @@ -1355,11 +1353,10 @@ def reduce_mean(input_tensor, For example: ```python - # 'x' is [[1., 1.] - # [2., 2.]] - tf.reduce_mean(x) ==> 1.5 - tf.reduce_mean(x, 0) ==> [1.5, 1.5] - tf.reduce_mean(x, 1) ==> [1., 2.] + x = tf.constant([[1., 1.], [2., 2.]]) + tf.reduce_mean(x) # 1.5 + tf.reduce_mean(x, 0) # [1.5, 1.5] + tf.reduce_mean(x, 1) # [1., 2.] ``` Args: @@ -1517,11 +1514,10 @@ def reduce_all(input_tensor, For example: ```python - # 'x' is [[True, True] - # [False, False]] - tf.reduce_all(x) ==> False - tf.reduce_all(x, 0) ==> [False, False] - tf.reduce_all(x, 1) ==> [True, False] + x = tf.constant([[True, True], [False, False]]) + tf.reduce_all(x) # False + tf.reduce_all(x, 0) # [False, False] + tf.reduce_all(x, 1) # [True, False] ``` Args: @@ -1565,11 +1561,10 @@ def reduce_any(input_tensor, For example: ```python - # 'x' is [[True, True] - # [False, False]] - tf.reduce_any(x) ==> True - tf.reduce_any(x, 0) ==> [True, True] - tf.reduce_any(x, 1) ==> [True, False] + x = tf.constant([[True, True], [False, False]]) + tf.reduce_any(x) # True + tf.reduce_any(x, 0) # [True, True] + tf.reduce_any(x, 1) # [True, False] ``` Args: @@ -1617,13 +1612,12 @@ def reduce_logsumexp(input_tensor, For example: ```python - # 'x' is [[0, 0, 0]] - # [0, 0, 0]] - tf.reduce_logsumexp(x) ==> log(6) - tf.reduce_logsumexp(x, 0) ==> [log(2), log(2), log(2)] - tf.reduce_logsumexp(x, 1) ==> [log(3), log(3)] - tf.reduce_logsumexp(x, 1, keep_dims=True) ==> [[log(3)], [log(3)]] - tf.reduce_logsumexp(x, [0, 1]) ==> log(6) + x = tf.constant([[0., 0., 0.], [0., 0., 0.]]) + tf.reduce_logsumexp(x) # log(6) + tf.reduce_logsumexp(x, 0) # [log(2), log(2), log(2)] + tf.reduce_logsumexp(x, 1) # [log(3), log(3)] + tf.reduce_logsumexp(x, 1, keep_dims=True) # [[log(3)], [log(3)]] + tf.reduce_logsumexp(x, [0, 1]) # log(6) ``` Args: @@ -1639,12 +1633,16 @@ def reduce_logsumexp(input_tensor, The reduced tensor. """ with ops.name_scope(name, "ReduceLogSumExp", [input_tensor]) as name: + raw_max = reduce_max( + input_tensor, + axis=axis, + reduction_indices=reduction_indices, + keep_dims=True) my_max = array_ops.stop_gradient( - reduce_max( - input_tensor, - axis=axis, - reduction_indices=reduction_indices, - keep_dims=True)) + array_ops.where( + gen_math_ops.is_finite(raw_max), + raw_max, + array_ops.zeros_like(raw_max))) result = gen_math_ops.log( reduce_sum( gen_math_ops.exp(input_tensor - my_max), @@ -1670,22 +1668,21 @@ def trace(x, name=None): For example: ```python - # 'x' is [[1, 2], - # [3, 4]] - tf.trace(x) ==> 5 - - # 'x' is [[1,2,3], - # [4,5,6], - # [7,8,9]] - tf.trace(x) ==> 15 - - # 'x' is [[[1,2,3], - # [4,5,6], - # [7,8,9]], - # [[-1,-2,-3], - # [-4,-5,-6], - # [-7,-8,-9]]] - tf.trace(x) ==> [15,-15] + x = tf.constant([[1, 2], [3, 4]]) + tf.trace(x) # 5 + + x = tf.constant([[1, 2, 3], + [4, 5, 6], + [7, 8, 9]]) + tf.trace(x) # 15 + + x = tf.constant([[[1, 2, 3], + [4, 5, 6], + [7, 8, 9]], + [[-1, -2, -3], + [-4, -5, -6], + [-7, -8, -9]]]) + tf.trace(x) # [15, -15] ``` Args: @@ -1732,35 +1729,46 @@ def matmul(a, ```python # 2-D tensor `a` - a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3]) => [[1. 2. 3.] - [4. 5. 6.]] + # [[1, 2, 3], + # [4, 5, 6]] + a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3]) + # 2-D tensor `b` - b = tf.constant([7, 8, 9, 10, 11, 12], shape=[3, 2]) => [[7. 8.] - [9. 10.] - [11. 12.]] - c = tf.matmul(a, b) => [[58 64] - [139 154]] + # [[ 7, 8], + # [ 9, 10], + # [11, 12]] + b = tf.constant([7, 8, 9, 10, 11, 12], shape=[3, 2]) + + # `a` * `b` + # [[ 58, 64], + # [139, 154]] + c = tf.matmul(a, b) # 3-D tensor `a` + # [[[ 1, 2, 3], + # [ 4, 5, 6]], + # [[ 7, 8, 9], + # [10, 11, 12]]] a = tf.constant(np.arange(1, 13, dtype=np.int32), - shape=[2, 2, 3]) => [[[ 1. 2. 3.] - [ 4. 5. 6.]], - [[ 7. 8. 9.] - [10. 11. 12.]]] + shape=[2, 2, 3]) # 3-D tensor `b` + # [[[13, 14], + # [15, 16], + # [17, 18]], + # [[19, 20], + # [21, 22], + # [23, 24]]] b = tf.constant(np.arange(13, 25, dtype=np.int32), - shape=[2, 3, 2]) => [[[13. 14.] - [15. 16.] - [17. 18.]], - [[19. 20.] - [21. 22.] - [23. 24.]]] - c = tf.matmul(a, b) => [[[ 94 100] - [229 244]], - [[508 532] - [697 730]]] + shape=[2, 3, 2]) + + # `a` * `b` + # [[[ 94, 100], + # [229, 244]], + # [[508, 532], + # [697, 730]]] + c = tf.matmul(a, b) # Since python >= 3.5 the @ operator is supported (see PEP 465). # In TensorFlow, it simply calls the `tf.matmul()` function, so the @@ -1980,13 +1988,13 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): For example: ```python - # tensor 'a' is [[1, 2], [3, 4]] - # tensor `b` is [[5, 0], [0, 6]] - tf.accumulate_n([a, b, a]) ==> [[7, 4], [6, 14]] + a = tf.constant([[1, 2], [3, 4]]) + b = tf.constant([[5, 0], [0, 6]]) + tf.accumulate_n([a, b, a]) # [[7, 4], [6, 14]] # Explicitly pass shape and type - tf.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) - ==> [[7, 4], [6, 14]] + tf.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) # [[7, 4], + # [6, 14]] ``` Args: @@ -2151,21 +2159,21 @@ def cumsum(x, axis=0, exclusive=False, reverse=False, name=None): element of the input is identical to the first element of the output: ```python - tf.cumsum([a, b, c]) # => [a, a + b, a + b + c] + tf.cumsum([a, b, c]) # [a, a + b, a + b + c] ``` By setting the `exclusive` kwarg to `True`, an exclusive cumsum is performed instead: ```python - tf.cumsum([a, b, c], exclusive=True) # => [0, a, a + b] + tf.cumsum([a, b, c], exclusive=True) # [0, a, a + b] ``` By setting the `reverse` kwarg to `True`, the cumsum is performed in the opposite direction: ```python - tf.cumsum([a, b, c], reverse=True) # => [a + b + c, b + c, c] + tf.cumsum([a, b, c], reverse=True) # [a + b + c, b + c, c] ``` This is more efficient than using separate `tf.reverse` ops. @@ -2173,7 +2181,7 @@ def cumsum(x, axis=0, exclusive=False, reverse=False, name=None): The `reverse` and `exclusive` kwargs can also be combined: ```python - tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] + tf.cumsum([a, b, c], exclusive=True, reverse=True) # [b + c, c, 0] ``` Args: @@ -2202,7 +2210,7 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None): first element of the input is identical to the first element of the output: ```python - tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] + tf.cumprod([a, b, c]) # [a, a * b, a * b * c] ``` By setting the `exclusive` kwarg to `True`, an exclusive cumprod is @@ -2210,21 +2218,21 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None): instead: ```python - tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] + tf.cumprod([a, b, c], exclusive=True) # [1, a, a * b] ``` By setting the `reverse` kwarg to `True`, the cumprod is performed in the opposite direction: ```python - tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] + tf.cumprod([a, b, c], reverse=True) # [a * b * c, b * c, c] ``` This is more efficient than using separate `tf.reverse` ops. The `reverse` and `exclusive` kwargs can also be combined: ```python - tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] + tf.cumprod([a, b, c], exclusive=True, reverse=True) # [b * c, c, 1] ``` Args: @@ -2448,6 +2456,10 @@ def tensordot(a, b, axes, name=None): raise ValueError("'axes' must be an integer or have length 2.") a_axes = axes[0] b_axes = axes[1] + if isinstance(a_axes, compat.integral_types) and \ + isinstance(b_axes, compat.integral_types): + a_axes = [a_axes] + b_axes = [b_axes] if len(a_axes) != len(b_axes): raise ValueError( "Different number of contraction axes 'a' and 'b', %s != %s.", |