aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-09 20:05:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 20:08:54 -0700
commit854ae599743a1e92a31ad49cfe42c6454cefd3b9 (patch)
tree1ff75695f61c5eb3353e739295e81f76bbe28a64 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
parent58fcfc98cd59ae3952399fc55380b8733df08df9 (diff)
Use Ophints to support TfLite UnidirectionaSequenceLstm and add an e2e test.
Support peephole and num_proj as well. PiperOrigin-RevId: 216467578
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc47
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 5496e2093e..e861df2b3d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -946,6 +946,49 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
.copy_shape(activ_temp_shape);
}
+void ProcessUnidirectionalSequenceLstmOperator(
+ Model* model, UnidirectionalSequenceLstmOperator* op) {
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.has_shape()) {
+ // Shape already propagated
+ return;
+ }
+
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes
+ return;
+ }
+
+ // TODO(renjieliu): check the inputs, as well as all kinds of weights.
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ const int batch_size = input_shape.dims(1);
+ const int timestamp = input_shape.dims(0);
+
+ const auto& recurrent_to_output_weights_array =
+ model->GetArray(op->inputs[8]);
+ // Yield until input dims have been resolved.
+ if (!recurrent_to_output_weights_array.has_shape()) {
+ return;
+ }
+
+ constexpr int kInputActivationStateTensor = 18;
+ constexpr int kInputCellStateTensor = 19;
+ // b(115961645): This is a hack to work around.
+ model->GetArray(op->inputs[kInputActivationStateTensor]).buffer.reset();
+ model->GetArray(op->inputs[kInputCellStateTensor]).buffer.reset();
+
+ const auto& output_weights_shape = recurrent_to_output_weights_array.shape();
+ const int output_size = output_weights_shape.dims(1);
+
+ Shape* output_shape = output_array.mutable_shape();
+ output_shape->ReplaceDims({timestamp, batch_size, output_size});
+}
+
void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
@@ -1800,6 +1843,10 @@ void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
ProcessResizeBilinearOperator(model,
static_cast<ResizeBilinearOperator*>(op));
break;
+ case OperatorType::kUnidirectionalSequenceLstm:
+ ProcessUnidirectionalSequenceLstmOperator(
+ model, static_cast<UnidirectionalSequenceLstmOperator*>(op));
+ break;
case OperatorType::kLstmCell:
ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
break;