aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/basic_rnn.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-14 14:49:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-14 14:56:06 -0700
commit899ac1ca6cda8aeb894b7baf1821138c9302c479 (patch)
tree0a6fabb7a8d86c29e6fc0e62f82e297f25e8bffe /tensorflow/contrib/lite/kernels/basic_rnn.cc
parentc328db1698ca1c4029219b7bf85274ff4b7c66c8 (diff)
Internal change.
PiperOrigin-RevId: 196570742
Diffstat (limited to 'tensorflow/contrib/lite/kernels/basic_rnn.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn.cc4
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc
index d812cd7bf0..0907547f9f 100644
--- a/tensorflow/contrib/lite/kernels/basic_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc
@@ -63,6 +63,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]);
TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]);
TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, input_weights->type, recurrent_weights->type);
TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
@@ -194,13 +196,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ // We already checked that weight types are consistent, so branch on one.
switch (input_weights->type) {
case kTfLiteFloat32:
return EvalFloat(input, input_weights, recurrent_weights, bias, params,
hidden_state, output);
case kTfLiteUInt8: {
// TODO(mirkov): implement eval with quantized inputs as well.
- TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
return EvalQuantized(input, input_weights, recurrent_weights, bias,