diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite/operator.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator.cc | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 4b2ef756cc..9380168f30 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -1053,6 +1053,23 @@ class Shape int GetVersion(const Operator& op) const override { return 1; } }; +class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions, + ::tflite::BuiltinOptions_OneHotOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateOneHotOptions(*builder, op.axis); + } + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->axis = options.axis(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class TensorFlowUnsupported : public BaseOperator { public: using BaseOperator::BaseOperator; @@ -1278,6 +1295,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { OperatorType::kFakeQuant)); ops.emplace_back( new Pack(::tflite::BuiltinOperator_PACK, OperatorType::kPack)); + ops.emplace_back( + new OneHot(::tflite::BuiltinOperator_ONE_HOT, OperatorType::kOneHot)); // Custom Operators. ops.emplace_back( @@ -1331,6 +1350,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { ops.emplace_back( new SimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice)); ops.emplace_back(new SimpleOperator<PowOperator>("POW", OperatorType::kPow)); + ops.emplace_back(new SimpleOperator<LogicalOrOperator>( + "LOGICAL_OR", OperatorType::kLogicalOr)); // Element-wise operator ops.emplace_back(new SimpleOperator<SinOperator>("SIN", OperatorType::kSin)); ops.emplace_back(new SimpleOperator<LogOperator>("LOG", OperatorType::kLog)); |