aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-05 12:52:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 12:58:18 -0700
commit3427a3c638fb92a172d390266ed62403f9140f7d (patch)
treeb36ed65be45e98a34de408894b7a8cedd4ce5919
parent3f54f1f60413cbd3e9a5a4126f8ae04bc4e06abc (diff)
Internal change.
PiperOrigin-RevId: 215951354
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD1
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_eval.cc3
2 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 68636fb070..d2d8073abd 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -259,6 +259,7 @@ cc_library(
srcs = ["lstm_eval.cc"],
hdrs = ["lstm_eval.h"],
deps = [
+ ":op_macros",
"//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels/internal:kernel_utils",
"//tensorflow/contrib/lite/kernels/internal:tensor_utils",
diff --git a/tensorflow/contrib/lite/kernels/lstm_eval.cc b/tensorflow/contrib/lite/kernels/lstm_eval.cc
index c6c21eb085..20a4e30009 100644
--- a/tensorflow/contrib/lite/kernels/lstm_eval.cc
+++ b/tensorflow/contrib/lite/kernels/lstm_eval.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
@@ -599,6 +600,7 @@ TfLiteStatus EvalFloat(
const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
TfLiteTensor* cell_state, TfLiteTensor* output) {
+ TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
const int n_batch = input->dims->data[input->dims->size - 2];
const int n_input = input->dims->data[input->dims->size - 1];
@@ -716,6 +718,7 @@ TfLiteStatus EvalHybrid(
TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
TfLiteTensor* output_state, TfLiteTensor* cell_state,
TfLiteTensor* output) {
+ TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
const int n_batch = input->dims->data[input->dims->size - 2];
const int n_input = input->dims->data[input->dims->size - 1];