aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tflite/operator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite/operator.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc21
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 4b2ef756cc..9380168f30 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1053,6 +1053,23 @@ class Shape
int GetVersion(const Operator& op) const override { return 1; }
};
+class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
+ ::tflite::BuiltinOptions_OneHotOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateOneHotOptions(*builder, op.axis);
+ }
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->axis = options.axis();
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
class TensorFlowUnsupported : public BaseOperator {
public:
using BaseOperator::BaseOperator;
@@ -1278,6 +1295,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
OperatorType::kFakeQuant));
ops.emplace_back(
new Pack(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
+ ops.emplace_back(
+ new OneHot(::tflite::BuiltinOperator_ONE_HOT, OperatorType::kOneHot));
// Custom Operators.
ops.emplace_back(
@@ -1331,6 +1350,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
ops.emplace_back(
new SimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice));
ops.emplace_back(new SimpleOperator<PowOperator>("POW", OperatorType::kPow));
+ ops.emplace_back(new SimpleOperator<LogicalOrOperator>(
+ "LOGICAL_OR", OperatorType::kLogicalOr));
// Element-wise operator
ops.emplace_back(new SimpleOperator<SinOperator>("SIN", OperatorType::kSin));
ops.emplace_back(new SimpleOperator<LogOperator>("LOG", OperatorType::kLog));