diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-09 20:05:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 20:08:54 -0700 |
commit | 854ae599743a1e92a31ad49cfe42c6454cefd3b9 (patch) | |
tree | 1ff75695f61c5eb3353e739295e81f76bbe28a64 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | |
parent | 58fcfc98cd59ae3952399fc55380b8733df08df9 (diff) |
Use Ophints to support TfLite UnidirectionaSequenceLstm and add an e2e test.
Support peephole and num_proj as well.
PiperOrigin-RevId: 216467578
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 5496e2093e..e861df2b3d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -946,6 +946,49 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { .copy_shape(activ_temp_shape); } +void ProcessUnidirectionalSequenceLstmOperator( + Model* model, UnidirectionalSequenceLstmOperator* op) { + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // Shape already propagated + return; + } + + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes + return; + } + + // TODO(renjieliu): check the inputs, as well as all kinds of weights. + const auto& input_array = model->GetArray(op->inputs[0]); + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + const int batch_size = input_shape.dims(1); + const int timestamp = input_shape.dims(0); + + const auto& recurrent_to_output_weights_array = + model->GetArray(op->inputs[8]); + // Yield until input dims have been resolved. + if (!recurrent_to_output_weights_array.has_shape()) { + return; + } + + constexpr int kInputActivationStateTensor = 18; + constexpr int kInputCellStateTensor = 19; + // b(115961645): This is a hack to work around. + model->GetArray(op->inputs[kInputActivationStateTensor]).buffer.reset(); + model->GetArray(op->inputs[kInputCellStateTensor]).buffer.reset(); + + const auto& output_weights_shape = recurrent_to_output_weights_array.shape(); + const int output_size = output_weights_shape.dims(1); + + Shape* output_shape = output_array.mutable_shape(); + output_shape->ReplaceDims({timestamp, batch_size, output_size}); +} + void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) { const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. @@ -1800,6 +1843,10 @@ void ProcessUnpackOperator(Model* model, UnpackOperator* op) { ProcessResizeBilinearOperator(model, static_cast<ResizeBilinearOperator*>(op)); break; + case OperatorType::kUnidirectionalSequenceLstm: + ProcessUnidirectionalSequenceLstmOperator( + model, static_cast<UnidirectionalSequenceLstmOperator*>(op)); + break; case OperatorType::kLstmCell: ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op)); break; |