aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/context.h2
-rw-r--r--tensorflow/contrib/lite/interpreter.cc7
-rw-r--r--tensorflow/contrib/lite/interpreter.h5
-rw-r--r--tensorflow/contrib/lite/kernels/cast.cc23
-rw-r--r--tensorflow/contrib/lite/kernels/cast_test.cc67
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor.h15
-rw-r--r--tensorflow/contrib/lite/model.cc3
-rw-r--r--tensorflow/contrib/lite/optional_debug_tools.cc2
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc4
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs1
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h9
-rw-r--r--tensorflow/contrib/lite/toco/model.h19
-rw-r--r--tensorflow/contrib/lite/toco/tflite/types.cc8
-rw-r--r--tensorflow/contrib/lite/toco/tflite/types_test.cc13
14 files changed, 166 insertions, 12 deletions
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index 6434e265b1..1265c4cba9 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -139,6 +139,7 @@ typedef enum {
kTfLiteString = 5,
kTfLiteBool = 6,
kTfLiteInt16 = 7,
+ kTfLiteComplex64 = 8,
} TfLiteType;
// Parameters for asymmetric quantization. Quantized values can be converted
@@ -159,6 +160,7 @@ typedef union {
uint8_t* uint8;
bool* b;
int16_t* i16;
+ _Complex float* c64;
} TfLitePtrUnion;
// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 57b2c0f32b..62a0b1ff08 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -359,10 +359,13 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims,
case kTfLiteBool:
*bytes = sizeof(bool) * count;
break;
+ case kTfLiteComplex64:
+ *bytes = sizeof(std::complex<float>) * count;
+ break;
default:
ReportError(&context_,
- "Only float32, int16, int32, int64, uint8, bool supported "
- "currently.");
+ "Only float32, int16, int32, int64, uint8, bool, complex64 "
+ "supported currently.");
return kTfLiteError;
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index e67543671b..033b8ee5fa 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -17,6 +17,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_
#define TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_
+#include <complex>
#include <cstdio>
#include <cstdlib>
#include <vector>
@@ -58,6 +59,10 @@ template <>
constexpr TfLiteType typeToTfLiteType<bool>() {
return kTfLiteBool;
}
+template <>
+constexpr TfLiteType typeToTfLiteType<std::complex<float>>() {
+ return kTfLiteComplex64;
+}
// Forward declare since NNAPIDelegate uses Interpreter.
class NNAPIDelegate;
diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc
index 60770ca0aa..8dd48af57f 100644
--- a/tensorflow/contrib/lite/kernels/cast.cc
+++ b/tensorflow/contrib/lite/kernels/cast.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <algorithm>
+#include <complex>
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
@@ -53,6 +54,20 @@ void copyCast(const FromT* in, ToT* out, int num_elements) {
[](FromT a) { return static_cast<ToT>(a); });
}
+template <typename ToT>
+void copyCast(const std::complex<float>* in, ToT* out, int num_elements) {
+ std::transform(in, in + num_elements, out, [](std::complex<float> a) {
+ return static_cast<ToT>(std::real(a));
+ });
+}
+
+template <>
+void copyCast(const std::complex<float>* in, std::complex<float>* out,
+ int num_elements) {
+ std::transform(in, in + num_elements, out,
+ [](std::complex<float> a) { return a; });
+}
+
template <typename FromT>
TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out,
int num_elements) {
@@ -72,6 +87,10 @@ TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out,
case kTfLiteBool:
copyCast(in, out->data.b, num_elements);
break;
+ case kTfLiteComplex64:
+ copyCast(in, reinterpret_cast<std::complex<float>*>(out->data.c64),
+ num_elements);
+ break;
default:
// Unsupported type.
return kTfLiteError;
@@ -95,6 +114,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return copyToTensor(input->data.f, output, num_elements);
case kTfLiteBool:
return copyToTensor(input->data.b, output, num_elements);
+ case kTfLiteComplex64:
+ return copyToTensor(
+ reinterpret_cast<std::complex<float>*>(input->data.c64), output,
+ num_elements);
default:
// Unsupported type.
return kTfLiteError;
diff --git a/tensorflow/contrib/lite/kernels/cast_test.cc b/tensorflow/contrib/lite/kernels/cast_test.cc
index 53e2000737..954f998206 100644
--- a/tensorflow/contrib/lite/kernels/cast_test.cc
+++ b/tensorflow/contrib/lite/kernels/cast_test.cc
@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <complex>
+
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
@@ -73,6 +75,71 @@ TEST(CastOpModel, CastBoolToFloat) {
ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f}));
}
+TEST(CastOpModel, CastComplex64ToFloat) {
+ CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
+ m.PopulateTensor<std::complex<float>>(
+ m.input(),
+ {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
+ std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
+ std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}));
+}
+
+TEST(CastOpModel, CastFloatToComplex64) {
+ CastOpModel m({TensorType_FLOAT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
+ m.PopulateTensor<float>(m.input(), {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+ m.Invoke();
+ EXPECT_THAT(
+ m.ExtractVector<std::complex<float>>(m.output()),
+ ElementsAreArray(
+ {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f),
+ std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f),
+ std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)}));
+}
+
+TEST(CastOpModel, CastComplex64ToInt) {
+ CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_INT32, {2, 3}});
+ m.PopulateTensor<std::complex<float>>(
+ m.input(),
+ {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
+ std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
+ std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<int>(m.output()),
+ ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(CastOpModel, CastIntToComplex64) {
+ CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
+ m.PopulateTensor<int>(m.input(), {1, 2, 3, 4, 5, 6});
+ m.Invoke();
+ EXPECT_THAT(
+ m.ExtractVector<std::complex<float>>(m.output()),
+ ElementsAreArray(
+ {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f),
+ std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f),
+ std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)}));
+}
+
+TEST(CastOpModel, CastComplex64ToComplex64) {
+ CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
+ m.PopulateTensor<std::complex<float>>(
+ m.input(),
+ {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
+ std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
+ std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
+ m.Invoke();
+ EXPECT_THAT(
+ m.ExtractVector<std::complex<float>>(m.output()),
+ ElementsAreArray(
+ {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
+ std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
+ std::complex<float>(5.0f, 15.0f),
+ std::complex<float>(6.0f, 16.0f)}));
+}
+
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
index 518bee1c63..ee2af5b460 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
+#include <complex>
#include <vector>
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
@@ -54,6 +55,13 @@ inline bool* GetTensorData(TfLiteTensor* tensor) {
return tensor != nullptr ? tensor->data.b : nullptr;
}
+template <>
+inline std::complex<float>* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr
+ ? reinterpret_cast<std::complex<float>*>(tensor->data.c64)
+ : nullptr;
+}
+
template <typename T>
inline const T* GetTensorData(const TfLiteTensor* tensor);
@@ -87,6 +95,13 @@ inline const bool* GetTensorData(const TfLiteTensor* tensor) {
return tensor != nullptr ? tensor->data.b : nullptr;
}
+template <>
+inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr
+ ? reinterpret_cast<const std::complex<float>*>(tensor->data.c64)
+ : nullptr;
+}
+
inline int RemapDim(int max_dimensions, int d) {
return max_dimensions - d - 1;
}
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 793a72272d..f54db3af87 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -63,6 +63,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
case TensorType_BOOL:
*type = kTfLiteBool;
break;
+ case TensorType_COMPLEX64:
+ *type = kTfLiteComplex64;
+ break;
default:
error_reporter->Report("Unimplemented data type %s (%d) in tensor\n",
EnumNameTensorType(tensor_type), tensor_type);
diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc
index 99c35b9caf..f1f025f777 100644
--- a/tensorflow/contrib/lite/optional_debug_tools.cc
+++ b/tensorflow/contrib/lite/optional_debug_tools.cc
@@ -52,6 +52,8 @@ const char* TensorTypeName(TfLiteType type) {
return "kTfLiteBool";
case kTfLiteInt16:
return "kTfLiteInt16";
+ case kTfLiteComplex64:
+ return "kTfLiteComplex64";
}
return "(invalid)";
}
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index b283551c45..5554d08fa0 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -92,6 +92,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
return NPY_OBJECT;
case kTfLiteBool:
return NPY_BOOL;
+ case kTfLiteComplex64:
+ return NPY_COMPLEX64;
case kTfLiteNoType:
return -1;
}
@@ -118,6 +120,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
case NPY_STRING:
case NPY_UNICODE:
return kTfLiteString;
+ case NPY_COMPLEX64:
+ return kTfLiteComplex64;
}
LOG(ERROR) << "Unknown PyArray dtype " << pyarray_type;
return kTfLiteNoType;
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 76ad3ef893..15fb8bbdb8 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -35,6 +35,7 @@ enum TensorType : byte {
STRING = 5,
BOOL = 6,
INT16 = 7,
+ COMPLEX64 = 8,
}
// Parameters for converting a quantized tensor back to float. Given a
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index e3ce90aa55..fe0ff9a7a5 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -223,11 +223,12 @@ enum TensorType {
TensorType_STRING = 5,
TensorType_BOOL = 6,
TensorType_INT16 = 7,
+ TensorType_COMPLEX64 = 8,
TensorType_MIN = TensorType_FLOAT32,
- TensorType_MAX = TensorType_INT16
+ TensorType_MAX = TensorType_COMPLEX64
};
-inline TensorType (&EnumValuesTensorType())[8] {
+inline TensorType (&EnumValuesTensorType())[9] {
static TensorType values[] = {
TensorType_FLOAT32,
TensorType_FLOAT16,
@@ -236,7 +237,8 @@ inline TensorType (&EnumValuesTensorType())[8] {
TensorType_INT64,
TensorType_STRING,
TensorType_BOOL,
- TensorType_INT16
+ TensorType_INT16,
+ TensorType_COMPLEX64
};
return values;
}
@@ -251,6 +253,7 @@ inline const char **EnumNamesTensorType() {
"STRING",
"BOOL",
"INT16",
+ "COMPLEX64",
nullptr
};
return names;
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index aa05a9bd0e..abe0bf3c54 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
#define TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
+#include <complex>
#include <functional>
#include <initializer_list>
#include <memory>
@@ -161,15 +162,16 @@ enum class AxesOrder {
// The type of the scalars in an array.
// Note that the type does not by itself tell whether the values in the array
-// are real (are literally interpreted as real numbers) or quantized (only
-// acquire a meaning as real numbers in conjunction with QuantizationParams).
+// are non-quantized (can be accessed directly) or quantized (must be
+// interpreted in conjunction with QuantizationParams).
//
// In practice though:
-// float values are always real
+// float values are never quantized
// uint8 values are always quantized
-// int32 values are either real or quantized (depending on whether
+// int32 values are sometimes quantized (depending on whether
// QuantizationParams are present).
-// other types are unused at the moment.
+// complex values are never quantized
+// other types are never quantized at the moment.
//
// kNone means that we don't know the data type yet, or that we don't care
// because we'll be dropping the array anyway (e.g. some exotic array types
@@ -187,7 +189,8 @@ enum class ArrayDataType : uint8 {
kUint32,
kInt64,
kUint64, // 10
- kString
+ kString,
+ kComplex64,
};
// Compile-time logic to map ArrayDataType to the corresponding C++ scalar type
@@ -241,6 +244,10 @@ template <>
struct DataTypeImpl<ArrayDataType::kString> {
typedef string Type;
};
+template <>
+struct DataTypeImpl<ArrayDataType::kComplex64> {
+ typedef std::complex<float> Type;
+};
template <ArrayDataType A>
using DataType = typename DataTypeImpl<A>::Type;
diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc
index 42c5d7e8eb..754f0b4b8c 100644
--- a/tensorflow/contrib/lite/toco/tflite/types.cc
+++ b/tensorflow/contrib/lite/toco/tflite/types.cc
@@ -100,6 +100,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
return ::tflite::TensorType_STRING;
case ArrayDataType::kBool:
return ::tflite::TensorType_BOOL;
+ case ArrayDataType::kComplex64:
+ return ::tflite::TensorType_COMPLEX64;
default:
// FLOAT32 is filled for unknown data types.
// TODO(ycling): Implement type inference in TF Lite interpreter.
@@ -123,6 +125,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) {
return ArrayDataType::kUint8;
case ::tflite::TensorType_BOOL:
return ArrayDataType::kBool;
+ case ::tflite::TensorType_COMPLEX64:
+ return ArrayDataType::kComplex64;
default:
LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'.";
}
@@ -147,6 +151,8 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
return CopyBuffer<ArrayDataType::kUint8>(array, builder);
case ArrayDataType::kBool:
return CopyBoolToBuffer(array, builder);
+ case ArrayDataType::kComplex64:
+ return CopyBuffer<ArrayDataType::kComplex64>(array, builder);
default:
LOG(FATAL) << "Unhandled array data type.";
}
@@ -172,6 +178,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
return CopyBuffer<ArrayDataType::kUint8>(buffer, array);
case ::tflite::TensorType_BOOL:
return CopyBuffer<ArrayDataType::kBool>(buffer, array);
+ case ::tflite::TensorType_COMPLEX64:
+ return CopyBuffer<ArrayDataType::kComplex64>(buffer, array);
default:
LOG(FATAL) << "Unhandled tensor type.";
}
diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/contrib/lite/toco/tflite/types_test.cc
index 8c6ef95bfa..8e9f30ba3a 100644
--- a/tensorflow/contrib/lite/toco/tflite/types_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/types_test.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/toco/tflite/types.h"
+#include <complex>
+
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@@ -71,7 +73,8 @@ TEST(DataType, SupportedTypes) {
{ArrayDataType::kInt32, ::tflite::TensorType_INT32},
{ArrayDataType::kInt64, ::tflite::TensorType_INT64},
{ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32},
- {ArrayDataType::kBool, ::tflite::TensorType_BOOL}};
+ {ArrayDataType::kBool, ::tflite::TensorType_BOOL},
+ {ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64}};
for (auto x : testdata) {
EXPECT_EQ(x.second, DataType::Serialize(x.first));
EXPECT_EQ(x.first, DataType::Deserialize(x.second));
@@ -171,6 +174,14 @@ TEST(DataBuffer, Bool) {
::testing::ElementsAre(true, false, true));
}
+TEST(DataBuffer, Complex64) {
+ Array recovered = ToFlatBufferAndBack<ArrayDataType::kComplex64>(
+ {std::complex<float>(1.0f, 2.0f), std::complex<float>(3.0f, 4.0f)});
+ EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kComplex64>().data,
+ ::testing::ElementsAre(std::complex<float>(1.0f, 2.0f),
+ std::complex<float>(3.0f, 4.0f)));
+}
+
TEST(Padding, All) {
EXPECT_EQ(::tflite::Padding_SAME, Padding::Serialize(PaddingType::kSame));
EXPECT_EQ(PaddingType::kSame, Padding::Deserialize(::tflite::Padding_SAME));