diff options
Diffstat (limited to 'tensorflow/python/framework/tensor_util.py')
-rw-r--r-- | tensorflow/python/framework/tensor_util.py | 21 |
1 files changed, 18 insertions, 3 deletions
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index b1b39f0651..7a9add319a 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -76,11 +76,16 @@ else: def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values): tensor_proto.int64_val.extend([np.asscalar(x) for x in proto_values]) - def SlowAppendComplexArrayToTensorProto(tensor_proto, proto_values): + def SlowAppendComplex64ArrayToTensorProto(tensor_proto, proto_values): tensor_proto.scomplex_val.extend([np.asscalar(v) for x in proto_values for v in [x.real, x.imag]]) + def SlowAppendComplex128ArrayToTensorProto(tensor_proto, proto_values): + tensor_proto.dcomplex_val.extend([np.asscalar(v) + for x in proto_values + for v in [x.real, x.imag]]) + def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values): tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values]) @@ -96,8 +101,8 @@ else: np.uint16: SlowAppendIntArrayToTensorProto, np.int16: SlowAppendIntArrayToTensorProto, np.int8: SlowAppendIntArrayToTensorProto, - np.complex64: SlowAppendComplexArrayToTensorProto, - np.complex128: SlowAppendComplexArrayToTensorProto, + np.complex64: SlowAppendComplex64ArrayToTensorProto, + np.complex128: SlowAppendComplex128ArrayToTensorProto, np.object: SlowAppendObjectArrayToTensorProto, np.bool: SlowAppendBoolArrayToTensorProto, dtypes.qint8.as_numpy_dtype: SlowAppendIntArrayToTensorProto, @@ -240,6 +245,7 @@ _TF_TO_IS_OK = { dtypes.int8: _FilterInt, dtypes.string: _FilterStr, dtypes.complex64: _FilterComplex, + dtypes.complex128: _FilterComplex, dtypes.int64: _FilterInt, dtypes.bool: _FilterBool, dtypes.qint32: _FilterInt, @@ -453,6 +459,15 @@ def MakeNdarray(tensor): else: return np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype).reshape(shape) + elif tensor_dtype == dtypes.complex128: + it = iter(tensor.dcomplex_val) + if len(tensor.dcomplex_val) == 2: + return np.repeat(np.array(complex(tensor.dcomplex_val[0], + tensor.dcomplex_val[1]), dtype=dtype), + num_elements).reshape(shape) + else: + return np.array([complex(x[0], x[1]) for x in zip(it, it)], + dtype=dtype).reshape(shape) elif tensor_dtype == dtypes.bool: if len(tensor.bool_val) == 1: return np.repeat(np.array(tensor.bool_val[0], dtype=dtype), |