diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 5496e2093e..e861df2b3d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -946,6 +946,49 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { .copy_shape(activ_temp_shape); } +void ProcessUnidirectionalSequenceLstmOperator( + Model* model, UnidirectionalSequenceLstmOperator* op) { + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // Shape already propagated + return; + } + + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes + return; + } + + // TODO(renjieliu): check the inputs, as well as all kinds of weights. + const auto& input_array = model->GetArray(op->inputs[0]); + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + const int batch_size = input_shape.dims(1); + const int timestamp = input_shape.dims(0); + + const auto& recurrent_to_output_weights_array = + model->GetArray(op->inputs[8]); + // Yield until input dims have been resolved. + if (!recurrent_to_output_weights_array.has_shape()) { + return; + } + + constexpr int kInputActivationStateTensor = 18; + constexpr int kInputCellStateTensor = 19; + // b(115961645): This is a hack to work around. + model->GetArray(op->inputs[kInputActivationStateTensor]).buffer.reset(); + model->GetArray(op->inputs[kInputCellStateTensor]).buffer.reset(); + + const auto& output_weights_shape = recurrent_to_output_weights_array.shape(); + const int output_size = output_weights_shape.dims(1); + + Shape* output_shape = output_array.mutable_shape(); + output_shape->ReplaceDims({timestamp, batch_size, output_size}); +} + void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) { const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. @@ -1800,6 +1843,10 @@ void ProcessUnpackOperator(Model* model, UnpackOperator* op) { ProcessResizeBilinearOperator(model, static_cast<ResizeBilinearOperator*>(op)); break; + case OperatorType::kUnidirectionalSequenceLstm: + ProcessUnidirectionalSequenceLstmOperator( + model, static_cast<UnidirectionalSequenceLstmOperator*>(op)); + break; case OperatorType::kLstmCell: ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op)); break; |