diff options
-rw-r--r-- | tensorflow/contrib/lite/context.h | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/interpreter.cc | 7 | ||||
-rw-r--r-- | tensorflow/contrib/lite/interpreter.h | 5 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/cast.cc | 23 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/cast_test.cc | 67 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/tensor.h | 15 | ||||
-rw-r--r-- | tensorflow/contrib/lite/model.cc | 3 | ||||
-rw-r--r-- | tensorflow/contrib/lite/optional_debug_tools.cc | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc | 4 | ||||
-rw-r--r-- | tensorflow/contrib/lite/schema/schema.fbs | 1 | ||||
-rwxr-xr-x | tensorflow/contrib/lite/schema/schema_generated.h | 9 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/model.h | 19 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/types.cc | 8 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/types_test.cc | 13 |
14 files changed, 166 insertions, 12 deletions
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index 6434e265b1..1265c4cba9 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -139,6 +139,7 @@ typedef enum { kTfLiteString = 5, kTfLiteBool = 6, kTfLiteInt16 = 7, + kTfLiteComplex64 = 8, } TfLiteType; // Parameters for asymmetric quantization. Quantized values can be converted @@ -159,6 +160,7 @@ typedef union { uint8_t* uint8; bool* b; int16_t* i16; + _Complex float* c64; } TfLitePtrUnion; // Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 57b2c0f32b..62a0b1ff08 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -359,10 +359,13 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, case kTfLiteBool: *bytes = sizeof(bool) * count; break; + case kTfLiteComplex64: + *bytes = sizeof(std::complex<float>) * count; + break; default: ReportError(&context_, - "Only float32, int16, int32, int64, uint8, bool supported " - "currently."); + "Only float32, int16, int32, int64, uint8, bool, complex64 " + "supported currently."); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index e67543671b..033b8ee5fa 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -17,6 +17,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ #define TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ +#include <complex> #include <cstdio> #include <cstdlib> #include <vector> @@ -58,6 +59,10 @@ template <> constexpr TfLiteType typeToTfLiteType<bool>() { return kTfLiteBool; } +template <> +constexpr TfLiteType typeToTfLiteType<std::complex<float>>() { + return kTfLiteComplex64; +} // Forward declare since NNAPIDelegate uses Interpreter. class NNAPIDelegate; diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc index 60770ca0aa..8dd48af57f 100644 --- a/tensorflow/contrib/lite/kernels/cast.cc +++ b/tensorflow/contrib/lite/kernels/cast.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <algorithm> +#include <complex> #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" @@ -53,6 +54,20 @@ void copyCast(const FromT* in, ToT* out, int num_elements) { [](FromT a) { return static_cast<ToT>(a); }); } +template <typename ToT> +void copyCast(const std::complex<float>* in, ToT* out, int num_elements) { + std::transform(in, in + num_elements, out, [](std::complex<float> a) { + return static_cast<ToT>(std::real(a)); + }); +} + +template <> +void copyCast(const std::complex<float>* in, std::complex<float>* out, + int num_elements) { + std::transform(in, in + num_elements, out, + [](std::complex<float> a) { return a; }); +} + template <typename FromT> TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out, int num_elements) { @@ -72,6 +87,10 @@ TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out, case kTfLiteBool: copyCast(in, out->data.b, num_elements); break; + case kTfLiteComplex64: + copyCast(in, reinterpret_cast<std::complex<float>*>(out->data.c64), + num_elements); + break; default: // Unsupported type. return kTfLiteError; @@ -95,6 +114,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return copyToTensor(input->data.f, output, num_elements); case kTfLiteBool: return copyToTensor(input->data.b, output, num_elements); + case kTfLiteComplex64: + return copyToTensor( + reinterpret_cast<std::complex<float>*>(input->data.c64), output, + num_elements); default: // Unsupported type. return kTfLiteError; diff --git a/tensorflow/contrib/lite/kernels/cast_test.cc b/tensorflow/contrib/lite/kernels/cast_test.cc index 53e2000737..954f998206 100644 --- a/tensorflow/contrib/lite/kernels/cast_test.cc +++ b/tensorflow/contrib/lite/kernels/cast_test.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <complex> + #include <gtest/gtest.h> #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" @@ -73,6 +75,71 @@ TEST(CastOpModel, CastBoolToFloat) { ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f})); } +TEST(CastOpModel, CastComplex64ToFloat) { + CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_FLOAT32, {2, 3}}); + m.PopulateTensor<std::complex<float>>( + m.input(), + {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f), + std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f), + std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector<float>(m.output()), + ElementsAreArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); +} + +TEST(CastOpModel, CastFloatToComplex64) { + CastOpModel m({TensorType_FLOAT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); + m.PopulateTensor<float>(m.input(), {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + m.Invoke(); + EXPECT_THAT( + m.ExtractVector<std::complex<float>>(m.output()), + ElementsAreArray( + {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f), + std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f), + std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)})); +} + +TEST(CastOpModel, CastComplex64ToInt) { + CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_INT32, {2, 3}}); + m.PopulateTensor<std::complex<float>>( + m.input(), + {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f), + std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f), + std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector<int>(m.output()), + ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(CastOpModel, CastIntToComplex64) { + CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); + m.PopulateTensor<int>(m.input(), {1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT( + m.ExtractVector<std::complex<float>>(m.output()), + ElementsAreArray( + {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f), + std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f), + std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)})); +} + +TEST(CastOpModel, CastComplex64ToComplex64) { + CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); + m.PopulateTensor<std::complex<float>>( + m.input(), + {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f), + std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f), + std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)}); + m.Invoke(); + EXPECT_THAT( + m.ExtractVector<std::complex<float>>(m.output()), + ElementsAreArray( + {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f), + std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f), + std::complex<float>(5.0f, 15.0f), + std::complex<float>(6.0f, 16.0f)})); +} + } // namespace } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h index 518bee1c63..ee2af5b460 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ +#include <complex> #include <vector> #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/internal/types.h" @@ -54,6 +55,13 @@ inline bool* GetTensorData(TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.b : nullptr; } +template <> +inline std::complex<float>* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr + ? reinterpret_cast<std::complex<float>*>(tensor->data.c64) + : nullptr; +} + template <typename T> inline const T* GetTensorData(const TfLiteTensor* tensor); @@ -87,6 +95,13 @@ inline const bool* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.b : nullptr; } +template <> +inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr + ? reinterpret_cast<const std::complex<float>*>(tensor->data.c64) + : nullptr; +} + inline int RemapDim(int max_dimensions, int d) { return max_dimensions - d - 1; } diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 793a72272d..f54db3af87 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -63,6 +63,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, case TensorType_BOOL: *type = kTfLiteBool; break; + case TensorType_COMPLEX64: + *type = kTfLiteComplex64; + break; default: error_reporter->Report("Unimplemented data type %s (%d) in tensor\n", EnumNameTensorType(tensor_type), tensor_type); diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc index 99c35b9caf..f1f025f777 100644 --- a/tensorflow/contrib/lite/optional_debug_tools.cc +++ b/tensorflow/contrib/lite/optional_debug_tools.cc @@ -52,6 +52,8 @@ const char* TensorTypeName(TfLiteType type) { return "kTfLiteBool"; case kTfLiteInt16: return "kTfLiteInt16"; + case kTfLiteComplex64: + return "kTfLiteComplex64"; } return "(invalid)"; } diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index b283551c45..5554d08fa0 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -92,6 +92,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { return NPY_OBJECT; case kTfLiteBool: return NPY_BOOL; + case kTfLiteComplex64: + return NPY_COMPLEX64; case kTfLiteNoType: return -1; } @@ -118,6 +120,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { case NPY_STRING: case NPY_UNICODE: return kTfLiteString; + case NPY_COMPLEX64: + return kTfLiteComplex64; } LOG(ERROR) << "Unknown PyArray dtype " << pyarray_type; return kTfLiteNoType; diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 76ad3ef893..15fb8bbdb8 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -35,6 +35,7 @@ enum TensorType : byte { STRING = 5, BOOL = 6, INT16 = 7, + COMPLEX64 = 8, } // Parameters for converting a quantized tensor back to float. Given a diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index e3ce90aa55..fe0ff9a7a5 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -223,11 +223,12 @@ enum TensorType { TensorType_STRING = 5, TensorType_BOOL = 6, TensorType_INT16 = 7, + TensorType_COMPLEX64 = 8, TensorType_MIN = TensorType_FLOAT32, - TensorType_MAX = TensorType_INT16 + TensorType_MAX = TensorType_COMPLEX64 }; -inline TensorType (&EnumValuesTensorType())[8] { +inline TensorType (&EnumValuesTensorType())[9] { static TensorType values[] = { TensorType_FLOAT32, TensorType_FLOAT16, @@ -236,7 +237,8 @@ inline TensorType (&EnumValuesTensorType())[8] { TensorType_INT64, TensorType_STRING, TensorType_BOOL, - TensorType_INT16 + TensorType_INT16, + TensorType_COMPLEX64 }; return values; } @@ -251,6 +253,7 @@ inline const char **EnumNamesTensorType() { "STRING", "BOOL", "INT16", + "COMPLEX64", nullptr }; return names; diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index aa05a9bd0e..abe0bf3c54 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ #define TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ +#include <complex> #include <functional> #include <initializer_list> #include <memory> @@ -161,15 +162,16 @@ enum class AxesOrder { // The type of the scalars in an array. // Note that the type does not by itself tell whether the values in the array -// are real (are literally interpreted as real numbers) or quantized (only -// acquire a meaning as real numbers in conjunction with QuantizationParams). +// are non-quantized (can be accessed directly) or quantized (must be +// interpreted in conjunction with QuantizationParams). // // In practice though: -// float values are always real +// float values are never quantized // uint8 values are always quantized -// int32 values are either real or quantized (depending on whether +// int32 values are sometimes quantized (depending on whether // QuantizationParams are present). -// other types are unused at the moment. +// complex values are never quantized +// other types are never quantized at the moment. // // kNone means that we don't know the data type yet, or that we don't care // because we'll be dropping the array anyway (e.g. some exotic array types @@ -187,7 +189,8 @@ enum class ArrayDataType : uint8 { kUint32, kInt64, kUint64, // 10 - kString + kString, + kComplex64, }; // Compile-time logic to map ArrayDataType to the corresponding C++ scalar type @@ -241,6 +244,10 @@ template <> struct DataTypeImpl<ArrayDataType::kString> { typedef string Type; }; +template <> +struct DataTypeImpl<ArrayDataType::kComplex64> { + typedef std::complex<float> Type; +}; template <ArrayDataType A> using DataType = typename DataTypeImpl<A>::Type; diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc index 42c5d7e8eb..754f0b4b8c 100644 --- a/tensorflow/contrib/lite/toco/tflite/types.cc +++ b/tensorflow/contrib/lite/toco/tflite/types.cc @@ -100,6 +100,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) { return ::tflite::TensorType_STRING; case ArrayDataType::kBool: return ::tflite::TensorType_BOOL; + case ArrayDataType::kComplex64: + return ::tflite::TensorType_COMPLEX64; default: // FLOAT32 is filled for unknown data types. // TODO(ycling): Implement type inference in TF Lite interpreter. @@ -123,6 +125,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) { return ArrayDataType::kUint8; case ::tflite::TensorType_BOOL: return ArrayDataType::kBool; + case ::tflite::TensorType_COMPLEX64: + return ArrayDataType::kComplex64; default: LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'."; } @@ -147,6 +151,8 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize( return CopyBuffer<ArrayDataType::kUint8>(array, builder); case ArrayDataType::kBool: return CopyBoolToBuffer(array, builder); + case ArrayDataType::kComplex64: + return CopyBuffer<ArrayDataType::kComplex64>(array, builder); default: LOG(FATAL) << "Unhandled array data type."; } @@ -172,6 +178,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor, return CopyBuffer<ArrayDataType::kUint8>(buffer, array); case ::tflite::TensorType_BOOL: return CopyBuffer<ArrayDataType::kBool>(buffer, array); + case ::tflite::TensorType_COMPLEX64: + return CopyBuffer<ArrayDataType::kComplex64>(buffer, array); default: LOG(FATAL) << "Unhandled tensor type."; } diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/contrib/lite/toco/tflite/types_test.cc index 8c6ef95bfa..8e9f30ba3a 100644 --- a/tensorflow/contrib/lite/toco/tflite/types_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/types_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/toco/tflite/types.h" +#include <complex> + #include <gmock/gmock.h> #include <gtest/gtest.h> @@ -71,7 +73,8 @@ TEST(DataType, SupportedTypes) { {ArrayDataType::kInt32, ::tflite::TensorType_INT32}, {ArrayDataType::kInt64, ::tflite::TensorType_INT64}, {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32}, - {ArrayDataType::kBool, ::tflite::TensorType_BOOL}}; + {ArrayDataType::kBool, ::tflite::TensorType_BOOL}, + {ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64}}; for (auto x : testdata) { EXPECT_EQ(x.second, DataType::Serialize(x.first)); EXPECT_EQ(x.first, DataType::Deserialize(x.second)); @@ -171,6 +174,14 @@ TEST(DataBuffer, Bool) { ::testing::ElementsAre(true, false, true)); } +TEST(DataBuffer, Complex64) { + Array recovered = ToFlatBufferAndBack<ArrayDataType::kComplex64>( + {std::complex<float>(1.0f, 2.0f), std::complex<float>(3.0f, 4.0f)}); + EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kComplex64>().data, + ::testing::ElementsAre(std::complex<float>(1.0f, 2.0f), + std::complex<float>(3.0f, 4.0f))); +} + TEST(Padding, All) { EXPECT_EQ(::tflite::Padding_SAME, Padding::Serialize(PaddingType::kSame)); EXPECT_EQ(PaddingType::kSame, Padding::Deserialize(::tflite::Padding_SAME)); |