aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/svdf.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-22 12:37:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 12:41:08 -0700
commit8fea38677b26743be0a26cf62f3b45f2814792df (patch)
treee4c8a3f8880d8f0c6525cb636b6b3cb92515f313 /tensorflow/contrib/lite/kernels/svdf.cc
parent6039d18b87d316c14a7908dbcefd438cea223b5c (diff)
Disable Non-Variable Tensor API in SVDF kernel.
TFLite SVDF now supports 5 inputs (with one variable tensor representing the state). PiperOrigin-RevId: 209811478
Diffstat (limited to 'tensorflow/contrib/lite/kernels/svdf.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/svdf.cc57
1 files changed, 12 insertions, 45 deletions
diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc
index 9e8ed3cbf3..6ba7959752 100644
--- a/tensorflow/contrib/lite/kernels/svdf.cc
+++ b/tensorflow/contrib/lite/kernels/svdf.cc
@@ -105,16 +105,11 @@ constexpr int kInputTensor = 0;
constexpr int kWeightsFeatureTensor = 1;
constexpr int kWeightsTimeTensor = 2;
constexpr int kBiasTensor = 3;
-
-// * If the node has 5 inputs the following tensor is used as state tensor.
-// This is defined to be a variable tensor, and will be modified by this op.
+// This is a variable tensor, and will be modified by this op.
constexpr int kInputActivationStateTensor = 4;
-// Output tensors.
-// * If node has 4 inputs, kStateTensor will be used as state tensor.
-// * If node has 5 inputs, kStateTensor is ignored.
-constexpr int kStateTensor = 0;
-constexpr int kOutputTensor = 1;
+// Output tensor.
+constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData();
@@ -134,21 +129,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int scratch_tensor_index = op_data->scratch_tensor_index;
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
- bool use_input_variable_states;
- if (node->inputs->size == 5) {
- use_input_variable_states = true;
- op_data->activation_state_tensor_index =
- node->inputs->data[kInputActivationStateTensor];
- } else if (node->inputs->size == 4) {
- use_input_variable_states = false;
- op_data->activation_state_tensor_index = node->outputs->data[kStateTensor];
- } else {
- context->ReportError(context,
- "The SVDF kernel expects 4 or 5 inputs. Got %d inputs",
- node->inputs->size);
- return kTfLiteError;
- }
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
+ op_data->activation_state_tensor_index =
+ node->inputs->data[kInputActivationStateTensor];
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* weights_feature =
@@ -178,28 +162,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
&context->tensors[op_data->activation_state_tensor_index];
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- if (use_input_variable_states) {
- // Check the shape of input state tensors.
- TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
- TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 0),
- batch_size);
- TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 1),
- memory_size * num_filters);
- } else {
- // Resize activation_state.
- // For each batch, the state is a 2-D tensor: memory_size * num_filters
- // The left most column is used to save current cycle activation.
- // The right most column is used to save temporary output which will be
- // reduced to num_units outputs.
- TfLiteIntArray* state_size_array = TfLiteIntArrayCreate(2);
- state_size_array->data[0] = batch_size;
- state_size_array->data[1] = memory_size * num_filters;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state,
- state_size_array));
-
- // Mark state as a persistent tensor.
- activation_state->allocation_type = kTfLiteArenaRwPersistent;
- }
+ // Check the shape of input state tensors.
+ TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 0), batch_size);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 1),
+ memory_size * num_filters);
// Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);