diff options
author | Alan Chiao <alanchiao@google.com> | 2018-07-12 12:16:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-12 12:25:21 -0700 |
commit | e8a65666c6aadbbbd2b19b9322d841b1547dbd35 (patch) | |
tree | eadeba71d8ad2d4c0c147d59f563d3246a7f6a65 /tensorflow/contrib | |
parent | 746a51b76742574d81783a4efe437e1824073d88 (diff) |
LSTM CHECK_OK on input tensor checks.
PiperOrigin-RevId: 204341675
Diffstat (limited to 'tensorflow/contrib')
4 files changed, 40 insertions, 25 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index 3425288f02..14a19aeef3 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -276,27 +276,33 @@ TfLiteStatus CheckLstmTensorDimensions( TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteNode* node, int n_input, int n_output, int n_cell) { - CheckLstmTensorDimensions( - context, node, n_input, n_output, n_cell, kFwInputToInputWeightsTensor, - kFwInputToForgetWeightsTensor, kFwInputToCellWeightsTensor, - kFwInputToOutputWeightsTensor, kFwRecurrentToInputWeightsTensor, - kFwRecurrentToForgetWeightsTensor, kFwRecurrentToCellWeightsTensor, - kFwRecurrentToOutputWeightsTensor, kFwCellToInputWeightsTensor, - kFwCellToForgetWeightsTensor, kFwCellToOutputWeightsTensor, - kFwInputGateBiasTensor, kFwForgetGateBiasTensor, kFwCellGateBiasTensor, - kFwOutputGateBiasTensor, kFwProjectionWeightsTensor, - kFwProjectionBiasTensor); - - CheckLstmTensorDimensions( - context, node, n_input, n_output, n_cell, kBwInputToInputWeightsTensor, - kBwInputToForgetWeightsTensor, kBwInputToCellWeightsTensor, - kBwInputToOutputWeightsTensor, kBwRecurrentToInputWeightsTensor, - kBwRecurrentToForgetWeightsTensor, kBwRecurrentToCellWeightsTensor, - kBwRecurrentToOutputWeightsTensor, kBwCellToInputWeightsTensor, - kBwCellToForgetWeightsTensor, kBwCellToOutputWeightsTensor, - kBwInputGateBiasTensor, kBwForgetGateBiasTensor, kBwCellGateBiasTensor, - kBwOutputGateBiasTensor, kBwProjectionWeightsTensor, - kBwProjectionBiasTensor); + TF_LITE_ENSURE_OK( + context, + CheckLstmTensorDimensions( + context, node, n_input, n_output, n_cell, + kFwInputToInputWeightsTensor, kFwInputToForgetWeightsTensor, + kFwInputToCellWeightsTensor, kFwInputToOutputWeightsTensor, + kFwRecurrentToInputWeightsTensor, kFwRecurrentToForgetWeightsTensor, + kFwRecurrentToCellWeightsTensor, kFwRecurrentToOutputWeightsTensor, + kFwCellToInputWeightsTensor, kFwCellToForgetWeightsTensor, + kFwCellToOutputWeightsTensor, kFwInputGateBiasTensor, + kFwForgetGateBiasTensor, kFwCellGateBiasTensor, + kFwOutputGateBiasTensor, kFwProjectionWeightsTensor, + kFwProjectionBiasTensor)); + + TF_LITE_ENSURE_OK( + context, + CheckLstmTensorDimensions( + context, node, n_input, n_output, n_cell, + kBwInputToInputWeightsTensor, kBwInputToForgetWeightsTensor, + kBwInputToCellWeightsTensor, kBwInputToOutputWeightsTensor, + kBwRecurrentToInputWeightsTensor, kBwRecurrentToForgetWeightsTensor, + kBwRecurrentToCellWeightsTensor, kBwRecurrentToOutputWeightsTensor, + kBwCellToInputWeightsTensor, kBwCellToForgetWeightsTensor, + kBwCellToOutputWeightsTensor, kBwInputGateBiasTensor, + kBwForgetGateBiasTensor, kBwCellGateBiasTensor, + kBwOutputGateBiasTensor, kBwProjectionWeightsTensor, + kBwProjectionBiasTensor)); // Check if Forward and Backward tensors match along required dimensions. return kTfLiteOk; @@ -334,7 +340,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. - CheckInputTensorDimensions(context, node, n_input, n_fw_output, n_fw_cell); + TF_LITE_ENSURE_OK( + context, CheckInputTensorDimensions(context, node, n_input, n_fw_output, + n_fw_cell)); // Get the pointer to output, state and scratch buffer tensors. TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); @@ -404,7 +412,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. - CheckInputTensorDimensions(context, node, n_input, n_bw_output, n_bw_cell); + TF_LITE_ENSURE_OK( + context, CheckInputTensorDimensions(context, node, n_input, n_bw_output, + n_bw_cell)); // Get the pointer to output, output_state and cell_state buffer tensors. TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index 3577ae6caa..4dfc891548 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -306,7 +306,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int n_output = recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. - CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); + TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input, + n_output, n_cell)); // Get the pointer to output, activation_state and cell_state tensors. TfLiteTensor* output = GetOutput(context, node, kOutputTensor); diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc index 0b7c56133e..0266f5fe57 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Unit test for TFLite LSTM op. +// +// TODO(alanchiao): add unit test with invalid input dimensions for this and its +// variants. #include <memory> #include <vector> diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc index 32daf2bb02..c48b470f92 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -274,7 +274,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int n_output = recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. - CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); + TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input, + n_output, n_cell)); // Get the pointer to output, output_state and cell_state buffer tensors. TfLiteTensor* output = GetOutput(context, node, kOutputTensor); |