diff options
Diffstat (limited to 'tensorflow/python/framework/dtypes.py')
-rw-r--r-- | tensorflow/python/framework/dtypes.py | 22 |
1 files changed, 17 insertions, 5 deletions
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index 9c1e05f8bc..d964a7f29b 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -32,6 +32,7 @@ class DType(object): * `tf.float64`: 64-bit double-precision floating-point. * `tf.bfloat16`: 16-bit truncated floating-point. * `tf.complex64`: 64-bit single-precision complex. + * `tf.complex128`: 128-bit double-precision complex. * `tf.int8`: 8-bit signed integer. * `tf.uint8`: 8-bit unsigned integer. @@ -122,6 +123,8 @@ class DType(object): base = self.base_dtype if base == complex64: return float32 + elif base == complex128: + return float64 else: return self @@ -149,7 +152,7 @@ class DType(object): @property def is_complex(self): """Returns whether this is a complex floating point type.""" - return self.base_dtype == complex64 + return self.base_dtype in (complex64, complex128) @property def is_quantized(self): @@ -179,8 +182,8 @@ class DType(object): TypeError: if this is a non-numeric, unordered, or quantized type. """ - if (self.is_quantized or self.base_dtype == bool or - self.base_dtype == string or self.base_dtype == complex64): + if (self.is_quantized or self.base_dtype in + (bool, string, complex64, complex128)): raise TypeError("Cannot find minimum value of %s." % self) # there is no simple way to get the min value of a dtype, we have to check @@ -201,8 +204,8 @@ class DType(object): TypeError: if this is a non-numeric, unordered, or quantized type. """ - if (self.is_quantized or self.base_dtype == bool or - self.base_dtype == string or self.base_dtype == complex64): + if (self.is_quantized or self.base_dtype in + (bool, string, complex64, complex128)): raise TypeError("Cannot find maximum value of %s." % self) # there is no simple way to get the min value of a dtype, we have to check @@ -277,6 +280,7 @@ int16 = DType(types_pb2.DT_INT16) int8 = DType(types_pb2.DT_INT8) string = DType(types_pb2.DT_STRING) complex64 = DType(types_pb2.DT_COMPLEX64) +complex128 = DType(types_pb2.DT_COMPLEX128) int64 = DType(types_pb2.DT_INT64) bool = DType(types_pb2.DT_BOOL) qint8 = DType(types_pb2.DT_QINT8) @@ -295,6 +299,7 @@ int16_ref = DType(types_pb2.DT_INT16_REF) int8_ref = DType(types_pb2.DT_INT8_REF) string_ref = DType(types_pb2.DT_STRING_REF) complex64_ref = DType(types_pb2.DT_COMPLEX64_REF) +complex128_ref = DType(types_pb2.DT_COMPLEX128_REF) int64_ref = DType(types_pb2.DT_INT64_REF) bool_ref = DType(types_pb2.DT_BOOL_REF) qint8_ref = DType(types_pb2.DT_QINT8_REF) @@ -317,6 +322,7 @@ _INTERN_TABLE = { types_pb2.DT_INT8: int8, types_pb2.DT_STRING: string, types_pb2.DT_COMPLEX64: complex64, + types_pb2.DT_COMPLEX128: complex128, types_pb2.DT_INT64: int64, types_pb2.DT_BOOL: bool, types_pb2.DT_QINT8: qint8, @@ -334,6 +340,7 @@ _INTERN_TABLE = { types_pb2.DT_INT8_REF: int8_ref, types_pb2.DT_STRING_REF: string_ref, types_pb2.DT_COMPLEX64_REF: complex64_ref, + types_pb2.DT_COMPLEX128_REF: complex128_ref, types_pb2.DT_INT64_REF: int64_ref, types_pb2.DT_BOOL_REF: bool_ref, types_pb2.DT_QINT8_REF: qint8_ref, @@ -356,6 +363,7 @@ _TYPE_TO_STRING = { types_pb2.DT_INT8: "int8", types_pb2.DT_STRING: "string", types_pb2.DT_COMPLEX64: "complex64", + types_pb2.DT_COMPLEX128: "complex128", types_pb2.DT_INT64: "int64", types_pb2.DT_BOOL: "bool", types_pb2.DT_QINT8: "qint8", @@ -373,6 +381,7 @@ _TYPE_TO_STRING = { types_pb2.DT_INT8_REF: "int8_ref", types_pb2.DT_STRING_REF: "string_ref", types_pb2.DT_COMPLEX64_REF: "complex64_ref", + types_pb2.DT_COMPLEX128_REF: "complex128_ref", types_pb2.DT_INT64_REF: "int64_ref", types_pb2.DT_BOOL_REF: "bool_ref", types_pb2.DT_QINT8_REF: "qint8_ref", @@ -414,6 +423,7 @@ _NP_TO_TF = frozenset([ (np.int16, int16), (np.int8, int8), (np.complex64, complex64), + (np.complex128, complex128), (np.object, string), (np.bool, bool), (_np_qint8, qint8), @@ -435,6 +445,7 @@ _TF_TO_NP = { # strings. types_pb2.DT_STRING: np.object, types_pb2.DT_COMPLEX64: np.complex64, + types_pb2.DT_COMPLEX128: np.complex128, types_pb2.DT_INT64: np.int64, types_pb2.DT_BOOL: np.bool, types_pb2.DT_QINT8: _np_qint8, @@ -454,6 +465,7 @@ _TF_TO_NP = { types_pb2.DT_INT8_REF: np.int8, types_pb2.DT_STRING_REF: np.object, types_pb2.DT_COMPLEX64_REF: np.complex64, + types_pb2.DT_COMPLEX128_REF: np.complex128, types_pb2.DT_INT64_REF: np.int64, types_pb2.DT_BOOL_REF: np.bool, types_pb2.DT_QINT8_REF: _np_qint8, |