diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-14 14:49:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-14 14:56:06 -0700 |
commit | 899ac1ca6cda8aeb894b7baf1821138c9302c479 (patch) | |
tree | 0a6fabb7a8d86c29e6fc0e62f82e297f25e8bffe /tensorflow/contrib/lite/kernels/basic_rnn.cc | |
parent | c328db1698ca1c4029219b7bf85274ff4b7c66c8 (diff) |
Internal change.
PiperOrigin-RevId: 196570742
Diffstat (limited to 'tensorflow/contrib/lite/kernels/basic_rnn.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/basic_rnn.cc | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc index d812cd7bf0..0907547f9f 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc @@ -63,6 +63,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]); TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]); TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input_weights->type, recurrent_weights->type); TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); @@ -194,13 +196,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + // We already checked that weight types are consistent, so branch on one. switch (input_weights->type) { case kTfLiteFloat32: return EvalFloat(input, input_weights, recurrent_weights, bias, params, hidden_state, output); case kTfLiteUInt8: { // TODO(mirkov): implement eval with quantized inputs as well. - TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); return EvalQuantized(input, input_weights, recurrent_weights, bias, |