aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/dtypes.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/dtypes.py')
-rw-r--r--tensorflow/python/framework/dtypes.py22
1 files changed, 17 insertions, 5 deletions
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index 9c1e05f8bc..d964a7f29b 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -32,6 +32,7 @@ class DType(object):
* `tf.float64`: 64-bit double-precision floating-point.
* `tf.bfloat16`: 16-bit truncated floating-point.
* `tf.complex64`: 64-bit single-precision complex.
+ * `tf.complex128`: 128-bit double-precision complex.
* `tf.int8`: 8-bit signed integer.
* `tf.uint8`: 8-bit unsigned integer.
@@ -122,6 +123,8 @@ class DType(object):
base = self.base_dtype
if base == complex64:
return float32
+ elif base == complex128:
+ return float64
else:
return self
@@ -149,7 +152,7 @@ class DType(object):
@property
def is_complex(self):
"""Returns whether this is a complex floating point type."""
- return self.base_dtype == complex64
+ return self.base_dtype in (complex64, complex128)
@property
def is_quantized(self):
@@ -179,8 +182,8 @@ class DType(object):
TypeError: if this is a non-numeric, unordered, or quantized type.
"""
- if (self.is_quantized or self.base_dtype == bool or
- self.base_dtype == string or self.base_dtype == complex64):
+ if (self.is_quantized or self.base_dtype in
+ (bool, string, complex64, complex128)):
raise TypeError("Cannot find minimum value of %s." % self)
# there is no simple way to get the min value of a dtype, we have to check
@@ -201,8 +204,8 @@ class DType(object):
TypeError: if this is a non-numeric, unordered, or quantized type.
"""
- if (self.is_quantized or self.base_dtype == bool or
- self.base_dtype == string or self.base_dtype == complex64):
+ if (self.is_quantized or self.base_dtype in
+ (bool, string, complex64, complex128)):
raise TypeError("Cannot find maximum value of %s." % self)
# there is no simple way to get the min value of a dtype, we have to check
@@ -277,6 +280,7 @@ int16 = DType(types_pb2.DT_INT16)
int8 = DType(types_pb2.DT_INT8)
string = DType(types_pb2.DT_STRING)
complex64 = DType(types_pb2.DT_COMPLEX64)
+complex128 = DType(types_pb2.DT_COMPLEX128)
int64 = DType(types_pb2.DT_INT64)
bool = DType(types_pb2.DT_BOOL)
qint8 = DType(types_pb2.DT_QINT8)
@@ -295,6 +299,7 @@ int16_ref = DType(types_pb2.DT_INT16_REF)
int8_ref = DType(types_pb2.DT_INT8_REF)
string_ref = DType(types_pb2.DT_STRING_REF)
complex64_ref = DType(types_pb2.DT_COMPLEX64_REF)
+complex128_ref = DType(types_pb2.DT_COMPLEX128_REF)
int64_ref = DType(types_pb2.DT_INT64_REF)
bool_ref = DType(types_pb2.DT_BOOL_REF)
qint8_ref = DType(types_pb2.DT_QINT8_REF)
@@ -317,6 +322,7 @@ _INTERN_TABLE = {
types_pb2.DT_INT8: int8,
types_pb2.DT_STRING: string,
types_pb2.DT_COMPLEX64: complex64,
+ types_pb2.DT_COMPLEX128: complex128,
types_pb2.DT_INT64: int64,
types_pb2.DT_BOOL: bool,
types_pb2.DT_QINT8: qint8,
@@ -334,6 +340,7 @@ _INTERN_TABLE = {
types_pb2.DT_INT8_REF: int8_ref,
types_pb2.DT_STRING_REF: string_ref,
types_pb2.DT_COMPLEX64_REF: complex64_ref,
+ types_pb2.DT_COMPLEX128_REF: complex128_ref,
types_pb2.DT_INT64_REF: int64_ref,
types_pb2.DT_BOOL_REF: bool_ref,
types_pb2.DT_QINT8_REF: qint8_ref,
@@ -356,6 +363,7 @@ _TYPE_TO_STRING = {
types_pb2.DT_INT8: "int8",
types_pb2.DT_STRING: "string",
types_pb2.DT_COMPLEX64: "complex64",
+ types_pb2.DT_COMPLEX128: "complex128",
types_pb2.DT_INT64: "int64",
types_pb2.DT_BOOL: "bool",
types_pb2.DT_QINT8: "qint8",
@@ -373,6 +381,7 @@ _TYPE_TO_STRING = {
types_pb2.DT_INT8_REF: "int8_ref",
types_pb2.DT_STRING_REF: "string_ref",
types_pb2.DT_COMPLEX64_REF: "complex64_ref",
+ types_pb2.DT_COMPLEX128_REF: "complex128_ref",
types_pb2.DT_INT64_REF: "int64_ref",
types_pb2.DT_BOOL_REF: "bool_ref",
types_pb2.DT_QINT8_REF: "qint8_ref",
@@ -414,6 +423,7 @@ _NP_TO_TF = frozenset([
(np.int16, int16),
(np.int8, int8),
(np.complex64, complex64),
+ (np.complex128, complex128),
(np.object, string),
(np.bool, bool),
(_np_qint8, qint8),
@@ -435,6 +445,7 @@ _TF_TO_NP = {
# strings.
types_pb2.DT_STRING: np.object,
types_pb2.DT_COMPLEX64: np.complex64,
+ types_pb2.DT_COMPLEX128: np.complex128,
types_pb2.DT_INT64: np.int64,
types_pb2.DT_BOOL: np.bool,
types_pb2.DT_QINT8: _np_qint8,
@@ -454,6 +465,7 @@ _TF_TO_NP = {
types_pb2.DT_INT8_REF: np.int8,
types_pb2.DT_STRING_REF: np.object,
types_pb2.DT_COMPLEX64_REF: np.complex64,
+ types_pb2.DT_COMPLEX128_REF: np.complex128,
types_pb2.DT_INT64_REF: np.int64,
types_pb2.DT_BOOL_REF: np.bool,
types_pb2.DT_QINT8_REF: _np_qint8,