aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc47
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc44
-rw-r--r--tensorflow/contrib/lite/toco/model.h6
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc39
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc5
6 files changed, 145 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 40cd6dea82..47faa20a29 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -239,6 +239,12 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
}
break;
}
+ case OperatorType::kUnidirectionalSequenceLstm: {
+ const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
+ if (data_type != ArrayDataType::kFloat) return ::tensorflow::Status::OK();
+ SetDataTypeForAllOutputs(model, op, data_type);
+ break;
+ }
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 5496e2093e..e861df2b3d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -946,6 +946,49 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
.copy_shape(activ_temp_shape);
}
+void ProcessUnidirectionalSequenceLstmOperator(
+ Model* model, UnidirectionalSequenceLstmOperator* op) {
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.has_shape()) {
+ // Shape already propagated
+ return;
+ }
+
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes
+ return;
+ }
+
+ // TODO(renjieliu): check the inputs, as well as all kinds of weights.
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ const int batch_size = input_shape.dims(1);
+ const int timestamp = input_shape.dims(0);
+
+ const auto& recurrent_to_output_weights_array =
+ model->GetArray(op->inputs[8]);
+ // Yield until input dims have been resolved.
+ if (!recurrent_to_output_weights_array.has_shape()) {
+ return;
+ }
+
+ constexpr int kInputActivationStateTensor = 18;
+ constexpr int kInputCellStateTensor = 19;
+ // b(115961645): This is a hack to work around.
+ model->GetArray(op->inputs[kInputActivationStateTensor]).buffer.reset();
+ model->GetArray(op->inputs[kInputCellStateTensor]).buffer.reset();
+
+ const auto& output_weights_shape = recurrent_to_output_weights_array.shape();
+ const int output_size = output_weights_shape.dims(1);
+
+ Shape* output_shape = output_array.mutable_shape();
+ output_shape->ReplaceDims({timestamp, batch_size, output_size});
+}
+
void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
@@ -1800,6 +1843,10 @@ void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
ProcessResizeBilinearOperator(model,
static_cast<ResizeBilinearOperator*>(op));
break;
+ case OperatorType::kUnidirectionalSequenceLstm:
+ ProcessUnidirectionalSequenceLstmOperator(
+ model, static_cast<UnidirectionalSequenceLstmOperator*>(op));
+ break;
case OperatorType::kLstmCell:
ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
break;
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 32f22e1ea0..6b195cc992 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/session_options.h"
@@ -2002,6 +2003,48 @@ tensorflow::Status ConvertCTCBeamSearchDecoderOperator(
return tensorflow::Status::OK();
}
+// This isn't a TensorFlow builtin op. Currently this node can only be generated
+// with TfLite OpHint API.
+tensorflow::Status ConvertUnidirectionalSequenceLstm(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm");
+
+ auto* op = new UnidirectionalSequenceLstmOperator();
+ const auto& indices = GetListAttr(node, "_tflite_input_indices");
+ if (indices.i_size() != node.input().size()) {
+ return tensorflow::errors::InvalidArgument("Input size does not match.");
+ }
+
+ // The input size needs to be the same as the TfLite UniDirectionalSequence
+ // Lstm implementation.
+ const int kInputsSize = 20;
+
+ op->inputs.resize(kInputsSize);
+ std::vector<bool> done(kInputsSize);
+ int idx = 0;
+ for (const string& input : node.input()) {
+ int real_index = indices.i(idx);
+ op->inputs[real_index] = (input);
+ done[real_index] = true;
+ idx++;
+ }
+
+ for (int idx = 0; idx < done.size(); idx++) {
+ if (!done[idx]) {
+ string optional_name = node.name() + "_" + std::to_string(idx);
+ model->CreateOptionalArray(optional_name);
+ op->inputs[idx] = optional_name;
+ }
+ }
+
+ // There're three outputs, only the last one is required.
+ op->outputs.push_back(node.name() + ":2");
+ model->operators.emplace_back(op);
+
+ return tensorflow::Status::OK();
+}
+
} // namespace
namespace internal {
@@ -2121,6 +2164,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Transpose", ConvertSimpleOperator<TransposeOperator, 2>},
{"Unpack", ConvertUnpackOperator},
{"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1>},
+ {"UnidirectionalSequenceLstm", ConvertUnidirectionalSequenceLstm},
});
}
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 61f1f095e9..f3b84430db 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -58,6 +58,7 @@ enum class OperatorType : uint8 {
kL2Normalization,
kL2Pool,
kLstmCell,
+ kUnidirectionalSequenceLstm,
kLocalResponseNormalization,
kLog,
kLogistic,
@@ -635,6 +636,11 @@ struct LstmCellOperator : Operator {
KernelType kernel_type;
};
+struct UnidirectionalSequenceLstmOperator : Operator {
+ UnidirectionalSequenceLstmOperator()
+ : Operator(OperatorType::kUnidirectionalSequenceLstm) {}
+};
+
// Element-wise multiplication operator.
//
// Inputs:
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index ed37535fe0..e08a61d357 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -741,6 +741,42 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
}
};
+class UnidirectionalSequenceLstm
+ : public BuiltinOperator<
+ UnidirectionalSequenceLstmOperator,
+ ::tflite::UnidirectionalSequenceLSTMOptions,
+ ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ // Current toco converter only supports tanh, no clip.
+ return ::tflite::CreateUnidirectionalSequenceLSTMOptions(
+ *builder, /*fused_activation_function=*/
+ ::tflite::ActivationFunctionType_TANH,
+ /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ // Only support tanh activation, so check that tflite type is tanh.
+ DCHECK(options.fused_activation_function() ==
+ ::tflite::ActivationFunctionType_TANH);
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+
+ std::vector<bool> GetMutatingInputVariables(
+ const Operator& op) const override {
+ std::vector<bool> mutating_input_variables(op.inputs.size(), false);
+ mutating_input_variables[kInputActivationStateTensor] = true;
+ mutating_input_variables[kInputCellStateTensor] = true;
+ return mutating_input_variables;
+ }
+};
+
class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
@@ -1435,6 +1471,9 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
OperatorType::kFakeQuant));
ops.push_back(
MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
+ ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>(
+ ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
+ OperatorType::kUnidirectionalSequenceLstm));
ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
OperatorType::kOneHot));
ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 083a96ad9d..61aa311212 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -407,6 +407,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder)
HANDLE_OPERATORTYPENAME_CASE(Unpack)
HANDLE_OPERATORTYPENAME_CASE(ZerosLike)
+ HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceLstm)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE
@@ -898,12 +899,12 @@ void CheckNoMissingArray(const Model& model) {
void FixNoMissingArray(Model* model) {
for (const auto& op : model->operators) {
for (const auto& input : op->inputs) {
- if (!model->HasArray(input)) {
+ if (!model->HasArray(input) && !model->IsOptionalArray(input)) {
model->GetOrCreateArray(input);
}
}
for (const auto& output : op->outputs) {
- if (!model->HasArray(output)) {
+ if (!model->HasArray(output) && !model->IsOptionalArray(output)) {
model->GetOrCreateArray(output);
}
}