aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tflite/operator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite/operator.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc31
1 files changed, 29 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 84a5410839..a8518adefc 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -626,11 +626,21 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
+ ::tflite::LSTMKernelType kernel_type;
+ switch (op.kernel_type) {
+ case LstmCellOperator::KERNEL_BASIC:
+ kernel_type = ::tflite::LSTMKernelType_BASIC;
+ break;
+ case LstmCellOperator::KERNEL_FULL:
+ kernel_type = ::tflite::LSTMKernelType_FULL;
+ break;
+ }
+
// Current toco converter only supports tanh, no clip.
return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
::tflite::ActivationFunctionType_TANH,
/*cell_clip=*/0.0,
- /*proj_clip=*/0.0);
+ /*proj_clip=*/0.0, kernel_type);
}
void ReadOptions(const TfLiteOptions& options,
@@ -638,9 +648,26 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
// Only support tanh activation, so check that tflite type is tanh.
CHECK(options.fused_activation_function() ==
::tflite::ActivationFunctionType_TANH);
+
+ switch (options.kernel_type()) {
+ case ::tflite::LSTMKernelType_BASIC:
+ op->kernel_type = LstmCellOperator::KERNEL_BASIC;
+ break;
+ case ::tflite::LSTMKernelType_FULL:
+ op->kernel_type = LstmCellOperator::KERNEL_FULL;
+ break;
+ }
}
- int GetVersion(const Operator& op) const override { return 1; }
+ int GetVersion(const Operator& op) const override {
+ const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
+ switch (lstm_op.kernel_type) {
+ case LstmCellOperator::KERNEL_FULL:
+ return 1;
+ case LstmCellOperator::KERNEL_BASIC:
+ return 2;
+ }
+ }
};
class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions,