aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-30 18:05:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-30 18:08:13 -0700
commit85d30bfcf412bd1ca06fa33548344bf40eedb4ac (patch)
tree5201c3fe5d1da4e0dd4379df588fd5216377b11a
parent40721422bfc9cec546537799e16dd75f443d2db2 (diff)
Internal change.
PiperOrigin-RevId: 194877173
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc70
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc49
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_test.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/optional_tensor_test.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc47
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc4
-rw-r--r--tensorflow/contrib/lite/models/speech_test.cc16
-rw-r--r--tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec20
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc9
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc7
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h8
12 files changed, 139 insertions, 107 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index a64ac42bc4..3ac0210f36 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -96,15 +96,23 @@ constexpr int kBwProjectionWeightsTensor = 33; // Optional
constexpr int kBwProjectionBiasTensor = 34; // Optional
// Output tensors.
-constexpr int kFwScratchBufferTensor = 0;
-constexpr int kFwOutputStateTensor = 1;
-constexpr int kFwCellStateTensor = 2;
-constexpr int kFwOutputTensor = 3;
+constexpr int kFwOutputStateTensor = 0;
+constexpr int kFwCellStateTensor = 1;
+constexpr int kFwOutputTensor = 2;
+
+constexpr int kBwOutputStateTensor = 3;
+constexpr int kBwCellStateTensor = 4;
+constexpr int kBwOutputTensor = 5;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, 2, scratch_tensor_index);
+ return scratch_tensor_index;
+}
-constexpr int kBwScratchBufferTensor = 4;
-constexpr int kBwOutputStateTensor = 5;
-constexpr int kBwCellStateTensor = 6;
-constexpr int kBwOutputTensor = 7;
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
// Check that input tensor dimensions matches with each other.
TfLiteStatus CheckLstmTensorDimensions(
@@ -296,9 +304,11 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
// Resize the output, state and scratch tensors based on the sizes of the input
// tensors. Also check that the size of the input tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 35);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 8);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 6);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
@@ -330,12 +340,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* fw_output_state =
GetOutput(context, node, kFwOutputStateTensor);
TfLiteTensor* fw_cell_state = GetOutput(context, node, kFwCellStateTensor);
- // TODO(ghodrat): Modify this as soon as we have a finalized method for
- // scratch buffers.
- TfLiteTensor* fw_scratch_buffer =
- GetOutput(context, node, kFwScratchBufferTensor);
- // Resize the output and output_state tensors.
+ // Resize the output, output_state and cell_state tensors.
TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3);
fw_output_size->data[0] = max_time;
fw_output_size->data[1] = n_batch;
@@ -349,13 +355,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_output_state,
fw_output_state_size));
- // Resize the scratch buffer tensor.
TfLiteIntArray* fw_cell_size = TfLiteIntArrayCreate(2);
fw_cell_size->data[0] = n_batch;
fw_cell_size->data[1] = n_fw_cell;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, fw_cell_state, fw_cell_size));
+ // Create a scratch buffer tensor.
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(2);
+ node->temporaries->data[0] = *scratch_tensor_index;
+ TfLiteTensor* fw_scratch_buffer =
+ &context->tensors[node->temporaries->data[0]];
+ fw_scratch_buffer->type = input->type;
+ fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
+
// Mark state tensors as persistent tensors.
fw_output_state->allocation_type = kTfLiteArenaRwPersistent;
fw_cell_state->allocation_type = kTfLiteArenaRwPersistent;
@@ -392,17 +406,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(context, node, n_input, n_bw_output, n_bw_cell);
- // Get the pointer to output, state and scratch buffer tensors.
+ // Get the pointer to output, output_state and cell_state buffer tensors.
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
TfLiteTensor* bw_output_state =
GetOutput(context, node, kBwOutputStateTensor);
TfLiteTensor* bw_cell_state = GetOutput(context, node, kBwCellStateTensor);
- // TODO(ghodrat): Modify this as soon as we have a finalized method for
- // scratch buffers.
- TfLiteTensor* bw_scratch_buffer =
- GetOutput(context, node, kBwScratchBufferTensor);
- // Resize the output and output_state tensors.
+ // Resize the output, output_state and cell_state tensors.
TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
bw_output_size->data[0] = max_time;
bw_output_size->data[1] = n_batch;
@@ -416,13 +426,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output_state,
bw_output_state_size));
- // Resize the scratch buffer tensor.
TfLiteIntArray* bw_cell_size = TfLiteIntArrayCreate(2);
bw_cell_size->data[0] = n_batch;
bw_cell_size->data[1] = n_bw_cell;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, bw_cell_state, bw_cell_size));
+ // Create a scratch buffer tensor.
+ node->temporaries->data[1] = *(scratch_tensor_index) + 1;
+ TfLiteTensor* bw_scratch_buffer =
+ &context->tensors[node->temporaries->data[1]];
+ bw_scratch_buffer->type = input->type;
+ bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
+
// Mark state tensors as persistent tensors.
bw_output_state->allocation_type = kTfLiteArenaRwPersistent;
bw_cell_state->allocation_type = kTfLiteArenaRwPersistent;
@@ -553,7 +569,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* fw_scratch_buffer =
- GetOutput(context, node, kFwScratchBufferTensor);
+ &context->tensors[node->temporaries->data[0]];
float* fw_input_gate_scratch = nullptr;
float* fw_cell_scratch = nullptr;
float* fw_forget_gate_scratch = nullptr;
@@ -624,7 +640,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* bw_scratch_buffer =
- GetOutput(context, node, kBwScratchBufferTensor);
+ &context->tensors[node->temporaries->data[1]];
float* bw_input_gate_scratch = nullptr;
float* bw_cell_scratch = nullptr;
float* bw_forget_gate_scratch = nullptr;
@@ -691,9 +707,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace bidirectional_sequence_lstm
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
- bidirectional_sequence_lstm::Prepare,
- bidirectional_sequence_lstm::Eval};
+ static TfLiteRegistration r = {
+ bidirectional_sequence_lstm::Init, bidirectional_sequence_lstm::Free,
+ bidirectional_sequence_lstm::Prepare, bidirectional_sequence_lstm::Eval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
index cca857bac0..a18e1bce34 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
@@ -102,9 +102,6 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
fw_projection_bias_ = AddNullInput();
}
- fw_scratch_buffer_ = AddOutput(TensorType_FLOAT32);
- // TODO(ghodrat): Modify these states when we have a permanent solution for
- // persistent buffer.
fw_output_state_ = AddOutput(TensorType_FLOAT32);
fw_cell_state_ = AddOutput(TensorType_FLOAT32);
fw_output_ = AddOutput(TensorType_FLOAT32);
@@ -164,9 +161,6 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
bw_projection_bias_ = AddNullInput();
}
- bw_scratch_buffer_ = AddOutput(TensorType_FLOAT32);
- // TODO(ghodrat): Modify these states when we have a permanent solution for
- // persistent buffer.
bw_output_state_ = AddOutput(TensorType_FLOAT32);
bw_cell_state_ = AddOutput(TensorType_FLOAT32);
bw_output_ = AddOutput(TensorType_FLOAT32);
@@ -349,12 +343,10 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
int fw_output_;
int fw_output_state_;
int fw_cell_state_;
- int fw_scratch_buffer_;
int bw_output_;
int bw_output_state_;
int bw_cell_state_;
- int bw_scratch_buffer_;
int n_batch_;
int n_input_;
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index 8cf1165135..668226e674 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -66,10 +66,19 @@ constexpr int kProjectionWeightsTensor = 16; // Optional
constexpr int kProjectionBiasTensor = 17; // Optional
// Output tensors.
-constexpr int kScratchBufferTensor = 0;
-constexpr int kOutputStateTensor = 1;
-constexpr int kCellStateTensor = 2;
-constexpr int kOutputTensor = 3;
+constexpr int kOutputStateTensor = 0;
+constexpr int kCellStateTensor = 1;
+constexpr int kOutputTensor = 2;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, 1, scratch_tensor_index);
+ return scratch_tensor_index;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
// Check that input tensor dimensions matches with each other.
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
@@ -220,12 +229,15 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
return kTfLiteOk;
}
-// Resize the output, state and scratch tensors based on the sizes of the input
-// tensors. Also check that the size of the input tensors match each other.
+// Resize the output, state tensors based on the sizes of the input tensors.
+// Allocate a temporary scratch tensor. Also check that the sizes of the input
+// tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
@@ -250,15 +262,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
- // Get the pointer to output, state and scratch buffer tensors.
+ // Get the pointer to output, output_state and cell_state tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- // TODO(ghodrat): Modify this as soon as we have a finalized method for
- // scratch buffers.
- TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
- // Resize the output and output_state tensors.
+ // Resize the output, output_state and cell_state tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
output_size->data[0] = n_batch;
output_size->data[1] = n_output;
@@ -271,13 +280,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, output_state, output_state_size));
- // Resize the output, state and scratch buffer tensors.
TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
cell_size->data[0] = n_batch;
cell_size->data[1] = n_cell;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, cell_state, cell_size));
+ // Create a scratch buffer tensor.
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(1);
+ node->temporaries->data[0] = *scratch_tensor_index;
+ TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]];
+ scratch_buffer->type = input->type;
+ scratch_buffer->allocation_type = kTfLiteArenaRw;
+
// Mark state tensors as persistent tensors.
output_state->allocation_type = kTfLiteArenaRwPersistent;
cell_state->allocation_type = kTfLiteArenaRwPersistent;
@@ -362,7 +378,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const bool use_peephole = (cell_to_output_weights != nullptr);
// Index the scratch buffers pointers to the global scratch buffer.
- TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
+ TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]];
+
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* forget_gate_scratch = nullptr;
@@ -433,8 +450,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace lstm
TfLiteRegistration* Register_LSTM() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
- lstm::Prepare, lstm::Eval};
+ static TfLiteRegistration r = {lstm::Init, lstm::Free, lstm::Prepare,
+ lstm::Eval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc
index c068286b0d..d81220d8d3 100644
--- a/tensorflow/contrib/lite/kernels/lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/lstm_test.cc
@@ -97,9 +97,6 @@ class LSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
- scratch_buffer_ = AddOutput(TensorType_FLOAT32);
- // TODO(ghodrat): Modify these states when we have a permanent solution for
- // persistent buffer.
output_state_ = AddOutput(TensorType_FLOAT32);
cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -233,7 +230,6 @@ class LSTMOpModel : public SingleOpModel {
int output_;
int output_state_;
int cell_state_;
- int scratch_buffer_;
int n_batch_;
int n_input_;
diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
index cee3ec6197..bcad58406a 100644
--- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
+++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
@@ -95,9 +95,6 @@ class LSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
- scratch_buffer_ = AddOutput(TensorType_FLOAT32);
- // TODO(ghodrat): Modify these states when we have a permanent solution for
- // persistent buffer.
output_state_ = AddOutput(TensorType_FLOAT32);
cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -235,7 +232,6 @@ class LSTMOpModel : public SingleOpModel {
int output_;
int output_state_;
int cell_state_;
- int scratch_buffer_;
int n_batch_;
int n_input_;
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index 42941a97db..3c1256d3a6 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -66,10 +66,19 @@ constexpr int kProjectionWeightsTensor = 16; // Optional
constexpr int kProjectionBiasTensor = 17; // Optional
// Output tensors.
-constexpr int kScratchBufferTensor = 0;
-constexpr int kOutputStateTensor = 1;
-constexpr int kCellStateTensor = 2;
-constexpr int kOutputTensor = 3;
+constexpr int kOutputStateTensor = 0;
+constexpr int kCellStateTensor = 1;
+constexpr int kOutputTensor = 2;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, 1, scratch_tensor_index);
+ return scratch_tensor_index;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
// Check that input tensor dimensions matches with each other.
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
@@ -220,12 +229,15 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
return kTfLiteOk;
}
-// Resize the output, state and scratch tensors based on the sizes of the input
-// tensors. Also check that the size of the input tensors match each other.
+// Resize the output and state tensors based on the sizes of the input tensors.
+// Allocate a temprory scratch tensor. Also check that the sizes of the input
+// tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
@@ -251,15 +263,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
- // Get the pointer to output, state and scratch buffer tensors.
+ // Get the pointer to output, output_state and cell_state buffer tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- // TODO(ghodrat): Modify this as soon as we have a finalized method for
- // scratch buffers.
- TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
- // Resize the output and output_state tensors.
+ // Resize the output, output_state and cell_state tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = max_time;
output_size->data[1] = n_batch;
@@ -273,13 +282,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, output_state, output_state_size));
- // Resize the scratch buffer tensor.
TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
cell_size->data[0] = n_batch;
cell_size->data[1] = n_cell;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, cell_state, cell_size));
+ // Create a scratch buffer tensor.
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(1);
+ node->temporaries->data[0] = *scratch_tensor_index;
+ TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]];
+ scratch_buffer->type = input->type;
+ scratch_buffer->allocation_type = kTfLiteArenaRw;
+
// Mark state tensors as persistent tensors.
output_state->allocation_type = kTfLiteArenaRwPersistent;
cell_state->allocation_type = kTfLiteArenaRwPersistent;
@@ -365,7 +381,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const bool use_peephole = (cell_to_output_weights != nullptr);
// Index the scratch buffers pointers to the global scratch buffer.
- TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor);
+ TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]];
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* forget_gate_scratch = nullptr;
@@ -439,7 +455,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace unidirectional_sequence_lstm
TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ static TfLiteRegistration r = {unidirectional_sequence_lstm::Init,
+ unidirectional_sequence_lstm::Free,
unidirectional_sequence_lstm::Prepare,
unidirectional_sequence_lstm::Eval};
return &r;
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
index 93b635ae57..5881ced7c7 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
@@ -100,9 +100,6 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
- scratch_buffer_ = AddOutput(TensorType_FLOAT32);
- // TODO(ghodrat): Modify these states when we have a permanent solution for
- // persistent buffer.
output_state_ = AddOutput(TensorType_FLOAT32);
cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -238,7 +235,6 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
int output_;
int output_state_;
int cell_state_;
- int scratch_buffer_;
int n_batch_;
int n_input_;
diff --git a/tensorflow/contrib/lite/models/speech_test.cc b/tensorflow/contrib/lite/models/speech_test.cc
index a354179a94..206de1962d 100644
--- a/tensorflow/contrib/lite/models/speech_test.cc
+++ b/tensorflow/contrib/lite/models/speech_test.cc
@@ -131,8 +131,8 @@ TEST_P(SpeechTest, SpeakerIdOkGoogleTest) {
ASSERT_TRUE(ConvertCsvData(
"speech_speakerid_model.tflite", "speech_speakerid_model_in.csv",
"speech_speakerid_model_out.csv", /*input_tensor=*/"0",
- /*output_tensor=*/"66",
- /*persistent_tensors=*/"19,20,40,41,61,62",
+ /*output_tensor=*/"63",
+ /*persistent_tensors=*/"18,19,38,39,58,59",
/*sequence_size=*/80, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
@@ -144,8 +144,8 @@ TEST_P(SpeechTest, AsrAmTest) {
ASSERT_TRUE(
ConvertCsvData("speech_asr_am_model.tflite", "speech_asr_am_model_in.csv",
"speech_asr_am_model_out.csv", /*input_tensor=*/"0",
- /*output_tensor=*/"109",
- /*persistent_tensors=*/"19,20,40,41,61,62,82,83,103,104",
+ /*output_tensor=*/"104",
+ /*persistent_tensors=*/"18,19,38,39,58,59,78,79,98,99",
/*sequence_size=*/320, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
@@ -170,8 +170,8 @@ TEST_P(SpeechTest, EndpointerTest) {
ASSERT_TRUE(ConvertCsvData(
"speech_endpointer_model.tflite", "speech_endpointer_model_in.csv",
"speech_endpointer_model_out.csv", /*input_tensor=*/"0",
- /*output_tensor=*/"58",
- /*persistent_tensors=*/"28,29,49,50",
+ /*output_tensor=*/"56",
+ /*persistent_tensors=*/"27,28,47,48",
/*sequence_size=*/320, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
@@ -183,8 +183,8 @@ TEST_P(SpeechTest, TtsTest) {
ASSERT_TRUE(ConvertCsvData("speech_tts_model.tflite",
"speech_tts_model_in.csv",
"speech_tts_model_out.csv", /*input_tensor=*/"0",
- /*output_tensor=*/"74",
- /*persistent_tensors=*/"25,26,46,47,67,68,73",
+ /*output_tensor=*/"71",
+ /*persistent_tensors=*/"24,25,44,45,64,65,70",
/*sequence_size=*/334, &os));
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations()))
diff --git a/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec b/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec
index 5812de4b30..f7f518b75f 100644
--- a/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec
+++ b/tensorflow/contrib/lite/models/testdata/speech_asr_lm_model.test_spec
@@ -1,5 +1,5 @@
load_model: "speech_asr_lm_model.tflite"
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 3
input: "63982"
@@ -18,7 +18,7 @@ invoke {
input: "63981"
output: "-0.314846"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 6
input: "63982"
@@ -31,7 +31,7 @@ invoke {
input: "3082"
output: "-3.63721"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 8
input: "63982"
@@ -44,7 +44,7 @@ invoke {
input: "18965"
output: "-6.93985"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 13
input: "63982"
@@ -63,7 +63,7 @@ invoke {
input: "63981"
output: "-3.82091"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 19
input: "63982"
@@ -88,7 +88,7 @@ invoke {
input: "63981"
output: "-0.677399"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 26
input: "63982"
@@ -113,7 +113,7 @@ invoke {
input: "63981"
output: "0.415889"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 30
input: "63982"
@@ -131,7 +131,7 @@ invoke {
input: "51923"
output: "-14.1147"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 34
input: "63982"
@@ -144,7 +144,7 @@ invoke {
input: "16318"
output: "-1.54815"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 36
input: "63982"
@@ -157,7 +157,7 @@ invoke {
input: "28303"
output: "-14.0947"
}
-init_state: "21,22,42,43,63,64"
+init_state: "20,21,40,41,60,61"
invoke {
id: 38
input: "63982"
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
index 45335fd78c..3f768bfee1 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
@@ -146,16 +146,19 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
lstm_cell_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT] = prev_activ_input;
lstm_cell_op->inputs[LstmCellOperator::PREV_STATE_INPUT] = prev_state_input;
- // Reorder LstmCell's 4 outputs.
+ // Reorder LstmCell's 3 outputs.
lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS);
lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT] =
src_op->outputs[kOutputTensor];
lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT] =
src_op->outputs[kCellStateTensor];
- lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] =
- src_op->outputs[kScratchBufferTensor];
lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] =
src_op->outputs[kOutputStateTensor];
+ // Create a new temp array for the fourth output.
+ const string& concat_temp_array_name =
+ AvailableArrayName(*model, base_name + "concat_temp");
+ model->GetOrCreateArray(concat_temp_array_name);
+ lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name;
// Add the op into model.
model->operators.emplace(op_it, std::move(lstm_cell_op));
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
index eca717680a..8e66323bd7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
@@ -138,10 +138,9 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
CreateOptionalArray(model, &(lstm_cell_op->inputs[kProjectionBiasTensor]),
base_name + "proj_bias");
- // Reorder LstmCell's outputs.
- lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS);
- lstm_cell_op->outputs[kScratchBufferTensor] =
- curr_op->outputs[LstmCellOperator::CONCAT_TEMP];
+ // Reorder and resize LstmCell's outputs.
+ lstm_cell_op->outputs.resize(
+ ExtendedLstmCellOutputs::kExtendedLstmOutputCount);
lstm_cell_op->outputs[kOutputStateTensor] =
curr_op->outputs[LstmCellOperator::ACTIV_TEMP];
lstm_cell_op->outputs[kCellStateTensor] =
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h
index 4a9974ed4e..1c32a78169 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h
@@ -51,10 +51,10 @@ enum ExtendedLstmCellInputs {
};
enum ExtendedLstmCellOutputs {
- kScratchBufferTensor = 0,
- kOutputStateTensor = 1,
- kCellStateTensor = 2,
- kOutputTensor = 3
+ kOutputStateTensor = 0,
+ kCellStateTensor = 1,
+ kOutputTensor = 2,
+ kExtendedLstmOutputCount = 3
};
// Create optional array used for optional tensor in ExtendedLstmCell inputs.