diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-04 07:45:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-04 07:50:20 -0700 |
commit | 3fbec92fa4997fc834807b57b229fa9ada179f6c (patch) | |
tree | 34a7c15af0428e5ecf2ee2deb16e7c26b3d4320b | |
parent | 71b19430e8484b136e0b872f6a543aff8a242587 (diff) |
Turn Cast into a proper builtin operator.
PiperOrigin-RevId: 191590230
-rw-r--r-- | tensorflow/contrib/lite/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/lite/builtin_op_data.h | 7 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/cast.cc | 9 | ||||
-rw-r--r-- | tensorflow/contrib/lite/model.cc | 70 | ||||
-rw-r--r-- | tensorflow/contrib/lite/schema/schema.fbs | 2 | ||||
-rwxr-xr-x | tensorflow/contrib/lite/schema/schema_generated.h | 38 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator.cc | 26 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator_test.cc | 2 |
8 files changed, 117 insertions, 38 deletions
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index ac269d540a..9c4533079c 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -89,6 +89,7 @@ cc_library( hdrs = [ "builtin_op_data.h", ], + deps = [":context"], ) cc_library( diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 5fc8954743..2b6c24768c 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -17,6 +17,8 @@ limitations under the License. #include <stdint.h> +#include "tensorflow/contrib/lite/context.h" + #ifdef __cplusplus extern "C" { #endif // __cplusplus @@ -174,6 +176,11 @@ typedef struct { int block_size; } TfLiteSpaceToDepthParams; +typedef struct { + TfLiteType in_data_type; + TfLiteType out_data_type; +} TfLiteCastParams; + typedef enum { kTfLiteCombinerTypeSum = 0, kTfLiteCombinerTypeMean = 1, diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc index 19942de7bc..17ef2c572e 100644 --- a/tensorflow/contrib/lite/kernels/cast.cc +++ b/tensorflow/contrib/lite/kernels/cast.cc @@ -34,6 +34,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // TODO(ahentz): these two checks would make the new implementation + // incompatible with some existing models, where params is not specified. It + // is OK not to have them because toco would have set input and output types + // to match the parameters. + // auto* params = reinterpret_cast<TfLiteCastParams*>(node->builtin_data); + // TF_LITE_ENSURE_EQ(context, input->type, params->in_data_type); + // TF_LITE_ENSURE_EQ(context, output->type, params->out_data_type); + return context->ResizeTensor(context, output, TfLiteIntArrayCopy(input->dims)); } diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 791d1378f3..606f4a5635 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -32,6 +32,32 @@ namespace tflite { const char* kEmptyTensorName = ""; +TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, + ErrorReporter* error_reporter) { + switch (tensor_type) { + case TensorType_FLOAT32: + *type = kTfLiteFloat32; + break; + case TensorType_INT32: + *type = kTfLiteInt32; + break; + case TensorType_UINT8: + *type = kTfLiteUInt8; + break; + case TensorType_INT64: + *type = kTfLiteInt64; + break; + case TensorType_STRING: + *type = kTfLiteString; + break; + default: + error_reporter->Report("Unimplemented data type %s (%d) in tensor\n", + EnumNameTensorType(tensor_type), tensor_type); + return kTfLiteError; + } + return kTfLiteOk; +} + // Loads a model from `filename`. If `mmap_file` is true then use mmap, // otherwise make a copy of the model in a buffer. std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename, @@ -307,10 +333,25 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_EXP: case BuiltinOperator_TOPK_V2: case BuiltinOperator_LOG_SOFTMAX: - case BuiltinOperator_CAST: case BuiltinOperator_DEQUANTIZE: case BuiltinOperator_PRELU: break; + case BuiltinOperator_CAST: { + TfLiteCastParams* params = MallocPOD<TfLiteCastParams>(); + if (auto* schema_params = op->builtin_options_as_CastOptions()) { + auto in_status = + ConvertTensorType(schema_params->in_data_type(), + ¶ms->in_data_type, error_reporter); + auto out_status = + ConvertTensorType(schema_params->out_data_type(), + ¶ms->out_data_type, error_reporter); + if (in_status != kTfLiteOk || out_status != kTfLiteOk) { + break; + } + } + builtin_data = reinterpret_cast<void*>(params); + break; + } case BuiltinOperator_LSH_PROJECTION: { TfLiteLSHProjectionParams* params = MallocPOD<TfLiteLSHProjectionParams>(); @@ -707,29 +748,10 @@ TfLiteStatus InterpreterBuilder::ParseTensors( } TfLiteType type; - switch (tensor->type()) { - case TensorType_FLOAT32: - type = kTfLiteFloat32; - break; - case TensorType_INT32: - type = kTfLiteInt32; - break; - case TensorType_UINT8: - type = kTfLiteUInt8; - break; - case TensorType_INT64: - type = kTfLiteInt64; - break; - case TensorType_STRING: - type = kTfLiteString; - break; - default: - // tensorType = ArrayType::NONE; - error_reporter_->Report("Unimplemented data type %s (%d) in tensor\n", - EnumNameTensorType(tensor->type()), - tensor->type()); - status = kTfLiteError; - continue; + if (ConvertTensorType(tensor->type(), &type, error_reporter_) != + kTfLiteOk) { + status = kTfLiteError; + continue; } auto get_readonly_data = [&](const char** buffer_data, size_t* buffer_size) { diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 7d2e00fe32..c63bfb28cc 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -381,6 +381,8 @@ table LogSoftmaxOptions { } table CastOptions { + in_data_type: TensorType; + out_data_type: TensorType; } table DequantizeOptions { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 66a97a1460..0735be5c8f 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -3702,14 +3702,30 @@ flatbuffers::Offset<LogSoftmaxOptions> CreateLogSoftmaxOptions(flatbuffers::Flat struct CastOptionsT : public flatbuffers::NativeTable { typedef CastOptions TableType; - CastOptionsT() { + TensorType in_data_type; + TensorType out_data_type; + CastOptionsT() + : in_data_type(TensorType_FLOAT32), + out_data_type(TensorType_FLOAT32) { } }; struct CastOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef CastOptionsT NativeTableType; + enum { + VT_IN_DATA_TYPE = 4, + VT_OUT_DATA_TYPE = 6 + }; + TensorType in_data_type() const { + return static_cast<TensorType>(GetField<int8_t>(VT_IN_DATA_TYPE, 0)); + } + TensorType out_data_type() const { + return static_cast<TensorType>(GetField<int8_t>(VT_OUT_DATA_TYPE, 0)); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && + VerifyField<int8_t>(verifier, VT_IN_DATA_TYPE) && + VerifyField<int8_t>(verifier, VT_OUT_DATA_TYPE) && verifier.EndTable(); } CastOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -3720,6 +3736,12 @@ struct CastOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct CastOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; + void add_in_data_type(TensorType in_data_type) { + fbb_.AddElement<int8_t>(CastOptions::VT_IN_DATA_TYPE, static_cast<int8_t>(in_data_type), 0); + } + void add_out_data_type(TensorType out_data_type) { + fbb_.AddElement<int8_t>(CastOptions::VT_OUT_DATA_TYPE, static_cast<int8_t>(out_data_type), 0); + } explicit CastOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -3733,8 +3755,12 @@ struct CastOptionsBuilder { }; inline flatbuffers::Offset<CastOptions> CreateCastOptions( - flatbuffers::FlatBufferBuilder &_fbb) { + flatbuffers::FlatBufferBuilder &_fbb, + TensorType in_data_type = TensorType_FLOAT32, + TensorType out_data_type = TensorType_FLOAT32) { CastOptionsBuilder builder_(_fbb); + builder_.add_out_data_type(out_data_type); + builder_.add_in_data_type(in_data_type); return builder_.Finish(); } @@ -5727,6 +5753,8 @@ inline CastOptionsT *CastOptions::UnPack(const flatbuffers::resolver_function_t inline void CastOptions::UnPackTo(CastOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; + { auto _e = in_data_type(); _o->in_data_type = _e; }; + { auto _e = out_data_type(); _o->out_data_type = _e; }; } inline flatbuffers::Offset<CastOptions> CastOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -5737,8 +5765,12 @@ inline flatbuffers::Offset<CastOptions> CreateCastOptions(flatbuffers::FlatBuffe (void)_rehasher; (void)_o; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CastOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _in_data_type = _o->in_data_type; + auto _out_data_type = _o->out_data_type; return tflite::CreateCastOptions( - _fbb); + _fbb, + _in_data_type, + _out_data_type); } inline DequantizeOptionsT *DequantizeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 0cb348bda5..f991529569 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -204,17 +204,22 @@ class BatchToSpaceND TocoOperator* op) const override {} }; -class Cast : public CustomOperator<CastOperator> { +class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions, + ::tflite::BuiltinOptions_CastOptions> { public: - using CustomOperator::CustomOperator; - void WriteOptions(const TocoOperator& op, - flexbuffers::Builder* fbb) const override { - fbb->Int("src_data_type", DataType::Serialize(op.src_data_type)); - fbb->Int("dst_data_type", DataType::Serialize(op.dst_data_type)); + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateCastOptions(*builder, + DataType::Serialize(op.src_data_type), + DataType::Serialize(op.dst_data_type)); } - void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { - op->src_data_type = DataType::Deserialize(m["src_data_type"].AsInt64()); - op->dst_data_type = DataType::Deserialize(m["dst_data_type"].AsInt64()); + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->src_data_type = DataType::Deserialize(options.in_data_type()); + op->dst_data_type = DataType::Deserialize(options.out_data_type()); } }; @@ -827,9 +832,10 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { new TopK_V2(::tflite::BuiltinOperator_TOPK_V2, OperatorType::kTopK_V2)); ops.emplace_back( new Lstm(::tflite::BuiltinOperator_LSTM, OperatorType::kLstmCell)); + ops.emplace_back( + new Cast(::tflite::BuiltinOperator_CAST, OperatorType::kCast)); // Custom Operators. - ops.emplace_back(new Cast("CAST", OperatorType::kCast)); ops.emplace_back( new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant)); diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index f7a213ecfc..4783843b7f 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -131,7 +131,7 @@ TEST_F(OperatorTest, BuiltinMean) { EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims); } -TEST_F(OperatorTest, CustomCast) { +TEST_F(OperatorTest, BuiltinCast) { CastOperator op; op.src_data_type = ArrayDataType::kFloat; op.dst_data_type = ArrayDataType::kUint8; |