aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tflite/operator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite/operator.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc39
1 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index ed37535fe0..e08a61d357 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -741,6 +741,42 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
}
};
+class UnidirectionalSequenceLstm
+ : public BuiltinOperator<
+ UnidirectionalSequenceLstmOperator,
+ ::tflite::UnidirectionalSequenceLSTMOptions,
+ ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ // Current toco converter only supports tanh, no clip.
+ return ::tflite::CreateUnidirectionalSequenceLSTMOptions(
+ *builder, /*fused_activation_function=*/
+ ::tflite::ActivationFunctionType_TANH,
+ /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ // Only support tanh activation, so check that tflite type is tanh.
+ DCHECK(options.fused_activation_function() ==
+ ::tflite::ActivationFunctionType_TANH);
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+
+ std::vector<bool> GetMutatingInputVariables(
+ const Operator& op) const override {
+ std::vector<bool> mutating_input_variables(op.inputs.size(), false);
+ mutating_input_variables[kInputActivationStateTensor] = true;
+ mutating_input_variables[kInputCellStateTensor] = true;
+ return mutating_input_variables;
+ }
+};
+
class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
@@ -1435,6 +1471,9 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
OperatorType::kFakeQuant));
ops.push_back(
MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
+ ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>(
+ ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
+ OperatorType::kUnidirectionalSequenceLstm));
ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
OperatorType::kOneHot));
ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,