aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-06 21:00:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-06 21:06:15 -0700
commit7fa6a6b42bc9d562e2b1cc765ca78d281b51f734 (patch)
tree19398616e50223b0876e6e9a987cc9908d404592 /tensorflow/contrib/lite/kernels
parente93a18954689b6d522560f5273f6d3320d545b2e (diff)
Add SequenceLSTMOptions to schema to decouple the sequential Op from the LSTM.
PiperOrigin-RevId: 216066634
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc11
2 files changed, 17 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index ec9cf38b83..89d57e4599 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -431,7 +431,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params =
+ reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
+ node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_to_input_weights =
@@ -482,6 +484,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ // Copy out the LSTM specific params so they can be passed in the function.
+ TfLiteLSTMParams lstm_params;
+ lstm_params.activation = params->activation;
+ lstm_params.cell_clip = params->cell_clip;
+ lstm_params.proj_clip = params->proj_clip;
+
switch (input_to_output_weights->type) {
case kTfLiteFloat32: {
return lstm_eval::EvalFloat(
@@ -496,7 +504,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
- projection_bias, params, /*forward_sequence=*/true,
+ projection_bias, &lstm_params, /*forward_sequence=*/true,
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
output);
}
@@ -523,7 +531,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
- projection_bias, params, /*forward_sequence=*/true,
+ projection_bias, &lstm_params, /*forward_sequence=*/true,
/*output_offset=*/0, scratch_buffer, scaling_factors,
prod_scaling_factors, recovered_cell_weights, input_quantized,
/*aux_input_quantized=*/nullptr, activation_state_quantized,
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
index cd3aac0532..c97b0fdd61 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
@@ -110,11 +110,12 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
output_ = AddOutput(TensorType_FLOAT32);
- SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
- BuiltinOptions_LSTMOptions,
- CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
- cell_clip, proj_clip)
- .Union());
+ SetBuiltinOp(
+ BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
+ BuiltinOptions_UnidirectionalSequenceLSTMOptions,
+ CreateUnidirectionalSequenceLSTMOptions(
+ builder_, ActivationFunctionType_TANH, cell_clip, proj_clip)
+ .Union());
BuildInterpreter(input_shapes);
}