aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tooling_util.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-26 14:16:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-26 14:19:02 -0700
commit346611325add9da16d9a74b905228dc3068b30c1 (patch)
tree387c500172b0415edf64cd74c956b39ac3393a5c /tensorflow/contrib/lite/toco/tooling_util.cc
parent1081683bf67f353dacc34c220c808a0080281f7f (diff)
Un-fused quantized Babelfish LSTM cell support in TFLite
including support for shuffled-weights fully-connected op. PiperOrigin-RevId: 202192299
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc47
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index a52c812ef4..3d9fa732bd 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -2200,4 +2200,51 @@ void UseArraysExtraInfo(Model* model, bool quantize_output) {
}
}
+void UndoWeightsShuffling(Model* model) {
+ for (const auto& op : model->operators) {
+ if (op->type != toco::OperatorType::kFullyConnected) {
+ continue;
+ }
+ const auto& fc_op = static_cast<toco::FullyConnectedOperator&>(*op);
+ if (fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) {
+ continue;
+ }
+ const string& weights_name = fc_op.inputs[1];
+ QCHECK_EQ(CountOpsWithInput(*model, weights_name), 1);
+ auto& weights_array = model->GetArray(weights_name);
+ QCHECK(weights_array.data_type == ArrayDataType::kUint8);
+ auto& weights_data =
+ weights_array.GetMutableBuffer<toco::ArrayDataType::kUint8>().data;
+ const auto& weights_shape = weights_array.shape();
+ QCHECK_EQ(weights_shape.dimensions_count(), 2);
+ const int rows = weights_shape.dims(0);
+ const int cols = weights_shape.dims(1);
+ QCHECK_EQ(rows % 4, 0);
+ QCHECK_EQ(cols % 16, 0);
+ CHECK_EQ(rows * cols, weights_data.size());
+ // Compute the de-shuffled weights
+ std::vector<uint8> deshuffled_data(weights_data.size());
+ uint8* shuffled_data_ptr = weights_data.data();
+ for (int r = 0; r < rows; r += 4) {
+ for (int c = 0; c < cols; c += 16) {
+ for (int i = 0; i < 4; i++) {
+ uint8* deshuffled_data_ptr =
+ deshuffled_data.data() + (r + i) * cols + c;
+ for (int j = 0; j < 16; j++) {
+ uint8 shuffled_val = *shuffled_data_ptr++;
+ // Deshuffling isn't only about deshuffling the storage layout,
+ // it's also about undoing the flipping of the sign bit, which is
+ // performed on the shuffled weights.
+ uint8 deshuffled_val = shuffled_val ^ 0x80;
+ *deshuffled_data_ptr++ = deshuffled_val;
+ }
+ }
+ }
+ }
+ CHECK_EQ(shuffled_data_ptr, weights_data.data() + rows * cols);
+ // Switch this FC op to using the deshuffled weights.
+ weights_data = std::move(deshuffled_data);
+ }
+}
+
} // namespace toco