diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite/operator.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator.cc | 79 |
1 files changed, 74 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 8377ba6a03..4b2ef756cc 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -290,8 +290,8 @@ class FakeQuant flatbuffers::Offset<TfLiteOptions> WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - return ::tflite::CreateFakeQuantOptions(*builder, op.minmax->min, - op.minmax->max, op.num_bits); + return ::tflite::CreateFakeQuantOptions( + *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range); } void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override { @@ -300,9 +300,13 @@ class FakeQuant minmax->max = options.max(); op->minmax.reset(minmax); op->num_bits = options.num_bits(); + op->narrow_range = options.narrow_range(); } - int GetVersion(const Operator& op) const override { return 1; } + int GetVersion(const Operator& op) const override { + const auto& fq_op = static_cast<const FakeQuantOperator&>(op); + return fq_op.narrow_range ? 2 : 1; + } }; class FullyConnected @@ -366,12 +370,13 @@ class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions, flatbuffers::Offset<TfLiteOptions> WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - return ::tflite::CreateGatherOptions(*builder, op.axis); + int axis = op.axis ? op.axis.value() : 0; + return ::tflite::CreateGatherOptions(*builder, axis); } void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override { - op->axis = options.axis(); + op->axis = {options.axis()}; } int GetVersion(const Operator& op) const override { return 1; } @@ -763,6 +768,44 @@ class Sum int GetVersion(const Operator& op) const override { return 1; } }; +class ReduceMax + : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions, + ::tflite::BuiltinOptions_ReducerOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->keep_dims = options.keep_dims(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class ReduceProd + : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions, + ::tflite::BuiltinOptions_ReducerOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->keep_dims = options.keep_dims(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class ResizeBilinear : public BuiltinOperator<ResizeBilinearOperator, ::tflite::ResizeBilinearOptions, @@ -970,6 +1013,26 @@ class ExpandDims int GetVersion(const Operator& op) const override { return 1; } }; +class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions, + ::tflite::BuiltinOptions_PackOptions> { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->values_count = options.values_count(); + op->axis = options.axis(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class Shape : public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions, ::tflite::BuiltinOptions_ShapeOptions> { @@ -1179,6 +1242,10 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { ops.emplace_back( new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean)); ops.emplace_back(new Sum(::tflite::BuiltinOperator_SUM, OperatorType::kSum)); + ops.emplace_back(new ReduceProd(::tflite::BuiltinOperator_REDUCE_PROD, + OperatorType::kReduceProd)); + ops.emplace_back(new ReduceMax(::tflite::BuiltinOperator_REDUCE_MAX, + OperatorType::kReduceMax)); ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR, OperatorType::kResizeBilinear)); ops.emplace_back( @@ -1209,6 +1276,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { new Shape(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape)); ops.emplace_back(new FakeQuant(::tflite::BuiltinOperator_FAKE_QUANT, OperatorType::kFakeQuant)); + ops.emplace_back( + new Pack(::tflite::BuiltinOperator_PACK, OperatorType::kPack)); // Custom Operators. ops.emplace_back( |