diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 32f22e1ea0..6b195cc992 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/session_options.h" @@ -2002,6 +2003,48 @@ tensorflow::Status ConvertCTCBeamSearchDecoderOperator( return tensorflow::Status::OK(); } +// This isn't a TensorFlow builtin op. Currently this node can only be generated +// with TfLite OpHint API. +tensorflow::Status ConvertUnidirectionalSequenceLstm( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm"); + + auto* op = new UnidirectionalSequenceLstmOperator(); + const auto& indices = GetListAttr(node, "_tflite_input_indices"); + if (indices.i_size() != node.input().size()) { + return tensorflow::errors::InvalidArgument("Input size does not match."); + } + + // The input size needs to be the same as the TfLite UniDirectionalSequence + // Lstm implementation. + const int kInputsSize = 20; + + op->inputs.resize(kInputsSize); + std::vector<bool> done(kInputsSize); + int idx = 0; + for (const string& input : node.input()) { + int real_index = indices.i(idx); + op->inputs[real_index] = (input); + done[real_index] = true; + idx++; + } + + for (int idx = 0; idx < done.size(); idx++) { + if (!done[idx]) { + string optional_name = node.name() + "_" + std::to_string(idx); + model->CreateOptionalArray(optional_name); + op->inputs[idx] = optional_name; + } + } + + // There're three outputs, only the last one is required. + op->outputs.push_back(node.name() + ":2"); + model->operators.emplace_back(op); + + return tensorflow::Status::OK(); +} + } // namespace namespace internal { @@ -2121,6 +2164,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Transpose", ConvertSimpleOperator<TransposeOperator, 2>}, {"Unpack", ConvertUnpackOperator}, {"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1>}, + {"UnidirectionalSequenceLstm", ConvertUnidirectionalSequenceLstm}, }); } |