aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/tensor_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/tensor_util.py')
-rw-r--r--tensorflow/python/framework/tensor_util.py21
1 files changed, 18 insertions, 3 deletions
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index b1b39f0651..7a9add319a 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -76,11 +76,16 @@ else:
def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.int64_val.extend([np.asscalar(x) for x in proto_values])
- def SlowAppendComplexArrayToTensorProto(tensor_proto, proto_values):
+ def SlowAppendComplex64ArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.scomplex_val.extend([np.asscalar(v)
for x in proto_values
for v in [x.real, x.imag]])
+ def SlowAppendComplex128ArrayToTensorProto(tensor_proto, proto_values):
+ tensor_proto.dcomplex_val.extend([np.asscalar(v)
+ for x in proto_values
+ for v in [x.real, x.imag]])
+
def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values):
tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values])
@@ -96,8 +101,8 @@ else:
np.uint16: SlowAppendIntArrayToTensorProto,
np.int16: SlowAppendIntArrayToTensorProto,
np.int8: SlowAppendIntArrayToTensorProto,
- np.complex64: SlowAppendComplexArrayToTensorProto,
- np.complex128: SlowAppendComplexArrayToTensorProto,
+ np.complex64: SlowAppendComplex64ArrayToTensorProto,
+ np.complex128: SlowAppendComplex128ArrayToTensorProto,
np.object: SlowAppendObjectArrayToTensorProto,
np.bool: SlowAppendBoolArrayToTensorProto,
dtypes.qint8.as_numpy_dtype: SlowAppendIntArrayToTensorProto,
@@ -240,6 +245,7 @@ _TF_TO_IS_OK = {
dtypes.int8: _FilterInt,
dtypes.string: _FilterStr,
dtypes.complex64: _FilterComplex,
+ dtypes.complex128: _FilterComplex,
dtypes.int64: _FilterInt,
dtypes.bool: _FilterBool,
dtypes.qint32: _FilterInt,
@@ -453,6 +459,15 @@ def MakeNdarray(tensor):
else:
return np.array([complex(x[0], x[1]) for x in zip(it, it)],
dtype=dtype).reshape(shape)
+ elif tensor_dtype == dtypes.complex128:
+ it = iter(tensor.dcomplex_val)
+ if len(tensor.dcomplex_val) == 2:
+ return np.repeat(np.array(complex(tensor.dcomplex_val[0],
+ tensor.dcomplex_val[1]), dtype=dtype),
+ num_elements).reshape(shape)
+ else:
+ return np.array([complex(x[0], x[1]) for x in zip(it, it)],
+ dtype=dtype).reshape(shape)
elif tensor_dtype == dtypes.bool:
if len(tensor.bool_val) == 1:
return np.repeat(np.array(tensor.bool_val[0], dtype=dtype),