aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/tensor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/tensor.cc')
-rw-r--r--tensorflow/core/framework/tensor.cc17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index e701b66319..e56db2af8c 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -216,6 +216,22 @@ struct ProtoHelper<complex64> {
};
template <>
+struct ProtoHelper<complex128> {
+ typedef Helper<double>::RepeatedFieldType FieldType;
+ static const complex128* Begin(const TensorProto& proto) {
+ return reinterpret_cast<const complex128*>(proto.dcomplex_val().data());
+ }
+ static size_t NumElements(const TensorProto& proto) {
+ return proto.dcomplex_val().size() / 2;
+ }
+ static void Fill(const complex128* data, size_t n, TensorProto* proto) {
+ const double* p = reinterpret_cast<const double*>(data);
+ FieldType copy(p, p + n * 2);
+ proto->mutable_dcomplex_val()->Swap(&copy);
+ }
+};
+
+template <>
struct ProtoHelper<qint32> {
typedef Helper<int32>::RepeatedFieldType FieldType;
static const qint32* Begin(const TensorProto& proto) {
@@ -385,6 +401,7 @@ void Tensor::UnsafeCopyFromInternal(const Tensor& other,
CASE(int8, SINGLE_ARG(STMTS)) \
CASE(string, SINGLE_ARG(STMTS)) \
CASE(complex64, SINGLE_ARG(STMTS)) \
+ CASE(complex128, SINGLE_ARG(STMTS)) \
CASE(int64, SINGLE_ARG(STMTS)) \
CASE(bool, SINGLE_ARG(STMTS)) \
CASE(qint32, SINGLE_ARG(STMTS)) \