aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tflite
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-26 14:16:37 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-28 21:37:43 -0700
commit110ddc2103d7c86084ff52994998575113862542 (patch)
tree503f944630e7ea4d2cd9ca5fbca4621c7f555db6 /tensorflow/contrib/lite/toco/tflite
parent92221c68cdcf27607969089e5b6c06fdeeae8ae8 (diff)
Un-fused quantized Babelfish LSTM cell support in TFLite
including support for shuffled-weights fully-connected op. PiperOrigin-RevId: 202192299
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite')
-rw-r--r--tensorflow/contrib/lite/toco/tflite/import.cc2
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc35
2 files changed, 35 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc
index d1867bd4fa..1dd4915b31 100644
--- a/tensorflow/contrib/lite/toco/tflite/import.cc
+++ b/tensorflow/contrib/lite/toco/tflite/import.cc
@@ -221,6 +221,8 @@ std::unique_ptr<Model> Import(const ModelFlags& model_flags,
model.get());
ImportIOTensors(*input_model, tensors_table, model.get());
+ UndoWeightsShuffling(model.get());
+
return model;
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 290a925c1e..2d7a4a7a4c 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -314,16 +314,47 @@ class FullyConnected
flatbuffers::FlatBufferBuilder* builder) const override {
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
- return ::tflite::CreateFullyConnectedOptions(*builder, activation_function);
+ ::tflite::FullyConnectedOptionsWeightsFormat tflite_weights_format;
+ switch (op.weights_format) {
+ case FullyConnectedWeightsFormat::kDefault:
+ tflite_weights_format =
+ ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
+ break;
+ case FullyConnectedWeightsFormat::kShuffled4x16Int8:
+ tflite_weights_format =
+ ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
+ break;
+ default:
+ LOG(ERROR) << "Unhandled FC weights format";
+ tflite_weights_format =
+ ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
+ }
+ return ::tflite::CreateFullyConnectedOptions(*builder, activation_function,
+ tflite_weights_format);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
+ switch (options.weights_format()) {
+ case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT:
+ op->weights_format = FullyConnectedWeightsFormat::kDefault;
+ break;
+ case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
+ op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8;
+ break;
+ default:
+ LOG(ERROR) << "Unhandled FC weights format";
+ op->weights_format = FullyConnectedWeightsFormat::kDefault;
+ }
}
- int GetVersion(const Operator& op) const override { return 1; }
+ int GetVersion(const Operator& op) const override {
+ const auto& fc_op = static_cast<const FullyConnectedOperator&>(op);
+ return fc_op.weights_format == FullyConnectedWeightsFormat::kDefault ? 1
+ : 2;
+ }
};
class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,