aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc47
2 files changed, 53 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 40cd6dea82..47faa20a29 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -239,6 +239,12 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
}
break;
}
+ case OperatorType::kUnidirectionalSequenceLstm: {
+ const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
+ if (data_type != ArrayDataType::kFloat) return ::tensorflow::Status::OK();
+ SetDataTypeForAllOutputs(model, op, data_type);
+ break;
+ }
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 5496e2093e..e861df2b3d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -946,6 +946,49 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
.copy_shape(activ_temp_shape);
}
+void ProcessUnidirectionalSequenceLstmOperator(
+ Model* model, UnidirectionalSequenceLstmOperator* op) {
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.has_shape()) {
+ // Shape already propagated
+ return;
+ }
+
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes
+ return;
+ }
+
+ // TODO(renjieliu): check the inputs, as well as all kinds of weights.
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ const int batch_size = input_shape.dims(1);
+ const int timestamp = input_shape.dims(0);
+
+ const auto& recurrent_to_output_weights_array =
+ model->GetArray(op->inputs[8]);
+ // Yield until input dims have been resolved.
+ if (!recurrent_to_output_weights_array.has_shape()) {
+ return;
+ }
+
+ constexpr int kInputActivationStateTensor = 18;
+ constexpr int kInputCellStateTensor = 19;
+ // b(115961645): This is a hack to work around.
+ model->GetArray(op->inputs[kInputActivationStateTensor]).buffer.reset();
+ model->GetArray(op->inputs[kInputCellStateTensor]).buffer.reset();
+
+ const auto& output_weights_shape = recurrent_to_output_weights_array.shape();
+ const int output_size = output_weights_shape.dims(1);
+
+ Shape* output_shape = output_array.mutable_shape();
+ output_shape->ReplaceDims({timestamp, batch_size, output_size});
+}
+
void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
@@ -1800,6 +1843,10 @@ void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
ProcessResizeBilinearOperator(model,
static_cast<ResizeBilinearOperator*>(op));
break;
+ case OperatorType::kUnidirectionalSequenceLstm:
+ ProcessUnidirectionalSequenceLstmOperator(
+ model, static_cast<UnidirectionalSequenceLstmOperator*>(op));
+ break;
case OperatorType::kLstmCell:
ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
break;