aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc47
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 5496e2093e..e861df2b3d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -946,6 +946,49 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
.copy_shape(activ_temp_shape);
}
+void ProcessUnidirectionalSequenceLstmOperator(
+ Model* model, UnidirectionalSequenceLstmOperator* op) {
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.has_shape()) {
+ // Shape already propagated
+ return;
+ }
+
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes
+ return;
+ }
+
+ // TODO(renjieliu): check the inputs, as well as all kinds of weights.
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ const int batch_size = input_shape.dims(1);
+ const int timestamp = input_shape.dims(0);
+
+ const auto& recurrent_to_output_weights_array =
+ model->GetArray(op->inputs[8]);
+ // Yield until input dims have been resolved.
+ if (!recurrent_to_output_weights_array.has_shape()) {
+ return;
+ }
+
+ constexpr int kInputActivationStateTensor = 18;
+ constexpr int kInputCellStateTensor = 19;
+ // b(115961645): This is a hack to work around.
+ model->GetArray(op->inputs[kInputActivationStateTensor]).buffer.reset();
+ model->GetArray(op->inputs[kInputCellStateTensor]).buffer.reset();
+
+ const auto& output_weights_shape = recurrent_to_output_weights_array.shape();
+ const int output_size = output_weights_shape.dims(1);
+
+ Shape* output_shape = output_array.mutable_shape();
+ output_shape->ReplaceDims({timestamp, batch_size, output_size});
+}
+
void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
@@ -1800,6 +1843,10 @@ void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
ProcessResizeBilinearOperator(model,
static_cast<ResizeBilinearOperator*>(op));
break;
+ case OperatorType::kUnidirectionalSequenceLstm:
+ ProcessUnidirectionalSequenceLstmOperator(
+ model, static_cast<UnidirectionalSequenceLstmOperator*>(op));
+ break;
case OperatorType::kLstmCell:
ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
break;