aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tflite/operator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite/operator.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc79
1 files changed, 74 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 8377ba6a03..4b2ef756cc 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -290,8 +290,8 @@ class FakeQuant
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
- return ::tflite::CreateFakeQuantOptions(*builder, op.minmax->min,
- op.minmax->max, op.num_bits);
+ return ::tflite::CreateFakeQuantOptions(
+ *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
@@ -300,9 +300,13 @@ class FakeQuant
minmax->max = options.max();
op->minmax.reset(minmax);
op->num_bits = options.num_bits();
+ op->narrow_range = options.narrow_range();
}
- int GetVersion(const Operator& op) const override { return 1; }
+ int GetVersion(const Operator& op) const override {
+ const auto& fq_op = static_cast<const FakeQuantOperator&>(op);
+ return fq_op.narrow_range ? 2 : 1;
+ }
};
class FullyConnected
@@ -366,12 +370,13 @@ class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
- return ::tflite::CreateGatherOptions(*builder, op.axis);
+ int axis = op.axis ? op.axis.value() : 0;
+ return ::tflite::CreateGatherOptions(*builder, axis);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
- op->axis = options.axis();
+ op->axis = {options.axis()};
}
int GetVersion(const Operator& op) const override { return 1; }
@@ -763,6 +768,44 @@ class Sum
int GetVersion(const Operator& op) const override { return 1; }
};
+class ReduceMax
+ : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
+ ::tflite::BuiltinOptions_ReducerOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->keep_dims = options.keep_dims();
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
+class ReduceProd
+ : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
+ ::tflite::BuiltinOptions_ReducerOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->keep_dims = options.keep_dims();
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
class ResizeBilinear
: public BuiltinOperator<ResizeBilinearOperator,
::tflite::ResizeBilinearOptions,
@@ -970,6 +1013,26 @@ class ExpandDims
int GetVersion(const Operator& op) const override { return 1; }
};
+class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
+ ::tflite::BuiltinOptions_PackOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->values_count = options.values_count();
+ op->axis = options.axis();
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
class Shape
: public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions,
::tflite::BuiltinOptions_ShapeOptions> {
@@ -1179,6 +1242,10 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
ops.emplace_back(
new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
ops.emplace_back(new Sum(::tflite::BuiltinOperator_SUM, OperatorType::kSum));
+ ops.emplace_back(new ReduceProd(::tflite::BuiltinOperator_REDUCE_PROD,
+ OperatorType::kReduceProd));
+ ops.emplace_back(new ReduceMax(::tflite::BuiltinOperator_REDUCE_MAX,
+ OperatorType::kReduceMax));
ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR,
OperatorType::kResizeBilinear));
ops.emplace_back(
@@ -1209,6 +1276,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
new Shape(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape));
ops.emplace_back(new FakeQuant(::tflite::BuiltinOperator_FAKE_QUANT,
OperatorType::kFakeQuant));
+ ops.emplace_back(
+ new Pack(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
// Custom Operators.
ops.emplace_back(