diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-26 14:16:37 -0700 |
---|---|---|
committer | Gunhan Gulsoy <gunan@google.com> | 2018-06-28 21:37:43 -0700 |
commit | 110ddc2103d7c86084ff52994998575113862542 (patch) | |
tree | 503f944630e7ea4d2cd9ca5fbca4621c7f555db6 /tensorflow/contrib/lite/toco/tflite | |
parent | 92221c68cdcf27607969089e5b6c06fdeeae8ae8 (diff) |
Un-fused quantized Babelfish LSTM cell support in TFLite
including support for shuffled-weights fully-connected op.
PiperOrigin-RevId: 202192299
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite')
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/import.cc | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator.cc | 35 |
2 files changed, 35 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc index d1867bd4fa..1dd4915b31 100644 --- a/tensorflow/contrib/lite/toco/tflite/import.cc +++ b/tensorflow/contrib/lite/toco/tflite/import.cc @@ -221,6 +221,8 @@ std::unique_ptr<Model> Import(const ModelFlags& model_flags, model.get()); ImportIOTensors(*input_model, tensors_table, model.get()); + UndoWeightsShuffling(model.get()); + return model; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 290a925c1e..2d7a4a7a4c 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -314,16 +314,47 @@ class FullyConnected flatbuffers::FlatBufferBuilder* builder) const override { auto activation_function = ActivationFunction::Serialize(op.fused_activation_function); - return ::tflite::CreateFullyConnectedOptions(*builder, activation_function); + ::tflite::FullyConnectedOptionsWeightsFormat tflite_weights_format; + switch (op.weights_format) { + case FullyConnectedWeightsFormat::kDefault: + tflite_weights_format = + ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT; + break; + case FullyConnectedWeightsFormat::kShuffled4x16Int8: + tflite_weights_format = + ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8; + break; + default: + LOG(ERROR) << "Unhandled FC weights format"; + tflite_weights_format = + ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT; + } + return ::tflite::CreateFullyConnectedOptions(*builder, activation_function, + tflite_weights_format); } void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override { op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); + switch (options.weights_format()) { + case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT: + op->weights_format = FullyConnectedWeightsFormat::kDefault; + break; + case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: + op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8; + break; + default: + LOG(ERROR) << "Unhandled FC weights format"; + op->weights_format = FullyConnectedWeightsFormat::kDefault; + } } - int GetVersion(const Operator& op) const override { return 1; } + int GetVersion(const Operator& op) const override { + const auto& fc_op = static_cast<const FullyConnectedOperator&>(op); + return fc_op.weights_format == FullyConnectedWeightsFormat::kDefault ? 1 + : 2; + } }; class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions, |