aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/math_ops.py')
-rw-r--r--tensorflow/python/ops/math_ops.py29
1 files changed, 17 insertions, 12 deletions
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index cdb6dc8f22..fbe6b62302 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -37,11 +37,11 @@ from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gen_sparse_ops
from tensorflow.python.ops import gen_spectral_ops
-from tensorflow.python.platform import tf_logging as logging
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_math_ops import *
# pylint: enable=wildcard-import
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
@@ -628,16 +628,17 @@ def cast(x, dtype, name=None):
```
The operation supports data types (for `x` and `dtype`) of
- `uint8`, `int8`, `uint16`, `int16`, `int32`, `int64`, `float16`, `float32`,
- `float64`, `complex64`, `complex128`, `bfloat16`. In case of casting from
- complex types (`complex64`, `complex128`) to real types, only the real part
- of `x` is returned. In case of casting from real types to complex types
- (`complex64`, `complex128`), the imaginary part of the returned value is set
- to `0`. The handling of complex types here matches the behavior of numpy.
+ `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`,
+ `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`.
+ In case of casting from complex types (`complex64`, `complex128`) to real
+ types, only the real part of `x` is returned. In case of casting from real
+ types to complex types (`complex64`, `complex128`), the imaginary part of the
+ returned value is set to `0`. The handling of complex types here matches the
+ behavior of numpy.
Args:
x: A `Tensor` or `SparseTensor` of numeric type. It could be
- `uint8`, `int8`, `uint16`, `int16`, `int32`, `int64`,
+ `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`,
`float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`.
dtype: The destination type. The list of supported dtypes is the same
as `x`.
@@ -651,6 +652,9 @@ def cast(x, dtype, name=None):
TypeError: If `x` cannot be cast to the `dtype`.
"""
base_type = dtypes.as_dtype(dtype).base_dtype
+ if isinstance(x,
+ (ops.Tensor, _resource_variable_type)) and base_type == x.dtype:
+ return x
with ops.name_scope(name, "Cast", [x]) as name:
if isinstance(x, sparse_tensor.SparseTensor):
values_cast = cast(x.values, base_type, name=name)
@@ -1222,8 +1226,9 @@ def _ReductionDims(x, axis, reduction_indices):
return axis
else:
# Fast path: avoid creating Rank and Range ops if ndims is known.
- if isinstance(x, ops.Tensor) and x._rank() is not None: # pylint: disable=protected-access
- return constant_op.constant(np.arange(x._rank()), dtype=dtypes.int32) # pylint: disable=protected-access
+ rank = common_shapes.rank(x)
+ if rank is not None:
+ return constant_op.constant(np.arange(rank), dtype=dtypes.int32)
if (isinstance(x, sparse_tensor.SparseTensor) and
x.dense_shape.get_shape().is_fully_defined()):
rank = x.dense_shape.get_shape()[0].value # sparse.dense_shape is 1-D.
@@ -1234,8 +1239,8 @@ def _ReductionDims(x, axis, reduction_indices):
def _may_reduce_to_scalar(keepdims, axis, reduction_indices, output):
- """Set a reduction's output's shape to be a scalar if we are certain."""
- if (not output.shape.is_fully_defined()) and (not keepdims) and (
+ """Set a reduction's output shape to be a scalar if we are certain."""
+ if not common_shapes.has_fully_defined_shape(output) and (not keepdims) and (
axis is None) and (reduction_indices is None):
output.set_shape(())
return output