diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco')
6 files changed, 145 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 40cd6dea82..47faa20a29 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -239,6 +239,12 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op, } break; } + case OperatorType::kUnidirectionalSequenceLstm: { + const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; + if (data_type != ArrayDataType::kFloat) return ::tensorflow::Status::OK(); + SetDataTypeForAllOutputs(model, op, data_type); + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 5496e2093e..e861df2b3d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -946,6 +946,49 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { .copy_shape(activ_temp_shape); } +void ProcessUnidirectionalSequenceLstmOperator( + Model* model, UnidirectionalSequenceLstmOperator* op) { + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // Shape already propagated + return; + } + + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes + return; + } + + // TODO(renjieliu): check the inputs, as well as all kinds of weights. + const auto& input_array = model->GetArray(op->inputs[0]); + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + const int batch_size = input_shape.dims(1); + const int timestamp = input_shape.dims(0); + + const auto& recurrent_to_output_weights_array = + model->GetArray(op->inputs[8]); + // Yield until input dims have been resolved. + if (!recurrent_to_output_weights_array.has_shape()) { + return; + } + + constexpr int kInputActivationStateTensor = 18; + constexpr int kInputCellStateTensor = 19; + // b(115961645): This is a hack to work around. + model->GetArray(op->inputs[kInputActivationStateTensor]).buffer.reset(); + model->GetArray(op->inputs[kInputCellStateTensor]).buffer.reset(); + + const auto& output_weights_shape = recurrent_to_output_weights_array.shape(); + const int output_size = output_weights_shape.dims(1); + + Shape* output_shape = output_array.mutable_shape(); + output_shape->ReplaceDims({timestamp, batch_size, output_size}); +} + void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) { const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. @@ -1800,6 +1843,10 @@ void ProcessUnpackOperator(Model* model, UnpackOperator* op) { ProcessResizeBilinearOperator(model, static_cast<ResizeBilinearOperator*>(op)); break; + case OperatorType::kUnidirectionalSequenceLstm: + ProcessUnidirectionalSequenceLstmOperator( + model, static_cast<UnidirectionalSequenceLstmOperator*>(op)); + break; case OperatorType::kLstmCell: ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op)); break; diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 32f22e1ea0..6b195cc992 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/session_options.h" @@ -2002,6 +2003,48 @@ tensorflow::Status ConvertCTCBeamSearchDecoderOperator( return tensorflow::Status::OK(); } +// This isn't a TensorFlow builtin op. Currently this node can only be generated +// with TfLite OpHint API. +tensorflow::Status ConvertUnidirectionalSequenceLstm( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm"); + + auto* op = new UnidirectionalSequenceLstmOperator(); + const auto& indices = GetListAttr(node, "_tflite_input_indices"); + if (indices.i_size() != node.input().size()) { + return tensorflow::errors::InvalidArgument("Input size does not match."); + } + + // The input size needs to be the same as the TfLite UniDirectionalSequence + // Lstm implementation. + const int kInputsSize = 20; + + op->inputs.resize(kInputsSize); + std::vector<bool> done(kInputsSize); + int idx = 0; + for (const string& input : node.input()) { + int real_index = indices.i(idx); + op->inputs[real_index] = (input); + done[real_index] = true; + idx++; + } + + for (int idx = 0; idx < done.size(); idx++) { + if (!done[idx]) { + string optional_name = node.name() + "_" + std::to_string(idx); + model->CreateOptionalArray(optional_name); + op->inputs[idx] = optional_name; + } + } + + // There're three outputs, only the last one is required. + op->outputs.push_back(node.name() + ":2"); + model->operators.emplace_back(op); + + return tensorflow::Status::OK(); +} + } // namespace namespace internal { @@ -2121,6 +2164,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Transpose", ConvertSimpleOperator<TransposeOperator, 2>}, {"Unpack", ConvertUnpackOperator}, {"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1>}, + {"UnidirectionalSequenceLstm", ConvertUnidirectionalSequenceLstm}, }); } diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 61f1f095e9..f3b84430db 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -58,6 +58,7 @@ enum class OperatorType : uint8 { kL2Normalization, kL2Pool, kLstmCell, + kUnidirectionalSequenceLstm, kLocalResponseNormalization, kLog, kLogistic, @@ -635,6 +636,11 @@ struct LstmCellOperator : Operator { KernelType kernel_type; }; +struct UnidirectionalSequenceLstmOperator : Operator { + UnidirectionalSequenceLstmOperator() + : Operator(OperatorType::kUnidirectionalSequenceLstm) {} +}; + // Element-wise multiplication operator. // // Inputs: diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index ed37535fe0..e08a61d357 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -741,6 +741,42 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions, } }; +class UnidirectionalSequenceLstm + : public BuiltinOperator< + UnidirectionalSequenceLstmOperator, + ::tflite::UnidirectionalSequenceLSTMOptions, + ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + // Current toco converter only supports tanh, no clip. + return ::tflite::CreateUnidirectionalSequenceLSTMOptions( + *builder, /*fused_activation_function=*/ + ::tflite::ActivationFunctionType_TANH, + /*cell_clip=*/0.0, + /*proj_clip=*/0.0); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + // Only support tanh activation, so check that tflite type is tanh. + DCHECK(options.fused_activation_function() == + ::tflite::ActivationFunctionType_TANH); + } + + int GetVersion(const Operator& op) const override { return 1; } + + std::vector<bool> GetMutatingInputVariables( + const Operator& op) const override { + std::vector<bool> mutating_input_variables(op.inputs.size(), false); + mutating_input_variables[kInputActivationStateTensor] = true; + mutating_input_variables[kInputCellStateTensor] = true; + return mutating_input_variables; + } +}; + class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions, ::tflite::BuiltinOptions_ReducerOptions> { public: @@ -1435,6 +1471,9 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList( OperatorType::kFakeQuant)); ops.push_back( MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack)); + ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>( + ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + OperatorType::kUnidirectionalSequenceLstm)); ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT, OperatorType::kOneHot)); ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK, diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 083a96ad9d..61aa311212 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -407,6 +407,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder) HANDLE_OPERATORTYPENAME_CASE(Unpack) HANDLE_OPERATORTYPENAME_CASE(ZerosLike) + HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceLstm) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE @@ -898,12 +899,12 @@ void CheckNoMissingArray(const Model& model) { void FixNoMissingArray(Model* model) { for (const auto& op : model->operators) { for (const auto& input : op->inputs) { - if (!model->HasArray(input)) { + if (!model->HasArray(input) && !model->IsOptionalArray(input)) { model->GetOrCreateArray(input); } } for (const auto& output : op->outputs) { - if (!model->HasArray(output)) { + if (!model->HasArray(output) && !model->IsOptionalArray(output)) { model->GetOrCreateArray(output); } } |