aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-05 10:48:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 10:53:05 -0700
commit5032036e1f2a7060848aed64bce94a1f882142d5 (patch)
treec4e6753dd5a1d2bbf8b4b0c3e8a6ca5c3901983c /tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
parent7fa693209fe238478739b3982f652a7e35be91f3 (diff)
Introduce auxiliary input and allow "cross-linking" in the bidirectional LSTM Op.
This introduces a connection between forward and backward cells across subsequent layers when stacking bidirectional LSTM Ops on top of each other. In more detail: Previously, the Op had only one input that was fed into the layer in the following way: INPUT (INPUT_REVERSED) | | ----------------------- | FW_LSTM BW_LSTM | <----- bidi-LSTM cell (with one input / two outputs) ----------------------- | | FW_OUT BW_OUT Now, the Op can have an (optional) auxiliary input in the following way: AUX_INPUT (AUX_INPUT_REVERSED) | | INPUT | (INPUT_R'D.)| | | | | ------------------------- | \ / \ / | | FW_LSTM BW_LSTM | <----- bidi-LSMT cell (with 2 inputs / 2 outputs) ------------------------- | | FW_OUT BW_OUT When stacking these Ops, previously, only the following flow was allowed: Input / \ FW_LSTM1 BW_LSTM1 | | | | FW_LSTM2 BW_LSTM2 | | | | FW_LSTM3 BW_LSTM3 \ / Output With the introduction of an auxiliary input to the bidi-LSTM layer, the forward (FW_LSTMi) output of the ith layer is fed into as the input to the next layer (hence, inputs to both FW_LSTM{i+1} and BW_LSTM{i+1}) and the backward output is fed as the auxiliary inputs to both FW_LSTM{i+1} and BW_LSTM{i+1}). This way, the stacking can be changed to allow for the "cross-linking" between subsequent layer in the following way: Input / \ FW_LSTM1 BW_LSTM1 | \ / | | / \ | FW_LSTM2 BW_LSTM2 | \ / | | / \ | FW_LSTM3 BW_LSTM3 \ / Output PiperOrigin-RevId: 211659472
Diffstat (limited to 'tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc70
1 files changed, 70 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
index d058fab529..74ba8021c2 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
@@ -177,6 +177,16 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
bw_output_ = AddOutput(TensorType_FLOAT32);
+ aux_input_ = AddNullInput();
+ fw_aux_input_to_input_weights_ = AddNullInput();
+ fw_aux_input_to_forget_weights_ = AddNullInput();
+ fw_aux_input_to_cell_weights_ = AddNullInput();
+ fw_aux_input_to_output_weights_ = AddNullInput();
+ bw_aux_input_to_input_weights_ = AddNullInput();
+ bw_aux_input_to_forget_weights_ = AddNullInput();
+ bw_aux_input_to_cell_weights_ = AddNullInput();
+ bw_aux_input_to_output_weights_ = AddNullInput();
+
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
BuiltinOptions_LSTMOptions,
CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
@@ -340,6 +350,16 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
int fw_output_;
int bw_output_;
+ int aux_input_;
+ int fw_aux_input_to_input_weights_;
+ int fw_aux_input_to_forget_weights_;
+ int fw_aux_input_to_cell_weights_;
+ int fw_aux_input_to_output_weights_;
+ int bw_aux_input_to_input_weights_;
+ int bw_aux_input_to_forget_weights_;
+ int bw_aux_input_to_cell_weights_;
+ int bw_aux_input_to_output_weights_;
+
int n_batch_;
int n_input_;
int n_fw_cell_;
@@ -415,6 +435,16 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
{n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
});
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
@@ -562,6 +592,16 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) {
{n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
});
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
@@ -709,6 +749,16 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
{n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
});
lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
@@ -848,6 +898,16 @@ TEST(LSTMOpTest,
{n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
});
lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
@@ -987,6 +1047,16 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
{n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
});
lstm.SetInputToInputWeights(