diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite/operator.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator.cc | 125 |
1 files changed, 109 insertions, 16 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 7e55ae92bd..4b2ef756cc 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -282,25 +282,31 @@ class DepthToSpace : public CustomOperator<DepthToSpaceOperator> { int GetVersion(const Operator& op) const override { return 1; } }; -class FakeQuant : public CustomOperator<FakeQuantOperator> { +class FakeQuant + : public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions, + ::tflite::BuiltinOptions_FakeQuantOptions> { public: - using CustomOperator::CustomOperator; - void WriteOptions(const TocoOperator& op, - flexbuffers::Builder* fbb) const override { - fbb->Float("min", op.minmax->min); - fbb->Float("max", op.minmax->max); - fbb->Int("num_bits", op.num_bits); + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateFakeQuantOptions( + *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range); } - void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { auto* minmax = new MinMax; - minmax->min = m["min"].AsFloat(); - minmax->max = m["max"].AsFloat(); + minmax->min = options.min(); + minmax->max = options.max(); op->minmax.reset(minmax); - const auto& num_bits = m["num_bits"]; - op->num_bits = num_bits.IsInt() ? num_bits.AsInt32() : 8; + op->num_bits = options.num_bits(); + op->narrow_range = options.narrow_range(); } - int GetVersion(const Operator& op) const override { return 1; } + int GetVersion(const Operator& op) const override { + const auto& fq_op = static_cast<const FakeQuantOperator&>(op); + return fq_op.narrow_range ? 2 : 1; + } }; class FullyConnected @@ -364,12 +370,13 @@ class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions, flatbuffers::Offset<TfLiteOptions> WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - return ::tflite::CreateGatherOptions(*builder, op.axis); + int axis = op.axis ? op.axis.value() : 0; + return ::tflite::CreateGatherOptions(*builder, axis); } void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override { - op->axis = options.axis(); + op->axis = {options.axis()}; } int GetVersion(const Operator& op) const override { return 1; } @@ -761,6 +768,44 @@ class Sum int GetVersion(const Operator& op) const override { return 1; } }; +class ReduceMax + : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions, + ::tflite::BuiltinOptions_ReducerOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->keep_dims = options.keep_dims(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class ReduceProd + : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions, + ::tflite::BuiltinOptions_ReducerOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->keep_dims = options.keep_dims(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class ResizeBilinear : public BuiltinOperator<ResizeBilinearOperator, ::tflite::ResizeBilinearOptions, @@ -885,6 +930,25 @@ class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions, int GetVersion(const Operator& op) const override { return 1; } }; +class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions, + ::tflite::BuiltinOptions_ArgMinOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateArgMinOptions( + *builder, DataType::Serialize(op.output_data_type)); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->output_data_type = DataType::Deserialize(options.output_type()); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class TransposeConv : public BuiltinOperator<TransposeConvOperator, ::tflite::TransposeConvOptions, @@ -949,6 +1013,26 @@ class ExpandDims int GetVersion(const Operator& op) const override { return 1; } }; +class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions, + ::tflite::BuiltinOptions_PackOptions> { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->values_count = options.values_count(); + op->axis = options.axis(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class Shape : public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions, ::tflite::BuiltinOptions_ShapeOptions> { @@ -1158,6 +1242,10 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { ops.emplace_back( new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean)); ops.emplace_back(new Sum(::tflite::BuiltinOperator_SUM, OperatorType::kSum)); + ops.emplace_back(new ReduceProd(::tflite::BuiltinOperator_REDUCE_PROD, + OperatorType::kReduceProd)); + ops.emplace_back(new ReduceMax(::tflite::BuiltinOperator_REDUCE_MAX, + OperatorType::kReduceMax)); ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR, OperatorType::kResizeBilinear)); ops.emplace_back( @@ -1175,6 +1263,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { ops.emplace_back( new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax)); ops.emplace_back( + new ArgMin(::tflite::BuiltinOperator_ARG_MIN, OperatorType::kArgMin)); + ops.emplace_back( new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTile)); ops.emplace_back(new ExpandDims(::tflite::BuiltinOperator_EXPAND_DIMS, OperatorType::kExpandDims)); @@ -1184,11 +1274,14 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { OperatorType::kSparseToDense)); ops.emplace_back( new Shape(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape)); + ops.emplace_back(new FakeQuant(::tflite::BuiltinOperator_FAKE_QUANT, + OperatorType::kFakeQuant)); + ops.emplace_back( + new Pack(::tflite::BuiltinOperator_PACK, OperatorType::kPack)); // Custom Operators. ops.emplace_back( new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); - ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant)); ops.emplace_back(new TensorFlowUnsupported("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported)); |