diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-26 14:16:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-26 14:19:02 -0700 |
commit | 346611325add9da16d9a74b905228dc3068b30c1 (patch) | |
tree | 387c500172b0415edf64cd74c956b39ac3393a5c /tensorflow/contrib/lite/toco/tooling_util.cc | |
parent | 1081683bf67f353dacc34c220c808a0080281f7f (diff) |
Un-fused quantized Babelfish LSTM cell support in TFLite
including support for shuffled-weights fully-connected op.
PiperOrigin-RevId: 202192299
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/tooling_util.cc | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index a52c812ef4..3d9fa732bd 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -2200,4 +2200,51 @@ void UseArraysExtraInfo(Model* model, bool quantize_output) { } } +void UndoWeightsShuffling(Model* model) { + for (const auto& op : model->operators) { + if (op->type != toco::OperatorType::kFullyConnected) { + continue; + } + const auto& fc_op = static_cast<toco::FullyConnectedOperator&>(*op); + if (fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) { + continue; + } + const string& weights_name = fc_op.inputs[1]; + QCHECK_EQ(CountOpsWithInput(*model, weights_name), 1); + auto& weights_array = model->GetArray(weights_name); + QCHECK(weights_array.data_type == ArrayDataType::kUint8); + auto& weights_data = + weights_array.GetMutableBuffer<toco::ArrayDataType::kUint8>().data; + const auto& weights_shape = weights_array.shape(); + QCHECK_EQ(weights_shape.dimensions_count(), 2); + const int rows = weights_shape.dims(0); + const int cols = weights_shape.dims(1); + QCHECK_EQ(rows % 4, 0); + QCHECK_EQ(cols % 16, 0); + CHECK_EQ(rows * cols, weights_data.size()); + // Compute the de-shuffled weights + std::vector<uint8> deshuffled_data(weights_data.size()); + uint8* shuffled_data_ptr = weights_data.data(); + for (int r = 0; r < rows; r += 4) { + for (int c = 0; c < cols; c += 16) { + for (int i = 0; i < 4; i++) { + uint8* deshuffled_data_ptr = + deshuffled_data.data() + (r + i) * cols + c; + for (int j = 0; j < 16; j++) { + uint8 shuffled_val = *shuffled_data_ptr++; + // Deshuffling isn't only about deshuffling the storage layout, + // it's also about undoing the flipping of the sign bit, which is + // performed on the shuffled weights. + uint8 deshuffled_val = shuffled_val ^ 0x80; + *deshuffled_data_ptr++ = deshuffled_val; + } + } + } + } + CHECK_EQ(shuffled_data_ptr, weights_data.data() + rows * cols); + // Switch this FC op to using the deshuffled weights. + weights_data = std::move(deshuffled_data); + } +} + } // namespace toco |