aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-03 21:04:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-03 21:07:59 -0700
commitc1385b0ddf823a6eabee865cc90e5d6147691add (patch)
tree3fdf196c77ff28b85ed27a78d83c223a39d150f6 /tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
parent0bdb4dd6b67d5bfa91f688e57cb293a617cb4b45 (diff)
Introduce auxiliary input and allow "cross-linking" in the bidirectional RNN Op.
This introduces a connection between forward and backward cells across subsequent layers when stacking bidirectional RNN Ops on top of each other. In more detail: Previously, the Op had only one input that was fed into the layer in the following way: INPUT (INPUT_REVERSED) | | --------------------- | FW_RNN BW_RNN | <----- bidi-RNN cell (with one input / two outpus) --------------------- | | FW_OUT BW_OUT Now, the Op can have an (optional) auxiliary input in the following way: AUX_INPUT (AUX_INPUT_REVERSED) | | INPUT | (INPUT_R'D.)| | | | | ----------------------- | \ / \ / | | FW_RNN BW_RNN | <----- bidi-RNN cell (with 2 inputs / 2 outpus) ----------------------- | | FW_OUT BW_OUT When stacking these Ops, previously, only the following flow was allowed: Input / \ FW_RNN1 BW_RNN1 | | | | FW_RNN2 BW RNN2 | | | | FW_RNN3 BW_RNN3 \ / Output With the introduction of an auxiliary input to the bidi-RNN layer, the forward (FW_RNNi) output of the ith layer is fed into as the input to the next layer (hence, inputs to both FW_RNN{i+1} and BW_RNN{i+1}) and the backward output is fed as the auxiliary inputs to both FW_RNN{i+1} and BW_RNN{i+1}). This way, the stacking can be changed to allow for the "cross-linking" between subsequent layer in the following way: Input / \ FW_RNN1 BW_RNN1 | \ / | | / \ | FW_RNN2 BW RNN2 | \ / | | / \ | FW_RNN3 BW_RNN3 \ / Output PiperOrigin-RevId: 211401475
Diffstat (limited to 'tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc16
1 files changed, 14 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
index 03236dbcdc..3e34ba6196 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
@@ -665,12 +665,18 @@ class BidirectionalRNNOpModel : public SingleOpModel {
fw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
fw_bias_ = AddInput(TensorType_FLOAT32);
fw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
- fw_output_ = AddOutput(TensorType_FLOAT32);
bw_weights_ = AddInput(TensorType_FLOAT32);
bw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
bw_bias_ = AddInput(TensorType_FLOAT32);
bw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
+
+ aux_input_ = AddNullInput();
+ aux_fw_weights_ = AddNullInput();
+ aux_bw_weights_ = AddNullInput();
+
+ fw_output_ = AddOutput(TensorType_FLOAT32);
bw_output_ = AddOutput(TensorType_FLOAT32);
+
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
BuiltinOptions_SequenceRNNOptions,
CreateSequenceRNNOptions(builder_, /*time_major=*/false,
@@ -685,7 +691,10 @@ class BidirectionalRNNOpModel : public SingleOpModel {
{bw_units_, input_size_}, // bw_weights
{bw_units_, bw_units_}, // bw_recurrent_weights
{bw_units_}, // bw_bias
- {batches_, bw_units_} // bw_hidden_state
+ {batches_, bw_units_}, // bw_hidden_state
+ {batches_, sequence_len_, 0}, // aux_input
+ {fw_units_, 0}, // aux_fw_weights
+ {bw_units_, 0}, // aux_bw_weights
});
}
@@ -742,6 +751,9 @@ class BidirectionalRNNOpModel : public SingleOpModel {
int bw_bias_;
int bw_hidden_state_;
int bw_output_;
+ int aux_input_;
+ int aux_fw_weights_;
+ int aux_bw_weights_;
int batches_;
int sequence_len_;