diff options
Diffstat (limited to 'tensorflow/python/ops/math_ops.py')
-rw-r--r-- | tensorflow/python/ops/math_ops.py | 29 |
1 files changed, 17 insertions, 12 deletions
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index cdb6dc8f22..fbe6b62302 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -37,11 +37,11 @@ from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gen_sparse_ops from tensorflow.python.ops import gen_spectral_ops -from tensorflow.python.platform import tf_logging as logging # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_math_ops import * # pylint: enable=wildcard-import +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util import nest @@ -628,16 +628,17 @@ def cast(x, dtype, name=None): ``` The operation supports data types (for `x` and `dtype`) of - `uint8`, `int8`, `uint16`, `int16`, `int32`, `int64`, `float16`, `float32`, - `float64`, `complex64`, `complex128`, `bfloat16`. In case of casting from - complex types (`complex64`, `complex128`) to real types, only the real part - of `x` is returned. In case of casting from real types to complex types - (`complex64`, `complex128`), the imaginary part of the returned value is set - to `0`. The handling of complex types here matches the behavior of numpy. + `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`, + `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`. + In case of casting from complex types (`complex64`, `complex128`) to real + types, only the real part of `x` is returned. In case of casting from real + types to complex types (`complex64`, `complex128`), the imaginary part of the + returned value is set to `0`. The handling of complex types here matches the + behavior of numpy. Args: x: A `Tensor` or `SparseTensor` of numeric type. It could be - `uint8`, `int8`, `uint16`, `int16`, `int32`, `int64`, + `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`, `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`. dtype: The destination type. The list of supported dtypes is the same as `x`. @@ -651,6 +652,9 @@ def cast(x, dtype, name=None): TypeError: If `x` cannot be cast to the `dtype`. """ base_type = dtypes.as_dtype(dtype).base_dtype + if isinstance(x, + (ops.Tensor, _resource_variable_type)) and base_type == x.dtype: + return x with ops.name_scope(name, "Cast", [x]) as name: if isinstance(x, sparse_tensor.SparseTensor): values_cast = cast(x.values, base_type, name=name) @@ -1222,8 +1226,9 @@ def _ReductionDims(x, axis, reduction_indices): return axis else: # Fast path: avoid creating Rank and Range ops if ndims is known. - if isinstance(x, ops.Tensor) and x._rank() is not None: # pylint: disable=protected-access - return constant_op.constant(np.arange(x._rank()), dtype=dtypes.int32) # pylint: disable=protected-access + rank = common_shapes.rank(x) + if rank is not None: + return constant_op.constant(np.arange(rank), dtype=dtypes.int32) if (isinstance(x, sparse_tensor.SparseTensor) and x.dense_shape.get_shape().is_fully_defined()): rank = x.dense_shape.get_shape()[0].value # sparse.dense_shape is 1-D. @@ -1234,8 +1239,8 @@ def _ReductionDims(x, axis, reduction_indices): def _may_reduce_to_scalar(keepdims, axis, reduction_indices, output): - """Set a reduction's output's shape to be a scalar if we are certain.""" - if (not output.shape.is_fully_defined()) and (not keepdims) and ( + """Set a reduction's output shape to be a scalar if we are certain.""" + if not common_shapes.has_fully_defined_shape(output) and (not keepdims) and ( axis is None) and (reduction_indices is None): output.set_shape(()) return output |