aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-06-01 16:27:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-01 16:30:28 -0700
commitb31498a054d55ce328a2820fd403af764c482500 (patch)
tree91b8513149a36ae042e2a1b51f9e284701bbdcec /tensorflow/contrib/lite
parent73ec24e8b75ba4f73a06756502d8bf86b2a6828b (diff)
Support 5-inputs LSTM kernel in TFLite (float only).
PiperOrigin-RevId: 198943559
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h10
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc190
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc3
-rw-r--r--tensorflow/contrib/lite/model.cc8
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs12
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h52
-rw-r--r--tensorflow/contrib/lite/testing/BUILD1
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py13
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc25
-rw-r--r--tensorflow/contrib/lite/toco/args.h1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc8
-rw-r--r--tensorflow/contrib/lite/toco/model.h10
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc31
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc6
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto6
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc2
17 files changed, 355 insertions, 31 deletions
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index 52ab9ee640..c1cc4476fb 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -148,10 +148,20 @@ typedef struct {
float beta;
} TfLiteLocalResponseNormParams;
+typedef enum {
+ kTfLiteLSTMFullKernel = 0,
+ kTfLiteLSTMBasicKernel
+} TfLiteLSTMKernelType;
+
typedef struct {
+ // Parameters for LSTM version 1.
TfLiteFusedActivation activation;
float cell_clip;
float proj_clip;
+
+ // Parameters for LSTM version 2.
+ // kTfLiteLSTMBasicKernel is only supported in version 2 or above.
+ TfLiteLSTMKernelType kernel_type;
} TfLiteLSTMParams;
typedef struct {
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index 990b3da055..9aae3e571b 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -25,6 +25,8 @@ limitations under the License.
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
@@ -34,6 +36,17 @@ namespace ops {
namespace builtin {
namespace lstm {
+struct OpData {
+ // Which kernel type to use. Full kernel (18-inputs) or basic kernel
+ // (5-inputs).
+ TfLiteLSTMKernelType kernel_type;
+ // Only used by full kernel.
+ int scratch_tensor_index;
+};
+
+// For full inputs kernel (18-inputs).
+namespace full {
+
// Input Tensors of size {n_batch, n_input}
constexpr int kInputTensor = 0;
@@ -71,13 +84,10 @@ constexpr int kCellStateTensor = 1;
constexpr int kOutputTensor = 2;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- auto* scratch_tensor_index = new int;
- context->AddTensors(context, 1, scratch_tensor_index);
- return scratch_tensor_index;
-}
-
-void Free(TfLiteContext* context, void* buffer) {
- delete reinterpret_cast<int*>(buffer);
+ auto* op_data = new OpData;
+ op_data->kernel_type = kTfLiteLSTMFullKernel;
+ context->AddTensors(context, 1, &op_data->scratch_tensor_index);
+ return op_data;
}
// Check that input tensor dimensions matches with each other.
@@ -233,7 +243,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
// Allocate a temporary scratch tensor. Also check that the sizes of the input
// tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
@@ -289,7 +299,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Create a scratch buffer tensor.
TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(1);
- node->temporaries->data[0] = *scratch_tensor_index;
+ node->temporaries->data[0] = op_data->scratch_tensor_index;
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
scratch_buffer->type = input->type;
scratch_buffer->allocation_type = kTfLiteArenaRw;
@@ -447,6 +457,168 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+} // namespace full
+
+// For basic kernel (5-inputs).
+namespace basic {
+
+enum InputTensor {
+ kInputData = 0,
+ kInputPrevActivation = 1,
+ kInputWeights = 2,
+ kInputBiases = 3,
+ kInputPrevState = 4,
+ kInputNum = 5,
+};
+
+enum OutputTensor {
+ kOutputActivation = 0,
+ kOutputState = 1,
+ kOutputConcatTemp = 2,
+ kOutputActivationTemp = 3,
+ kOutputNum = 4,
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* op_data = new OpData;
+ op_data->kernel_type = kTfLiteLSTMBasicKernel;
+ // `scratch_tensor_index` is unused in this kernel.
+ op_data->scratch_tensor_index = -1;
+ return op_data;
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE(context, node->inputs->size == kInputNum);
+ TF_LITE_ENSURE(context, node->outputs->size == kOutputNum);
+
+ // Only Float32 is supportted currently.
+ // TODO(ycling): Implement quantize uint8 support.
+ for (int index = 0; index < node->inputs->size; ++index) {
+ TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
+ TF_LITE_ENSURE_EQ(context, tensor->type, kTfLiteFloat32);
+ }
+
+ const TfLiteTensor* input = GetInput(context, node, kInputData);
+ const TfLiteTensor* prev_activation =
+ GetInput(context, node, kInputPrevActivation);
+ const TfLiteTensor* weights = GetInput(context, node, kInputWeights);
+ const TfLiteTensor* bias = GetInput(context, node, kInputBiases);
+ const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState);
+
+ TF_LITE_ENSURE_EQ(context, input->dims->size, 2);
+ const int num_batches = input->dims->data[0];
+
+ TF_LITE_ENSURE_EQ(context, prev_activation->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, prev_activation->dims->data[0], num_batches);
+
+ TF_LITE_ENSURE_EQ(context, weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, bias->dims->size, 1);
+
+ TF_LITE_ENSURE_EQ(context, prev_state->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches);
+
+ TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation);
+ TfLiteTensor* state_out = GetOutput(context, node, kOutputState);
+ TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp);
+ TfLiteTensor* activation_temp =
+ GetOutput(context, node, kOutputActivationTemp);
+
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(
+ context, activation_out,
+ TfLiteIntArrayCopy(prev_activation->dims)));
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, state_out,
+ TfLiteIntArrayCopy(prev_state->dims)));
+ TfLiteIntArray* concat_temp_size = TfLiteIntArrayCreate(2);
+ concat_temp_size->data[0] = num_batches;
+ concat_temp_size->data[1] = weights->dims->data[1];
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, concat_temp, concat_temp_size));
+ TfLiteIntArray* activation_temp_size = TfLiteIntArrayCreate(2);
+ activation_temp_size->data[0] = num_batches;
+ activation_temp_size->data[1] = weights->dims->data[0];
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_temp,
+ activation_temp_size));
+
+ // Set the state tensors as persistent.
+ for (auto index : {kInputPrevActivation, kInputPrevState}) {
+ TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
+ tensor->allocation_type = kTfLiteArenaRwPersistent;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputData);
+ const TfLiteTensor* prev_activation =
+ GetInput(context, node, kInputPrevActivation);
+ const TfLiteTensor* weights = GetInput(context, node, kInputWeights);
+ const TfLiteTensor* bias = GetInput(context, node, kInputBiases);
+ const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState);
+
+ TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation);
+ TfLiteTensor* state_out = GetOutput(context, node, kOutputState);
+ TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp);
+ TfLiteTensor* activation_temp =
+ GetOutput(context, node, kOutputActivationTemp);
+
+ optimized_ops::LstmCell(
+ // Inputs.
+ GetTensorData<float>(input), GetTensorDims(input),
+ GetTensorData<float>(prev_activation), GetTensorDims(prev_activation),
+ GetTensorData<float>(weights), GetTensorDims(weights),
+ GetTensorData<float>(bias), GetTensorDims(bias),
+ GetTensorData<float>(prev_state), GetTensorDims(prev_state),
+ // Outputs.
+ GetTensorData<float>(state_out), GetTensorDims(state_out),
+ GetTensorData<float>(activation_out), GetTensorDims(activation_out),
+ GetTensorData<float>(concat_temp), GetTensorDims(concat_temp),
+ GetTensorData<float>(activation_temp), GetTensorDims(activation_temp));
+
+ // TODO(ycling): Investigate if this copy can be avoided with the 5-inputs
+ // LSTM kernel.
+ memcpy(prev_activation->data.raw, activation_out->data.raw,
+ activation_out->bytes);
+ memcpy(prev_state->data.raw, state_out->data.raw, state_out->bytes);
+
+ return kTfLiteOk;
+}
+
+} // namespace basic
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ const auto* params = reinterpret_cast<const TfLiteLSTMParams*>(buffer);
+ switch (params->kernel_type) {
+ case kTfLiteLSTMFullKernel:
+ return full::Init(context, buffer, length);
+ case kTfLiteLSTMBasicKernel:
+ return basic::Init(context, buffer, length);
+ }
+}
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const auto* op_data = reinterpret_cast<const OpData*>(node->user_data);
+ switch (op_data->kernel_type) {
+ case kTfLiteLSTMFullKernel:
+ return full::Prepare(context, node);
+ case kTfLiteLSTMBasicKernel:
+ return basic::Prepare(context, node);
+ }
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const auto* op_data = reinterpret_cast<const OpData*>(node->user_data);
+ switch (op_data->kernel_type) {
+ case kTfLiteLSTMFullKernel:
+ return full::Eval(context, node);
+ case kTfLiteLSTMBasicKernel:
+ return basic::Eval(context, node);
+ }
+}
+
} // namespace lstm
TfLiteRegistration* Register_LSTM() {
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index c7d72738d6..184b02dcec 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -126,7 +126,8 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION());
AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
Register_LOCAL_RESPONSE_NORMALIZATION());
- AddBuiltin(BuiltinOperator_LSTM, Register_LSTM());
+ AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1,
+ /* max_version */ 2);
AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
Register_BIDIRECTIONAL_SEQUENCE_LSTM());
AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index ca115a1c59..8d8d74adfb 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -558,6 +558,14 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
parse_activation(lstm_params->fused_activation_function());
params->cell_clip = lstm_params->cell_clip();
params->proj_clip = lstm_params->proj_clip();
+ switch (lstm_params->kernel_type()) {
+ case LSTMKernelType_FULL:
+ params->kernel_type = kTfLiteLSTMFullKernel;
+ break;
+ case LSTMKernelType_BASIC:
+ params->kernel_type = kTfLiteLSTMBasicKernel;
+ break;
+ }
}
*builtin_data = reinterpret_cast<void*>(params);
break;
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 7d76134e3d..7dbb36c864 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -315,11 +315,23 @@ table LocalResponseNormalizationOptions {
beta:float;
}
+enum LSTMKernelType : byte {
+ // Full LSTM kernel which supports peephole and projection.
+ FULL = 0,
+ // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell.
+ BASIC = 1,
+}
+
// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
table LSTMOptions {
+ // Parameters for LSTM version 1 or above.
fused_activation_function:ActivationFunctionType;
cell_clip: float; // Optional, 0.0 means no clipping
proj_clip: float; // Optional, 0.0 means no clipping
+
+ // Parameters for LSTM version 2 or above.
+ // Basic kernel is only supported in version 2 or above.
+ kernel_type: LSTMKernelType = FULL;
}
table ResizeBilinearOptions {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 0a60fcd3d0..b1beb39b28 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -1428,6 +1428,35 @@ inline const char *EnumNameLSHProjectionType(LSHProjectionType e) {
return EnumNamesLSHProjectionType()[index];
}
+enum LSTMKernelType {
+ LSTMKernelType_FULL = 0,
+ LSTMKernelType_BASIC = 1,
+ LSTMKernelType_MIN = LSTMKernelType_FULL,
+ LSTMKernelType_MAX = LSTMKernelType_BASIC
+};
+
+inline LSTMKernelType (&EnumValuesLSTMKernelType())[2] {
+ static LSTMKernelType values[] = {
+ LSTMKernelType_FULL,
+ LSTMKernelType_BASIC
+ };
+ return values;
+}
+
+inline const char **EnumNamesLSTMKernelType() {
+ static const char *names[] = {
+ "FULL",
+ "BASIC",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameLSTMKernelType(LSTMKernelType e) {
+ const size_t index = static_cast<int>(e);
+ return EnumNamesLSTMKernelType()[index];
+}
+
enum CombinerType {
CombinerType_SUM = 0,
CombinerType_MEAN = 1,
@@ -2865,10 +2894,12 @@ struct LSTMOptionsT : public flatbuffers::NativeTable {
ActivationFunctionType fused_activation_function;
float cell_clip;
float proj_clip;
+ LSTMKernelType kernel_type;
LSTMOptionsT()
: fused_activation_function(ActivationFunctionType_NONE),
cell_clip(0.0f),
- proj_clip(0.0f) {
+ proj_clip(0.0f),
+ kernel_type(LSTMKernelType_FULL) {
}
};
@@ -2877,7 +2908,8 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
enum {
VT_FUSED_ACTIVATION_FUNCTION = 4,
VT_CELL_CLIP = 6,
- VT_PROJ_CLIP = 8
+ VT_PROJ_CLIP = 8,
+ VT_KERNEL_TYPE = 10
};
ActivationFunctionType fused_activation_function() const {
return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
@@ -2888,11 +2920,15 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
float proj_clip() const {
return GetField<float>(VT_PROJ_CLIP, 0.0f);
}
+ LSTMKernelType kernel_type() const {
+ return static_cast<LSTMKernelType>(GetField<int8_t>(VT_KERNEL_TYPE, 0));
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<float>(verifier, VT_CELL_CLIP) &&
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
+ VerifyField<int8_t>(verifier, VT_KERNEL_TYPE) &&
verifier.EndTable();
}
LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -2912,6 +2948,9 @@ struct LSTMOptionsBuilder {
void add_proj_clip(float proj_clip) {
fbb_.AddElement<float>(LSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f);
}
+ void add_kernel_type(LSTMKernelType kernel_type) {
+ fbb_.AddElement<int8_t>(LSTMOptions::VT_KERNEL_TYPE, static_cast<int8_t>(kernel_type), 0);
+ }
explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -2928,10 +2967,12 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(
flatbuffers::FlatBufferBuilder &_fbb,
ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
float cell_clip = 0.0f,
- float proj_clip = 0.0f) {
+ float proj_clip = 0.0f,
+ LSTMKernelType kernel_type = LSTMKernelType_FULL) {
LSTMOptionsBuilder builder_(_fbb);
builder_.add_proj_clip(proj_clip);
builder_.add_cell_clip(cell_clip);
+ builder_.add_kernel_type(kernel_type);
builder_.add_fused_activation_function(fused_activation_function);
return builder_.Finish();
}
@@ -6226,6 +6267,7 @@ inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
{ auto _e = cell_clip(); _o->cell_clip = _e; };
{ auto _e = proj_clip(); _o->proj_clip = _e; };
+ { auto _e = kernel_type(); _o->kernel_type = _e; };
}
inline flatbuffers::Offset<LSTMOptions> LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -6239,11 +6281,13 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBuffe
auto _fused_activation_function = _o->fused_activation_function;
auto _cell_clip = _o->cell_clip;
auto _proj_clip = _o->proj_clip;
+ auto _kernel_type = _o->kernel_type;
return tflite::CreateLSTMOptions(
_fbb,
_fused_activation_function,
_cell_clip,
- _proj_clip);
+ _proj_clip,
+ _kernel_type);
}
inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 74fc32a12b..80e4c5a4dd 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -155,6 +155,7 @@ cc_library(
deps = [
":split",
":test_runner",
+ "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:builtin_ops",
],
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index f07e36fc7d..9bb7a4600d 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -118,6 +118,8 @@ class ExtraTocoOptions(object):
self.allow_custom_ops = False
# Rnn states that are used to support rnn / lstm cells.
self.rnn_states = None
+ # Split the LSTM inputs from 5 inoputs to 18 inputs for TFLite.
+ self.split_tflite_lstm_inputs = None
def toco_options(data_types,
@@ -155,6 +157,11 @@ def toco_options(data_types,
s += " --allow_custom_ops"
if extra_toco_options.rnn_states:
s += (" --rnn_states='" + extra_toco_options.rnn_states + "'")
+ if extra_toco_options.split_tflite_lstm_inputs is not None:
+ if extra_toco_options.split_tflite_lstm_inputs:
+ s += " --split_tflite_lstm_inputs=true"
+ else:
+ s += " --split_tflite_lstm_inputs=false"
return s
@@ -461,6 +468,11 @@ def make_zip_of_tests(zip_path,
sess,
tf.global_variables() + inputs +
outputs) if use_frozen_graph else sess.graph_def
+
+ if "split_tflite_lstm_inputs" in param_dict_real:
+ extra_toco_options.split_tflite_lstm_inputs = param_dict_real[
+ "split_tflite_lstm_inputs"]
+
tflite_model_binary, toco_log = toco_convert(
graph_def.SerializeToString(), input_tensors, output_tensors,
extra_toco_options)
@@ -2019,6 +2031,7 @@ def make_lstm_tests(zip_path):
"time_step_size": [1],
"input_vec_size": [3],
"num_cells": [4],
+ "split_tflite_lstm_inputs": [True, False],
},
]
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 8cab6cd8cd..fc28faf524 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <iostream>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/testing/split.h"
namespace tflite {
@@ -290,12 +291,24 @@ void TfLiteDriver::ResetLSTMStateTensors() {
const auto& node_and_reg = interpreter_->node_and_registration(node_index);
const auto& node = node_and_reg->first;
const auto& registration = node_and_reg->second;
- if (registration.builtin_code == tflite::BuiltinOperator_LSTM &&
- node.outputs->size >= 2) {
- // The first 2 outputs of LSTM are state tensors.
- for (int i = 0; i < 2; ++i) {
- int node_index = node.outputs->data[i];
- ResetTensor(node_index);
+
+ if (registration.builtin_code == tflite::BuiltinOperator_LSTM) {
+ const auto* params =
+ reinterpret_cast<const TfLiteLSTMParams*>(node.builtin_data);
+ if (params->kernel_type == kTfLiteLSTMFullKernel &&
+ node.outputs->size >= 2) {
+ // The first 2 outputs of LSTM are state tensors.
+ for (int i = 0; i < 2; ++i) {
+ int node_index = node.outputs->data[i];
+ ResetTensor(node_index);
+ }
+ } else if (params->kernel_type == kTfLiteLSTMBasicKernel &&
+ node.inputs->size == 5) {
+ // The 2th and 5th inputs are state tensors.
+ for (int i : {1, 4}) {
+ int node_index = node.inputs->data[i];
+ ResetTensor(node_index);
+ }
}
}
}
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index 6c0311af0a..77bc54f191 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -242,6 +242,7 @@ struct ParsedTocoFlags {
Arg<bool> propagate_fake_quant_num_bits = Arg<bool>(false);
Arg<bool> allow_nudging_weights_to_use_fast_gemm_kernel = Arg<bool>(false);
Arg<int64> dedupe_array_min_size_bytes = Arg<int64>(64);
+ Arg<bool> split_tflite_lstm_inputs = Arg<bool>(true);
};
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
index 3f768bfee1..5b6a984ee1 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
@@ -33,9 +33,10 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
return false;
}
- // Already a compact LstmCell with LstmCellOperator::NUM_INPUTS of inputs,
- // do not need to merge cell inputs.
- if (src_op->inputs.size() == LstmCellOperator::NUM_INPUTS) {
+ // Already a compact LstmCell. Do not need to merge cell inputs.
+ const auto* src_lstm_op = static_cast<LstmCellOperator*>(src_op);
+ if (src_lstm_op->kernel_type != LstmCellOperator::KERNEL_FULL ||
+ src_lstm_op->inputs.size() != kExtendedLstmInputCount) {
return false;
}
@@ -136,6 +137,7 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
// Emplace a new LSTM cell operator (use basic 5 inputs kernel).
auto lstm_cell_op = absl::make_unique<LstmCellOperator>();
+ lstm_cell_op->kernel_type = LstmCellOperator::KERNEL_BASIC;
// Compact LstmCell's 5 inputs.
lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
index 8e66323bd7..e6e3dfa1de 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
@@ -33,9 +33,10 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
return false;
}
- // Already an extended LstmCell with kExtendedLstmInputCount of inputs,
- // do not need to split cell inputs.
- if (curr_op->inputs.size() == kExtendedLstmInputCount) {
+ const auto* curr_lstm_op = static_cast<LstmCellOperator*>(curr_op);
+ // Already an extended LstmCell. Do not need to split cell inputs.
+ if (curr_lstm_op->kernel_type != LstmCellOperator::KERNEL_BASIC ||
+ curr_lstm_op->inputs.size() != LstmCellOperator::NUM_INPUTS) {
return false;
}
@@ -56,6 +57,7 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
// Emplace a new LstmCell operator with extended inputs (kernel/lstm.cc).
auto lstm_cell_op = absl::make_unique<LstmCellOperator>();
+ lstm_cell_op->kernel_type = LstmCellOperator::KERNEL_FULL;
lstm_cell_op->inputs.resize(kExtendedLstmInputCount);
int num_input = model->GetArray(curr_op->inputs[LstmCellOperator::DATA_INPUT])
.shape()
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 9062c03c73..1a4f87e363 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -527,7 +527,15 @@ struct LstmCellOperator : Operator {
ACTIV_TEMP = 3,
NUM_OUTPUTS = 4
};
- LstmCellOperator() : Operator(OperatorType::kLstmCell) {}
+ enum KernelType {
+ KERNEL_BASIC = 0,
+ KERNEL_FULL = 1,
+ };
+
+ LstmCellOperator()
+ : Operator(OperatorType::kLstmCell), kernel_type(KERNEL_BASIC) {}
+
+ KernelType kernel_type;
};
// Element-wise multiplication operator.
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 84a5410839..a8518adefc 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -626,11 +626,21 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
+ ::tflite::LSTMKernelType kernel_type;
+ switch (op.kernel_type) {
+ case LstmCellOperator::KERNEL_BASIC:
+ kernel_type = ::tflite::LSTMKernelType_BASIC;
+ break;
+ case LstmCellOperator::KERNEL_FULL:
+ kernel_type = ::tflite::LSTMKernelType_FULL;
+ break;
+ }
+
// Current toco converter only supports tanh, no clip.
return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
::tflite::ActivationFunctionType_TANH,
/*cell_clip=*/0.0,
- /*proj_clip=*/0.0);
+ /*proj_clip=*/0.0, kernel_type);
}
void ReadOptions(const TfLiteOptions& options,
@@ -638,9 +648,26 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
// Only support tanh activation, so check that tflite type is tanh.
CHECK(options.fused_activation_function() ==
::tflite::ActivationFunctionType_TANH);
+
+ switch (options.kernel_type()) {
+ case ::tflite::LSTMKernelType_BASIC:
+ op->kernel_type = LstmCellOperator::KERNEL_BASIC;
+ break;
+ case ::tflite::LSTMKernelType_FULL:
+ op->kernel_type = LstmCellOperator::KERNEL_FULL;
+ break;
+ }
}
- int GetVersion(const Operator& op) const override { return 1; }
+ int GetVersion(const Operator& op) const override {
+ const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
+ switch (lstm_op.kernel_type) {
+ case LstmCellOperator::KERNEL_FULL:
+ return 1;
+ case LstmCellOperator::KERNEL_BASIC:
+ return 2;
+ }
+ }
};
class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions,
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index 7786a4ada3..9c6ad673ab 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -153,6 +153,11 @@ bool ParseTocoFlagsFromCommandLineFlags(
parsed_flags.dedupe_array_min_size_bytes.default_value(),
"Minimum size of constant arrays to deduplicate; arrays smaller "
"will not be deduplicated."),
+ Flag("split_tflite_lstm_inputs",
+ parsed_flags.split_tflite_lstm_inputs.bind(),
+ parsed_flags.split_tflite_lstm_inputs.default_value(),
+ "Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. "
+ "Ignored if the output format is not TFLite."),
};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
@@ -245,6 +250,7 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
READ_TOCO_FLAG(allow_nudging_weights_to_use_fast_gemm_kernel,
FlagRequirement::kNone);
READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone);
+ READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
// Deprecated flag handling.
if (parsed_toco_flags.input_type.specified()) {
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
index 8589ca361d..15f755c104 100644
--- a/tensorflow/contrib/lite/toco/toco_flags.proto
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -37,7 +37,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
//
-// Next ID to use: 19.
+// Next ID to use: 20.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@@ -165,4 +165,8 @@ message TocoFlags {
// Minimum size of constant arrays to deduplicate; arrays smaller will not be
// deduplicated.
optional int64 dedupe_array_min_size_bytes = 18 [default = 64];
+
+ // Split the LSTM inputs from 5 tensors to 18 tensors for TFLite.
+ // Ignored if the output format is not TFLite.
+ optional bool split_tflite_lstm_inputs = 19 [default = true];
}
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index b5531ca2f4..a648883d1f 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -263,7 +263,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
if (!toco_flags.debug_disable_recurrent_cell_fusion()) {
transformations.Add(new IdentifyLstmCell);
}
- if (output_format == TFLITE) {
+ if (output_format == TFLITE && toco_flags.split_tflite_lstm_inputs()) {
transformations.Add(new toco::SplitLstmCellInputs);
} else {
transformations.Add(new toco::MergeLstmCellInputs);