aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/toco/BUILD5
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc185
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc171
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.cc97
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h102
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc5
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD11
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/lstm_utils_test.cc442
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc24
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc9
11 files changed, 1048 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 20c156a932..864d0254f2 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -188,7 +188,10 @@ cc_library(
"graph_transformations/identify_l2_normalization.cc",
"graph_transformations/identify_l2_pool.cc",
"graph_transformations/identify_lstm.cc",
+ "graph_transformations/identify_lstm_merge_inputs.cc",
+ "graph_transformations/identify_lstm_split_inputs.cc",
"graph_transformations/identify_relu1.cc",
+ "graph_transformations/lstm_utils.cc",
"graph_transformations/make_initial_dequantize_operator.cc",
"graph_transformations/propagate_array_data_types.cc",
"graph_transformations/propagate_fixed_sizes.cc",
@@ -235,6 +238,7 @@ cc_library(
],
hdrs = [
"graph_transformations/graph_transformations.h",
+ "graph_transformations/lstm_utils.h",
],
visibility = ["//visibility:public"],
deps = [
@@ -245,6 +249,7 @@ cc_library(
":tooling_util",
":types_proto_cc",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index cf90ebe996..5d7ada1b74 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -124,6 +124,8 @@ DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine)
DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization)
DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool)
DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell)
+DECLARE_GRAPH_TRANSFORMATION(SplitLstmCellInputs)
+DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs)
DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator)
DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
new file mode 100644
index 0000000000..45335fd78c
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
@@ -0,0 +1,185 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iostream>
+#include <string>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
+ // Find lstm cell.
+ auto op_it = model->operators.begin() + op_index;
+ auto src_op = op_it->get();
+ if (src_op->type != OperatorType::kLstmCell) {
+ return false;
+ }
+
+ // Already a compact LstmCell with LstmCellOperator::NUM_INPUTS of inputs,
+ // do not need to merge cell inputs.
+ if (src_op->inputs.size() == LstmCellOperator::NUM_INPUTS) {
+ return false;
+ }
+
+ // Identify prev_activ_input, prev_state_input as required Op inputs,
+ // using the rnn_states in the model flag.
+ string prev_activ_input;
+ if (!GetMatchingRnnArray(model, src_op->outputs[kOutputTensor],
+ &prev_activ_input)) {
+ return false;
+ }
+ string prev_state_input;
+ if (!GetMatchingRnnArray(model, src_op->outputs[kCellStateTensor],
+ &prev_state_input)) {
+ return false;
+ }
+
+ // Get LstmCell's cell, input, output size.
+ int num_cell = model->GetArray(src_op->inputs[kInputToInputWeightsTensor])
+ .shape()
+ .dims(0);
+ int num_input = model->GetArray(src_op->inputs[kInputToInputWeightsTensor])
+ .shape()
+ .dims(1);
+ int num_output =
+ model->GetArray(src_op->inputs[kRecurrentToInputWeightsTensor])
+ .shape()
+ .dims(1);
+
+ // Make sure n_cell and n_output are equal as there is no projection.
+ CHECK_EQ(num_cell, num_output);
+
+ // Create tensorflow_graphdef style's one big weight tensor.
+ const string base_name(FindLongestCommonPrefix(
+ src_op->outputs[kOutputTensor], src_op->outputs[kCellStateTensor]));
+ string merged_weights = AvailableArrayName(*model, base_name + "weights");
+ auto& array = model->GetOrCreateArray(merged_weights);
+ array.data_type = ArrayDataType::kFloat;
+ int weights_dim1 = 4 * num_cell;
+ int weights_dim2 = num_input + num_output;
+ Shape shape = Shape({weights_dim1, weights_dim2});
+ array.copy_shape(shape);
+ auto& buffer = array.GetMutableBuffer<ArrayDataType::kFloat>();
+ buffer.data.resize(weights_dim1 * weights_dim2);
+
+ // Merge 8 small weight tensors to 1 weight tensor.
+ CopyArrayToSubArray(
+ buffer, weights_dim2,
+ model->GetArray(src_op->inputs[kInputToInputWeightsTensor]), 0, 0);
+ CopyArrayToSubArray(
+ buffer, weights_dim2,
+ model->GetArray(src_op->inputs[kInputToCellWeightsTensor]), num_cell, 0);
+ CopyArrayToSubArray(
+ buffer, weights_dim2,
+ model->GetArray(src_op->inputs[kInputToForgetWeightsTensor]),
+ num_cell * 2, 0);
+ CopyArrayToSubArray(
+ buffer, weights_dim2,
+ model->GetArray(src_op->inputs[kInputToOutputWeightsTensor]),
+ num_cell * 3, 0);
+ CopyArrayToSubArray(
+ buffer, weights_dim2,
+ model->GetArray(src_op->inputs[kRecurrentToInputWeightsTensor]), 0,
+ num_input);
+ CopyArrayToSubArray(
+ buffer, weights_dim2,
+ model->GetArray(src_op->inputs[kRecurrentToCellWeightsTensor]), num_cell,
+ num_input);
+ CopyArrayToSubArray(
+ buffer, weights_dim2,
+ model->GetArray(src_op->inputs[kRecurrentToForgetWeightsTensor]),
+ num_cell * 2, num_input);
+ CopyArrayToSubArray(
+ buffer, weights_dim2,
+ model->GetArray(src_op->inputs[kRecurrentToOutputWeightsTensor]),
+ num_cell * 3, num_input);
+
+ // Create tensorflow_graphdef style's one big bias tensor.
+ string merged_biases = AvailableArrayName(*model, base_name + "biases");
+ auto& bias_array = model->GetOrCreateArray(merged_biases);
+ bias_array.data_type = ArrayDataType::kFloat;
+ bias_array.copy_shape(Shape({weights_dim1}));
+ auto& bias_buffer = bias_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ bias_buffer.data.resize(weights_dim1);
+
+ // Merge 4 small bias tensors into a big one.
+ CopyArrayToSubArray(bias_buffer, weights_dim2,
+ model->GetArray(src_op->inputs[kInputGateBiasTensor]), 0,
+ 0);
+ CopyArrayToSubArray(bias_buffer, weights_dim2,
+ model->GetArray(src_op->inputs[kCellGateBiasTensor]),
+ num_cell, 0);
+ CopyArrayToSubArray(bias_buffer, weights_dim2,
+ model->GetArray(src_op->inputs[kForgetGateBiasTensor]),
+ num_cell * 2, 0);
+ CopyArrayToSubArray(bias_buffer, weights_dim2,
+ model->GetArray(src_op->inputs[kOutputGateBiasTensor]),
+ num_cell * 3, 0);
+
+ // Emplace a new LSTM cell operator (use basic 5 inputs kernel).
+ auto lstm_cell_op = absl::make_unique<LstmCellOperator>();
+
+ // Compact LstmCell's 5 inputs.
+ lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS);
+ lstm_cell_op->inputs[LstmCellOperator::DATA_INPUT] =
+ src_op->inputs[kInputTensor];
+ lstm_cell_op->inputs[LstmCellOperator::WEIGHTS_INPUT] = merged_weights;
+ lstm_cell_op->inputs[LstmCellOperator::BIASES_INPUT] = merged_biases;
+ lstm_cell_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT] = prev_activ_input;
+ lstm_cell_op->inputs[LstmCellOperator::PREV_STATE_INPUT] = prev_state_input;
+
+ // Reorder LstmCell's 4 outputs.
+ lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS);
+ lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT] =
+ src_op->outputs[kOutputTensor];
+ lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT] =
+ src_op->outputs[kCellStateTensor];
+ lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] =
+ src_op->outputs[kScratchBufferTensor];
+ lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] =
+ src_op->outputs[kOutputStateTensor];
+
+ // Add the op into model.
+ model->operators.emplace(op_it, std::move(lstm_cell_op));
+ AddMessageF("Creating compact LstmCell replacing previous lstm cell");
+
+ // Delete arrays and operators replaced by the LSTM cell operator. Order is
+ // important - DeleteArrayIfUnused() only succeeds if dependent operators
+ // have been removed first. Start at the output and work towards the input.
+ // Erase curr lstm op being replaced.
+ DeleteArrayIfUnused(src_op->inputs[kInputToInputWeightsTensor], model);
+ DeleteArrayIfUnused(src_op->inputs[kInputToForgetWeightsTensor], model);
+ DeleteArrayIfUnused(src_op->inputs[kInputToCellWeightsTensor], model);
+ DeleteArrayIfUnused(src_op->inputs[kInputToOutputWeightsTensor], model);
+ DeleteArrayIfUnused(src_op->inputs[kRecurrentToInputWeightsTensor], model);
+ DeleteArrayIfUnused(src_op->inputs[kRecurrentToForgetWeightsTensor], model);
+ DeleteArrayIfUnused(src_op->inputs[kRecurrentToCellWeightsTensor], model);
+ DeleteArrayIfUnused(src_op->inputs[kRecurrentToOutputWeightsTensor], model);
+ DeleteArrayIfUnused(src_op->inputs[kInputGateBiasTensor], model);
+ DeleteArrayIfUnused(src_op->inputs[kForgetGateBiasTensor], model);
+ DeleteArrayIfUnused(src_op->inputs[kCellGateBiasTensor], model);
+ DeleteArrayIfUnused(src_op->inputs[kOutputGateBiasTensor], model);
+ model->operators.erase(FindOp(*model, src_op));
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
new file mode 100644
index 0000000000..eca717680a
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
@@ -0,0 +1,171 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iostream>
+#include <string>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
+ // Find lstm cell.
+ auto op_it = model->operators.begin() + op_index;
+ auto curr_op = op_it->get();
+ if (curr_op->type != OperatorType::kLstmCell) {
+ return false;
+ }
+
+ // Already an extended LstmCell with kExtendedLstmInputCount of inputs,
+ // do not need to split cell inputs.
+ if (curr_op->inputs.size() == kExtendedLstmInputCount) {
+ return false;
+ }
+
+ // Make sure the WEIGHTS_INPUT and BIASES_INPUT are constant arrays,
+ // that are able to be split into smaller weight and bias tensors.
+ if (!IsConstantParameterArray(
+ *model, curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]) ||
+ !IsConstantParameterArray(
+ *model, curr_op->inputs[LstmCellOperator::BIASES_INPUT])) {
+ return false;
+ }
+
+ // Make sure propagate_fixed_sizes has defined the size of the output.
+ if (!model->GetArray(curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT])
+ .has_shape()) {
+ return false;
+ }
+
+ // Emplace a new LstmCell operator with extended inputs (kernel/lstm.cc).
+ auto lstm_cell_op = absl::make_unique<LstmCellOperator>();
+ lstm_cell_op->inputs.resize(kExtendedLstmInputCount);
+ int num_input = model->GetArray(curr_op->inputs[LstmCellOperator::DATA_INPUT])
+ .shape()
+ .dims(1);
+
+ // n_cell and n_output have the same size when there is no projection.
+ int num_cell =
+ model->GetArray(curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT])
+ .shape()
+ .dims(1);
+ int num_output = num_cell;
+
+ // Data input.
+ lstm_cell_op->inputs[kInputTensor] =
+ curr_op->inputs[LstmCellOperator::ACTIV_OUTPUT];
+
+ // Get original weight tensor and decompose 1 tensor to 8 sub tensors.
+ Array& kernel =
+ model->GetArray(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
+ const string base_name(FindLongestCommonPrefix(
+ curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT],
+ curr_op->outputs[LstmCellOperator::STATE_OUTPUT]));
+
+ // Input weight tensors of size {n_cell, n_input}.
+ CopySubArrayToArray(
+ model, &(lstm_cell_op->inputs[kInputToInputWeightsTensor]),
+ base_name + "weight_i_i", num_cell, num_input, kernel, 0, 0);
+ CopySubArrayToArray(model, &(lstm_cell_op->inputs[kInputToCellWeightsTensor]),
+ base_name + "weight_c_i", num_cell, num_input, kernel,
+ num_cell, 0);
+ CopySubArrayToArray(
+ model, &(lstm_cell_op->inputs[kInputToForgetWeightsTensor]),
+ base_name + "weight_f_i", num_cell, num_input, kernel, num_cell * 2, 0);
+ CopySubArrayToArray(
+ model, &(lstm_cell_op->inputs[kInputToOutputWeightsTensor]),
+ base_name + "weight_o_i", num_cell, num_input, kernel, num_cell * 3, 0);
+
+ // Recurrent weight tensors of size {n_cell, n_output}.
+ CopySubArrayToArray(
+ model, &(lstm_cell_op->inputs[kRecurrentToInputWeightsTensor]),
+ base_name + "weight_i_r", num_cell, num_output, kernel, 0, num_input);
+ CopySubArrayToArray(model,
+ &(lstm_cell_op->inputs[kRecurrentToCellWeightsTensor]),
+ base_name + "weight_c_r", num_cell, num_output, kernel,
+ num_cell, num_input);
+ CopySubArrayToArray(model,
+ &(lstm_cell_op->inputs[kRecurrentToForgetWeightsTensor]),
+ base_name + "weight_f_r", num_cell, num_output, kernel,
+ num_cell * 2, num_input);
+ CopySubArrayToArray(model,
+ &(lstm_cell_op->inputs[kRecurrentToOutputWeightsTensor]),
+ base_name + "weight_o_r", num_cell, num_output, kernel,
+ num_cell * 3, num_input);
+
+ // Peephole (optional).
+ CreateOptionalArray(model, &(lstm_cell_op->inputs[kCellToInputWeightsTensor]),
+ base_name + "peephole_c_i");
+ CreateOptionalArray(model,
+ &(lstm_cell_op->inputs[kCellToForgetWeightsTensor]),
+ base_name + "peephole_c_f");
+ CreateOptionalArray(model,
+ &(lstm_cell_op->inputs[kCellToOutputWeightsTensor]),
+ base_name + "peephole_c_o");
+
+ // Get original bias tensor and decompose 1 tensor to 4 sub tensors
+ Array& bias =
+ model->GetArray(curr_op->inputs[LstmCellOperator::BIASES_INPUT]);
+ CopySubArrayToArray(model, &(lstm_cell_op->inputs[kInputGateBiasTensor]),
+ base_name + "bias_i", num_cell, 1, bias, 0, 0);
+ CopySubArrayToArray(model, &(lstm_cell_op->inputs[kCellGateBiasTensor]),
+ base_name + "bias_c", num_cell, 1, bias, num_cell, 0);
+ CopySubArrayToArray(model, &(lstm_cell_op->inputs[kForgetGateBiasTensor]),
+ base_name + "bias_f", num_cell, 1, bias, num_cell * 2, 0);
+ CopySubArrayToArray(model, &(lstm_cell_op->inputs[kOutputGateBiasTensor]),
+ base_name + "bias_o", num_cell, 1, bias, num_cell * 3, 0);
+
+ // Projection (optional).
+ CreateOptionalArray(model, &(lstm_cell_op->inputs[kProjectionWeightsTensor]),
+ base_name + "proj_weight");
+ CreateOptionalArray(model, &(lstm_cell_op->inputs[kProjectionBiasTensor]),
+ base_name + "proj_bias");
+
+ // Reorder LstmCell's outputs.
+ lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS);
+ lstm_cell_op->outputs[kScratchBufferTensor] =
+ curr_op->outputs[LstmCellOperator::CONCAT_TEMP];
+ lstm_cell_op->outputs[kOutputStateTensor] =
+ curr_op->outputs[LstmCellOperator::ACTIV_TEMP];
+ lstm_cell_op->outputs[kCellStateTensor] =
+ curr_op->outputs[LstmCellOperator::STATE_OUTPUT];
+ lstm_cell_op->outputs[kOutputTensor] =
+ curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT];
+
+ // Add the op into model.
+ model->operators.emplace(op_it, std::move(lstm_cell_op));
+ AddMessageF("Creating extended LstmCell replacing previous lstm cell");
+
+ // Delete arrays and operators replaced by the LSTM cell operator. Order is
+ // important - DeleteArrayIfUnused() only succeeds if dependent operators
+ // have been removed first. Start at the output and work towards the input.
+ // Erase curr lstm op being replaced.
+ DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT], model);
+ DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::BIASES_INPUT], model);
+ DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT],
+ model);
+ DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT],
+ model);
+ model->operators.erase(FindOp(*model, curr_op));
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.cc b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.cc
new file mode 100644
index 0000000000..910a960589
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.cc
@@ -0,0 +1,97 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h"
+
+namespace toco {
+
+void CreateOptionalArray(Model* model, string* input_array_buffer,
+ const string& array_name) {
+ *input_array_buffer = array_name;
+ model->CreateOptionalArray(array_name);
+}
+
+void CopyArrayData(const Buffer<ArrayDataType::kFloat>& src_buffer,
+ int src_stride, int src_start_idx1, int src_start_idx2,
+ Buffer<ArrayDataType::kFloat>* dst_buffer, int dst_stride,
+ int dst_start_idx1, int dst_start_idx2, int dim1_copy_size,
+ int dim2_copy_size) {
+ int src_offset = src_start_idx1 * src_stride + src_start_idx2;
+ int dst_offset = dst_start_idx1 * dst_stride + dst_start_idx2;
+ for (int i = 0; i < dim1_copy_size; i++) {
+ for (int j = 0; j < dim2_copy_size; j++) {
+ int idx_src = src_offset + i * src_stride + j;
+ int idx_dst = dst_offset + i * dst_stride + j;
+ dst_buffer->data[idx_dst] = src_buffer.data[idx_src];
+ }
+ }
+}
+
+Buffer<ArrayDataType::kFloat>* CreateFloatArrayBuffer(Model* model,
+ string* array_name,
+ const Shape& shape) {
+ *array_name = AvailableArrayName(*model, *array_name);
+ auto& array = model->GetOrCreateArray(*array_name);
+ array.data_type = ArrayDataType::kFloat;
+ array.copy_shape(shape);
+ Buffer<ArrayDataType::kFloat>* buffer =
+ &(array.GetMutableBuffer<ArrayDataType::kFloat>());
+ buffer->data.resize(RequiredBufferSizeForShape(shape));
+ return buffer;
+}
+
+void CopySubArrayToArray(Model* model, string* array_name,
+ const string& tensor_name, int dim1_size,
+ int dim2_size, const Array& original_array,
+ int start_idx1, int start_idx2) {
+ // Determine whether it's bias or not, create shape, buffer.
+ bool is_bias = dim2_size == 1;
+ Shape shape = is_bias ? Shape({dim1_size}) : Shape({dim1_size, dim2_size});
+ Buffer<ArrayDataType::kFloat>* buffer =
+ CreateFloatArrayBuffer(model, array_name, shape);
+ auto& orig_buffer = original_array.GetBuffer<ArrayDataType::kFloat>();
+
+ // Copy data from big tensor.
+ CopyArrayData(orig_buffer, is_bias ? 1 : original_array.shape().dims(1),
+ start_idx1, start_idx2, buffer, dim2_size, 0, 0, dim1_size,
+ dim2_size);
+}
+
+void CopyArrayToSubArray(Buffer<ArrayDataType::kFloat>& tensor_buffer,
+ int tensor_stride, const Array& sub_array,
+ int start_idx1, int start_idx2) {
+ // Get tensor data.
+ bool is_bias = sub_array.shape().dims().size() == 1;
+ int dim1_copy_size = sub_array.shape().dims()[0];
+ int dim2_copy_size = is_bias ? 1 : sub_array.shape().dims(1);
+ auto& sub_buffer = sub_array.GetBuffer<ArrayDataType::kFloat>();
+
+ // Copy data from sub tensor.
+ CopyArrayData(sub_buffer, dim2_copy_size, 0, 0, &tensor_buffer,
+ is_bias ? 1 : tensor_stride, start_idx1, start_idx2,
+ dim1_copy_size, dim2_copy_size);
+}
+
+bool GetMatchingRnnArray(Model* model, const string& back_edge_source_array,
+ string* rnn_array) {
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (rnn_state.back_edge_source_array() == back_edge_source_array) {
+ *rnn_array = rnn_state.state_array();
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h
new file mode 100644
index 0000000000..881c2d4dc8
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h
@@ -0,0 +1,102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iostream>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+// For consistency with the parameters defined in extended LstmCell's kernel
+// (tensorflow/contrib/lite/kernels/lstm.cc),
+// use lowercase for these constants.
+
+enum ExtendedLstmCellInputs {
+ kInputTensor = 0,
+ kInputToInputWeightsTensor = 1, // Optional
+ kInputToForgetWeightsTensor = 2,
+ kInputToCellWeightsTensor = 3,
+ kInputToOutputWeightsTensor = 4,
+ kRecurrentToInputWeightsTensor = 5, // Optional
+ kRecurrentToForgetWeightsTensor = 6,
+ kRecurrentToCellWeightsTensor = 7,
+ kRecurrentToOutputWeightsTensor = 8,
+ kCellToInputWeightsTensor = 9, // Optional
+ kCellToForgetWeightsTensor = 10, // Optional
+ kCellToOutputWeightsTensor = 11, // Optional
+ kInputGateBiasTensor = 12, // Optional
+ kForgetGateBiasTensor = 13,
+ kCellGateBiasTensor = 14,
+ kOutputGateBiasTensor = 15,
+ kProjectionWeightsTensor = 16, // Optional
+ kProjectionBiasTensor = 17, // Optional
+ kExtendedLstmInputCount = 18
+};
+
+enum ExtendedLstmCellOutputs {
+ kScratchBufferTensor = 0,
+ kOutputStateTensor = 1,
+ kCellStateTensor = 2,
+ kOutputTensor = 3
+};
+
+// Create optional array used for optional tensor in ExtendedLstmCell inputs.
+void CreateOptionalArray(Model* model, string* input_array_buffer,
+ const string& array_name);
+
+// Create float array and get its buffer.
+Buffer<ArrayDataType::kFloat>* CreateFloatArrayBuffer(Model* model,
+ string* array_name,
+ const Shape& shape);
+
+// Copy data from one array to the other one (supports 1D and 2D array),
+// for 1D array, the 2nd dim's size is 1.
+// Arguments:
+// src_buffer: the source buffer
+// src_stride: the stride of source buffer, i.e., 2nd dim's size
+// src_start_idx1: the 1st dim index of start point in src matrix
+// src_start_idx2: the 2nd dim index of start point in src matrix
+// dst_buffer: the destination buffer
+// dst_stride: the stride of destination buffer, i.e., 2nd dim's size
+// dst_start_idx1: the 1st dim index of start point in dst matrix
+// dst_start_idx2: the 2nd dim index of start point in dst matrix
+// dim1_copy_size: 1st dim size of copy data
+// dim2_copy_size: 2nd dim size of copy data
+void CopyArrayData(const Buffer<ArrayDataType::kFloat>& src_buffer,
+ int src_stride, int src_start_idx1, int src_start_idx2,
+ Buffer<ArrayDataType::kFloat>* dst_buffer, int dst_stride,
+ int dst_start_idx1, int dst_start_idx2, int dim1_copy_size,
+ int dim2_copy_size);
+
+// Copy a subset of array data and create a smaller array,
+// mostly used for spliting weights and bias for Lstm cell.
+void CopySubArrayToArray(Model* model, string* array_name,
+ const string& tensor_name, int dim1_size,
+ int dim2_size, const Array& original_array,
+ int start_idx1, int start_idx2);
+
+// Copy array data to a large array's submatrix,
+// mostly used for merging weights and bias for Lstm cell.
+void CopyArrayToSubArray(Buffer<ArrayDataType::kFloat>& tensor_buffer,
+ int tensor_stride, const Array& sub_array,
+ int start_idx1, int start_idx2);
+
+// Get mating rnn array inputs using rnn_states flag.
+bool GetMatchingRnnArray(Model* model, const string& back_edge_source_array,
+ string* rnn_array);
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index fa7e70d90b..2a44849ffa 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -728,9 +728,8 @@ void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
}
void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
- // I/O arrays should be allocated on creation of op.
- QCHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS);
- QCHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS);
+ // Only required for compact LstmCell with default NUM_INPUTS of inputs.
+ if (op->inputs.size() != LstmCellOperator::NUM_INPUTS) return;
const auto& input_array =
model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
index 8931498782..2f94f9cd8a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
@@ -18,6 +18,17 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "lstm_utils_test",
+ srcs = ["lstm_utils_test.cc"],
+ deps = [
+ "//tensorflow/contrib/lite/toco:graph_transformations",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/lstm_utils_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/lstm_utils_test.cc
new file mode 100644
index 0000000000..6aae0775d3
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/lstm_utils_test.cc
@@ -0,0 +1,442 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <tuple>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+namespace {
+
+// A gmock matcher that check that elements of a float vector match to a given
+// tolerance.
+std::vector<testing::Matcher<float>> ArrayFloatNear(
+ const std::vector<float>& values, float max_abs_error = 1e-5) {
+ std::vector<testing::Matcher<float>> matchers;
+ matchers.reserve(values.size());
+ for (const float& v : values) {
+ matchers.emplace_back(testing::FloatNear(v, max_abs_error));
+ }
+ return matchers;
+}
+} // namespace
+
+class CopyArrayDataTest : public ::testing::Test {
+ public:
+ CopyArrayDataTest() {}
+
+ void PrepareBuffers(Model* model, std::initializer_list<float> src_data,
+ int src_dim_1, int src_dim_2,
+ std::initializer_list<float> dst_data, int dst_dim_1,
+ int dst_dim_2) {
+ string src_array = "src_array";
+ src_buffer_ = CreateFloatArrayBuffer(
+ model, &src_array,
+ src_dim_2 == 1 ? Shape({src_dim_1}) : Shape({src_dim_1, src_dim_2}));
+ PopulateBuffer(src_buffer_, src_data);
+ string dst_array = "dst_array";
+ dst_buffer_ = CreateFloatArrayBuffer(
+ model, &dst_array,
+ dst_dim_2 == 1 ? Shape({dst_dim_1}) : Shape({dst_dim_1, dst_dim_2}));
+ PopulateBuffer(dst_buffer_, dst_data);
+ }
+
+ Buffer<ArrayDataType::kFloat>* GetSrcBuffer() { return src_buffer_; }
+ Buffer<ArrayDataType::kFloat>* GetDstBuffer() { return dst_buffer_; }
+
+ void PopulateBuffer(Buffer<ArrayDataType::kFloat>* buffer,
+ const std::vector<float>& init_data) {
+ for (int i = 0; i < init_data.size(); i++) {
+ buffer->data[i] = init_data[i];
+ }
+ }
+ void UpdateBuffer(Buffer<ArrayDataType::kFloat>* buffer,
+ std::initializer_list<float> data) {
+ buffer->data.resize(data.size());
+ PopulateBuffer(buffer, data);
+ }
+
+ private:
+ Buffer<ArrayDataType::kFloat>* src_buffer_;
+ Buffer<ArrayDataType::kFloat>* dst_buffer_;
+};
+
+// Copy from 1 big 2D array to 8 smaller ones.
+TEST_F(CopyArrayDataTest, CopyFromBigArrayToSmallerArrayes2D) {
+ // Init src_buffer, dst_buffer.
+ Model model;
+ std::initializer_list<float> large_tf_weight_data = {
+ -0.320407, -0.108683, 0.406358, -0.410811, -0.285786, -0.15769,
+ -0.194201, 0.170866, 0.084135, 0.201878, 0.21519, -0.284458,
+ 0.495906, -0.073818, 0.045578, 0.149816, -0.447073, -0.453578,
+ 0.116766, 0.21808, 0.047326, -0.001985, 0.402193, 0.315517,
+ 0.38258, 0.43599, 0.11986, 0.465195, 0.33548, -0.118789,
+ -0.414159, 0.049269, 0.156108, 0.093459, -0.129103, -0.086274,
+ 0.186188, -0.324923, 0.4117, -0.344439, 0.240465, -0.343331,
+ -0.463082, -0.231706, -0.487465, -0.186592, -0.020756, -0.239007,
+ 0.364817, 0.459106, -0.171447, -0.006542, 0.204032, -0.375317,
+ -0.041911, 0.051664, 0.320483, 0.155899, 0.156555, -0.249823,
+ -0.353107, 0.031563, -0.340771, -0.052532, 0.134631, -0.257957,
+ -0.50141, 0.486939, -0.43853, 0.268426, -0.08754, -0.109447,
+ -0.502462, -0.028055, -0.121838, -0.046016, 0.105309, -0.070774,
+ 0.495683, -0.475088, 0.048654, -0.38582, 0.411018, -0.315606,
+ 0.349628, 0.21698, 0.258989, -0.097902, 0.331218, 0.034602,
+ 0.418069, -0.089025, -0.417513, 0.07609, 0.393821, 0.404733,
+ -0.055418, -0.43903, -0.447049, 0.013125, 0.278503, 0.459869,
+ 0.143755, -0.177335, -0.162247, -0.432371, 0.153714, -0.047403,
+ -0.446775, -0.418363, 0.019743, 0.042025};
+ std::initializer_list<float> tflite_lstm_input_weight = {0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0};
+ PrepareBuffers(&model, large_tf_weight_data, /*src_dim_1=*/16,
+ /*src_dim_2=*/7, tflite_lstm_input_weight,
+ /*dst_dim_1=*/4, /*dst_dim_2=*/3);
+
+ // Copy src starts at (0,0), size (4,3).
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/7, /*src_start_idx1=*/0,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3,
+ /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/3);
+ std::vector<float> expected = {-0.320407, -0.108683, 0.406358, 0.170866,
+ 0.084135, 0.201878, 0.045578, 0.149816,
+ -0.447073, -0.001985, 0.402193, 0.315517};
+ EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
+
+ // Copy src starts at (4,0), size (4,3).
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/7, /*src_start_idx1=*/4,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3,
+ /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/3);
+ expected = {0.33548, -0.118789, -0.414159, -0.086274, 0.186188, -0.324923,
+ -0.463082, -0.231706, -0.487465, 0.459106, -0.171447, -0.006542};
+ EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
+
+ // Copy src starts at (8,0), size (4,3).
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/7, /*src_start_idx1=*/8,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3,
+ /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/3);
+ expected = {0.320483, 0.155899, 0.156555, -0.052532, 0.134631, -0.257957,
+ -0.08754, -0.109447, -0.502462, -0.070774, 0.495683, -0.475088};
+ EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
+
+ // Copy src starts at (12,0), size (4,3).
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/7, /*src_start_idx1=*/12,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3,
+ /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/3);
+ expected = {0.349628, 0.21698, 0.258989, -0.089025, -0.417513, 0.07609,
+ -0.447049, 0.013125, 0.278503, -0.432371, 0.153714, -0.047403};
+ EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
+
+ // New dst_buffer with size 16.
+ std::initializer_list<float> tflite_lstm_recurrent_weight = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+ PrepareBuffers(&model, large_tf_weight_data, /*src_dim_1=*/16,
+ /*src_dim_2=*/7, tflite_lstm_recurrent_weight,
+ /*dst_dim_1=*/4, /*dst_dim_2=*/4);
+
+ // Copy src starts at (0,3), size (4,4).
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/7, /*src_start_idx1=*/0,
+ /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4,
+ /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/4);
+ expected = {-0.410811, -0.285786, -0.15769, -0.194201, 0.21519, -0.284458,
+ 0.495906, -0.073818, -0.453578, 0.116766, 0.21808, 0.047326,
+ 0.38258, 0.43599, 0.11986, 0.465195};
+ EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
+
+ // Copy src starts at (4,3), size (4,4).
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/7, /*src_start_idx1=*/4,
+ /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4,
+ /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/4);
+ expected = {0.049269, 0.156108, 0.093459, -0.129103, 0.4117, -0.344439,
+ 0.240465, -0.343331, -0.186592, -0.020756, -0.239007, 0.364817,
+ 0.204032, -0.375317, -0.041911, 0.051664};
+ EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
+
+ // Copy src starts at (8,3), size (4,4).
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/7, /*src_start_idx1=*/8,
+ /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4,
+ /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/4);
+ expected = {-0.249823, -0.353107, 0.031563, -0.340771, -0.50141, 0.486939,
+ -0.43853, 0.268426, -0.028055, -0.121838, -0.046016, 0.105309,
+ 0.048654, -0.38582, 0.411018, -0.315606};
+ EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
+
+ // Copy src starts at (12,3), size (4,4).
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/7, /*src_start_idx1=*/12,
+ /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4,
+ /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/4);
+ expected = {-0.097902, 0.331218, 0.034602, 0.418069, 0.393821, 0.404733,
+ -0.055418, -0.43903, 0.459869, 0.143755, -0.177335, -0.162247,
+ -0.446775, -0.418363, 0.019743, 0.042025};
+ EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
+}
+
+// Copy from 1 big 1D array to 4 small ones.
+TEST_F(CopyArrayDataTest, CopyFromBigArrayToSmallerArrayes1D) {
+ // Init src_buffer, dst_buffer.
+ Model model;
+ std::initializer_list<float> large_tf_bias_data = {
+ 0.980304, 0.419808, 0.080278, 0.728548, 0.581674, 0.672433,
+ 0.434190, 0.844357, 0.229587, 0.785629, 0.022065, 0.753082,
+ 0.422080, 0.539481, 0.878386, 0.168965};
+ std::initializer_list<float> tflite_lstm_i_bias = {0, 0, 0, 0};
+ PrepareBuffers(&model, large_tf_bias_data, /*src_dim_1=*/16,
+ /*src_dim_2=*/1, tflite_lstm_i_bias,
+ /*dst_dim_1=*/4, /*dst_dim_2=*/1);
+
+ // Copy starts at (0,), size (4,).
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/1, /*src_start_idx1=*/0,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1,
+ /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/1);
+ std::vector<float> expected = {0.980304, 0.419808, 0.080278, 0.728548};
+ EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
+
+ // Copy starts at (4,), size (4,).
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/1, /*src_start_idx1=*/4,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1,
+ /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/1);
+ expected = {0.581674, 0.672433, 0.434190, 0.844357};
+ EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
+
+ // Copy starts at (8,), size (4,).
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/1, /*src_start_idx1=*/8,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1,
+ /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/1);
+ expected = {0.229587, 0.785629, 0.022065, 0.753082};
+ EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
+
+ // Copy starts at (12,), size (4,).
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/1, /*src_start_idx1=*/12,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1,
+ /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/1);
+ expected = {0.422080, 0.539481, 0.878386, 0.168965};
+ EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
+}
+
+// Copy from 8 small 2D arrayes to 1 big one.
+TEST_F(CopyArrayDataTest, CopyFromSmallArrayesToBigArray2D) {
+ // Init src_buffer, dst_buffer.
+ Model model;
+ std::initializer_list<float> large_tf_weights_data = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+
+ // Copy dst starts (0, 0), size (4, 3).
+ std::initializer_list<float> tflite_lstm_i2i_weight = {
+ -0.320407, -0.108683, 0.406358, 0.170866, 0.084135, 0.201878,
+ 0.045578, 0.149816, -0.447073, -0.001985, 0.402193, 0.315517};
+ PrepareBuffers(&model, tflite_lstm_i2i_weight, /*src_dim_1=*/4,
+ /*src_dim_2=*/3, large_tf_weights_data,
+ /*dst_dim_1=*/16, /*dst_dim_2=*/7);
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/3, /*src_start_idx1=*/0,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7,
+ /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/3);
+
+ // Copy dst starts (4, 0), size (4, 3).
+ std::initializer_list<float> tflite_lstm_i2c_weight = {
+ 0.33548, -0.118789, -0.414159, -0.086274, 0.186188, -0.324923,
+ -0.463082, -0.231706, -0.487465, 0.459106, -0.171447, -0.006542};
+ PopulateBuffer(GetSrcBuffer(), tflite_lstm_i2c_weight);
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/3, /*src_start_idx1=*/0,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7,
+ /*dst_start_idx1=*/4, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/3);
+
+ // Copy dst starts (8, 0), size (4, 3).
+ std::initializer_list<float> tflite_lstm_i2f_weight = {
+ 0.320483, 0.155899, 0.156555, -0.052532, 0.134631, -0.257957,
+ -0.08754, -0.109447, -0.502462, -0.070774, 0.495683, -0.475088};
+ PopulateBuffer(GetSrcBuffer(), tflite_lstm_i2f_weight);
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/3, /*src_start_idx1=*/0,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7,
+ /*dst_start_idx1=*/8, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/3);
+
+ // Copy dst starts (12, 0), size (4, 3).
+ std::initializer_list<float> tflite_lstm_i2o_weight = {
+ 0.349628, 0.21698, 0.258989, -0.089025, -0.417513, 0.07609,
+ -0.447049, 0.013125, 0.278503, -0.432371, 0.153714, -0.047403};
+ PopulateBuffer(GetSrcBuffer(), tflite_lstm_i2o_weight);
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/3, /*src_start_idx1=*/0,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7,
+ /*dst_start_idx1=*/12, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/3);
+
+ // Copy dst starts (0, 3), size (4, 4).
+ std::initializer_list<float> tflite_lstm_i2r_weight = {
+ -0.410811, -0.285786, -0.15769, -0.194201, 0.21519, -0.284458,
+ 0.495906, -0.073818, -0.453578, 0.116766, 0.21808, 0.047326,
+ 0.38258, 0.43599, 0.11986, 0.465195};
+ UpdateBuffer(GetSrcBuffer(), tflite_lstm_i2r_weight);
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/4, /*src_start_idx1=*/0,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7,
+ /*dst_start_idx1=*/0, /*dst_start_idx2=*/3,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/4);
+
+ // Copy dst starts (4, 3), size (4, 4).
+ std::initializer_list<float> tflite_lstm_c2r_weight = {
+ 0.049269, 0.156108, 0.093459, -0.129103, 0.4117, -0.344439,
+ 0.240465, -0.343331, -0.186592, -0.020756, -0.239007, 0.364817,
+ 0.204032, -0.375317, -0.041911, 0.051664};
+ PopulateBuffer(GetSrcBuffer(), tflite_lstm_c2r_weight);
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/4, /*src_start_idx1=*/0,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7,
+ /*dst_start_idx1=*/4, /*dst_start_idx2=*/3,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/4);
+
+ // Copy dst starts (8, 3), size (4, 4).
+ std::initializer_list<float> tflite_lstm_f2r_weight = {
+ -0.249823, -0.353107, 0.031563, -0.340771, -0.50141, 0.486939,
+ -0.43853, 0.268426, -0.028055, -0.121838, -0.046016, 0.105309,
+ 0.048654, -0.38582, 0.411018, -0.315606};
+ PopulateBuffer(GetSrcBuffer(), tflite_lstm_f2r_weight);
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/4, /*src_start_idx1=*/0,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7,
+ /*dst_start_idx1=*/8, /*dst_start_idx2=*/3,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/4);
+
+ // Copy dst starts (12, 3), size (4, 4).
+ std::initializer_list<float> tflite_lstm_o2r_weight = {
+ -0.097902, 0.331218, 0.034602, 0.418069, 0.393821, 0.404733,
+ -0.055418, -0.43903, 0.459869, 0.143755, -0.177335, -0.162247,
+ -0.446775, -0.418363, 0.019743, 0.042025};
+ PopulateBuffer(GetSrcBuffer(), tflite_lstm_o2r_weight);
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/4, /*src_start_idx1=*/0,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7,
+ /*dst_start_idx1=*/12, /*dst_start_idx2=*/3,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/4);
+
+ std::vector<float> expected = {
+ -0.320407, -0.108683, 0.406358, -0.410811, -0.285786, -0.15769,
+ -0.194201, 0.170866, 0.084135, 0.201878, 0.21519, -0.284458,
+ 0.495906, -0.073818, 0.045578, 0.149816, -0.447073, -0.453578,
+ 0.116766, 0.21808, 0.047326, -0.001985, 0.402193, 0.315517,
+ 0.38258, 0.43599, 0.11986, 0.465195, 0.33548, -0.118789,
+ -0.414159, 0.049269, 0.156108, 0.093459, -0.129103, -0.086274,
+ 0.186188, -0.324923, 0.4117, -0.344439, 0.240465, -0.343331,
+ -0.463082, -0.231706, -0.487465, -0.186592, -0.020756, -0.239007,
+ 0.364817, 0.459106, -0.171447, -0.006542, 0.204032, -0.375317,
+ -0.041911, 0.051664, 0.320483, 0.155899, 0.156555, -0.249823,
+ -0.353107, 0.031563, -0.340771, -0.052532, 0.134631, -0.257957,
+ -0.50141, 0.486939, -0.43853, 0.268426, -0.08754, -0.109447,
+ -0.502462, -0.028055, -0.121838, -0.046016, 0.105309, -0.070774,
+ 0.495683, -0.475088, 0.048654, -0.38582, 0.411018, -0.315606,
+ 0.349628, 0.21698, 0.258989, -0.097902, 0.331218, 0.034602,
+ 0.418069, -0.089025, -0.417513, 0.07609, 0.393821, 0.404733,
+ -0.055418, -0.43903, -0.447049, 0.013125, 0.278503, 0.459869,
+ 0.143755, -0.177335, -0.162247, -0.432371, 0.153714, -0.047403,
+ -0.446775, -0.418363, 0.019743, 0.042025};
+
+ EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
+}
+
+// Copy from 4 small 1D arrayes to 1 big one.
+TEST_F(CopyArrayDataTest, CopyFromSmallArrayesToBigArray1D) {
+ // Init src_buffer, dst_buffer.
+ Model model;
+ std::initializer_list<float> large_tf_bias_data = {0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0};
+
+ std::initializer_list<float> tflite_lstm_i_bias = {0.980304, 0.419808,
+ 0.080278, 0.728548};
+
+ PrepareBuffers(&model, tflite_lstm_i_bias, /*src_dim_1=*/4,
+ /*src_dim_2=*/1, large_tf_bias_data,
+ /*dst_dim_1=*/16, /*dst_dim_2=*/1);
+
+ // Copy starts at (0,), size (4,).
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/1, /*src_start_idx1=*/0,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1,
+ /*dst_start_idx1=*/0, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/1);
+
+ // Copy starts at (4,), size (4,).
+ std::initializer_list<float> tflite_lstm_cell_bias = {0.581674, 0.672433,
+ 0.434190, 0.844357};
+ PopulateBuffer(GetSrcBuffer(), tflite_lstm_cell_bias);
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/1, /*src_start_idx1=*/0,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1,
+ /*dst_start_idx1=*/4, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/1);
+
+ // Copy starts at (8,0), size (4,).
+ std::initializer_list<float> tflite_lstm_forget_bias = {0.229587, 0.785629,
+ 0.022065, 0.753082};
+ PopulateBuffer(GetSrcBuffer(), tflite_lstm_forget_bias);
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/1, /*src_start_idx1=*/0,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1,
+ /*dst_start_idx1=*/8, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/1);
+
+ // Copy starts at (12,), size (4,).
+ std::initializer_list<float> tflite_lstm_output_bias = {0.422080, 0.539481,
+ 0.878386, 0.168965};
+ PopulateBuffer(GetSrcBuffer(), tflite_lstm_output_bias);
+ CopyArrayData(*(GetSrcBuffer()),
+ /*src_stride=*/1, /*src_start_idx1=*/0,
+ /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1,
+ /*dst_start_idx1=*/12, /*dst_start_idx2=*/0,
+ /*dim1_copy_size=*/4, /*dim2_copy_size=*/1);
+
+ std::vector<float> expected = {0.980304, 0.419808, 0.080278, 0.728548,
+ 0.581674, 0.672433, 0.434190, 0.844357,
+ 0.229587, 0.785629, 0.022065, 0.753082,
+ 0.422080, 0.539481, 0.878386, 0.168965};
+
+ EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected)));
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 04aaedd59d..ff54b350bf 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -524,6 +524,28 @@ class Transpose
TocoOperator* op) const override {}
};
+class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
+ ::tflite::BuiltinOptions_LSTMOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ // Current toco converter only supports tanh, no clip.
+ return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
+ ::tflite::ActivationFunctionType_TANH,
+ /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ // Only support tanh activation, so check that tflite type is tanh.
+ CHECK(options.fused_activation_function() ==
+ ::tflite::ActivationFunctionType_TANH);
+ }
+};
+
class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions,
::tflite::BuiltinOptions_MeanOptions> {
public:
@@ -779,6 +801,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze));
ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE,
OperatorType::kStridedSlice));
+ ops.emplace_back(
+ new Lstm(::tflite::BuiltinOperator_LSTM, OperatorType::kLstmCell));
// Custom Operators.
ops.emplace_back(new Cast("CAST", OperatorType::kCast));
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index ac77284be0..47da8b68ea 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -107,7 +107,8 @@ bool SupportsFusedActivationFunction(FileFormat format) {
}
bool SupportsLstmCell(FileFormat format) {
- return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT);
+ return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT ||
+ format == TFLITE);
}
bool SupportsPreallocatedWorkspace(FileFormat format) {
@@ -225,9 +226,13 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
}
}
transformations.Add(new ConvertPureConvToDepthwise);
- // TFLite export does not yet support fused LSTM cell.
if (SupportsLstmCell(output_format)) {
transformations.Add(new IdentifyLstmCell);
+ if (output_format == TFLITE) {
+ transformations.Add(new toco::SplitLstmCellInputs);
+ } else {
+ transformations.Add(new toco::MergeLstmCellInputs);
+ }
}
transformations.Add(new ResolveConstantConcatenation);
RunGraphTransformations(model, "general graph transformations",