diff options
author | 2018-08-22 12:37:57 -0700 | |
---|---|---|
committer | 2018-08-22 12:41:08 -0700 | |
commit | 8fea38677b26743be0a26cf62f3b45f2814792df (patch) | |
tree | e4c8a3f8880d8f0c6525cb636b6b3cb92515f313 /tensorflow/contrib/lite/kernels/svdf.cc | |
parent | 6039d18b87d316c14a7908dbcefd438cea223b5c (diff) |
Disable Non-Variable Tensor API in SVDF kernel.
TFLite SVDF now supports 5 inputs (with one variable tensor representing the state).
PiperOrigin-RevId: 209811478
Diffstat (limited to 'tensorflow/contrib/lite/kernels/svdf.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/svdf.cc | 57 |
1 files changed, 12 insertions, 45 deletions
diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc index 9e8ed3cbf3..6ba7959752 100644 --- a/tensorflow/contrib/lite/kernels/svdf.cc +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -105,16 +105,11 @@ constexpr int kInputTensor = 0; constexpr int kWeightsFeatureTensor = 1; constexpr int kWeightsTimeTensor = 2; constexpr int kBiasTensor = 3; - -// * If the node has 5 inputs the following tensor is used as state tensor. -// This is defined to be a variable tensor, and will be modified by this op. +// This is a variable tensor, and will be modified by this op. constexpr int kInputActivationStateTensor = 4; -// Output tensors. -// * If node has 4 inputs, kStateTensor will be used as state tensor. -// * If node has 5 inputs, kStateTensor is ignored. -constexpr int kStateTensor = 0; -constexpr int kOutputTensor = 1; +// Output tensor. +constexpr int kOutputTensor = 0; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* op_data = new OpData(); @@ -134,21 +129,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { int scratch_tensor_index = op_data->scratch_tensor_index; // Check we have all the inputs and outputs we need. - TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); - bool use_input_variable_states; - if (node->inputs->size == 5) { - use_input_variable_states = true; - op_data->activation_state_tensor_index = - node->inputs->data[kInputActivationStateTensor]; - } else if (node->inputs->size == 4) { - use_input_variable_states = false; - op_data->activation_state_tensor_index = node->outputs->data[kStateTensor]; - } else { - context->ReportError(context, - "The SVDF kernel expects 4 or 5 inputs. Got %d inputs", - node->inputs->size); - return kTfLiteError; - } + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); + op_data->activation_state_tensor_index = + node->inputs->data[kInputActivationStateTensor]; const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* weights_feature = @@ -178,28 +162,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { &context->tensors[op_data->activation_state_tensor_index]; TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - if (use_input_variable_states) { - // Check the shape of input state tensors. - TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2); - TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 0), - batch_size); - TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 1), - memory_size * num_filters); - } else { - // Resize activation_state. - // For each batch, the state is a 2-D tensor: memory_size * num_filters - // The left most column is used to save current cycle activation. - // The right most column is used to save temporary output which will be - // reduced to num_units outputs. - TfLiteIntArray* state_size_array = TfLiteIntArrayCreate(2); - state_size_array->data[0] = batch_size; - state_size_array->data[1] = memory_size * num_filters; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state, - state_size_array)); - - // Mark state as a persistent tensor. - activation_state->allocation_type = kTfLiteArenaRwPersistent; - } + // Check the shape of input state tensors. + TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 0), batch_size); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 1), + memory_size * num_filters); // Resize output. TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); |