diff options
Diffstat (limited to 'tensorflow/core/framework/tensor.cc')
-rw-r--r-- | tensorflow/core/framework/tensor.cc | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index e701b66319..e56db2af8c 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -216,6 +216,22 @@ struct ProtoHelper<complex64> { }; template <> +struct ProtoHelper<complex128> { + typedef Helper<double>::RepeatedFieldType FieldType; + static const complex128* Begin(const TensorProto& proto) { + return reinterpret_cast<const complex128*>(proto.dcomplex_val().data()); + } + static size_t NumElements(const TensorProto& proto) { + return proto.dcomplex_val().size() / 2; + } + static void Fill(const complex128* data, size_t n, TensorProto* proto) { + const double* p = reinterpret_cast<const double*>(data); + FieldType copy(p, p + n * 2); + proto->mutable_dcomplex_val()->Swap(©); + } +}; + +template <> struct ProtoHelper<qint32> { typedef Helper<int32>::RepeatedFieldType FieldType; static const qint32* Begin(const TensorProto& proto) { @@ -385,6 +401,7 @@ void Tensor::UnsafeCopyFromInternal(const Tensor& other, CASE(int8, SINGLE_ARG(STMTS)) \ CASE(string, SINGLE_ARG(STMTS)) \ CASE(complex64, SINGLE_ARG(STMTS)) \ + CASE(complex128, SINGLE_ARG(STMTS)) \ CASE(int64, SINGLE_ARG(STMTS)) \ CASE(bool, SINGLE_ARG(STMTS)) \ CASE(qint32, SINGLE_ARG(STMTS)) \ |