aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc44
1 files changed, 44 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 32f22e1ea0..6b195cc992 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/session_options.h"
@@ -2002,6 +2003,48 @@ tensorflow::Status ConvertCTCBeamSearchDecoderOperator(
return tensorflow::Status::OK();
}
+// This isn't a TensorFlow builtin op. Currently this node can only be generated
+// with TfLite OpHint API.
+tensorflow::Status ConvertUnidirectionalSequenceLstm(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm");
+
+ auto* op = new UnidirectionalSequenceLstmOperator();
+ const auto& indices = GetListAttr(node, "_tflite_input_indices");
+ if (indices.i_size() != node.input().size()) {
+ return tensorflow::errors::InvalidArgument("Input size does not match.");
+ }
+
+ // The input size needs to be the same as the TfLite UniDirectionalSequence
+ // Lstm implementation.
+ const int kInputsSize = 20;
+
+ op->inputs.resize(kInputsSize);
+ std::vector<bool> done(kInputsSize);
+ int idx = 0;
+ for (const string& input : node.input()) {
+ int real_index = indices.i(idx);
+ op->inputs[real_index] = (input);
+ done[real_index] = true;
+ idx++;
+ }
+
+ for (int idx = 0; idx < done.size(); idx++) {
+ if (!done[idx]) {
+ string optional_name = node.name() + "_" + std::to_string(idx);
+ model->CreateOptionalArray(optional_name);
+ op->inputs[idx] = optional_name;
+ }
+ }
+
+ // There're three outputs, only the last one is required.
+ op->outputs.push_back(node.name() + ":2");
+ model->operators.emplace_back(op);
+
+ return tensorflow::Status::OK();
+}
+
} // namespace
namespace internal {
@@ -2121,6 +2164,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Transpose", ConvertSimpleOperator<TransposeOperator, 2>},
{"Unpack", ConvertUnpackOperator},
{"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1>},
+ {"UnidirectionalSequenceLstm", ConvertUnidirectionalSequenceLstm},
});
}