diff options
author | 2018-08-21 19:41:46 -0700 | |
---|---|---|
committer | 2018-08-21 19:45:00 -0700 | |
commit | d7f704424858c30a23c12c39d0841db326b8d148 (patch) | |
tree | 052d8ef2ccbca29946cf2b940845ac3828cd8707 /tensorflow/contrib/lite/kernels/svdf.cc | |
parent | 496023e9dc84a076caeb2e5e8e13b6a3d819ad6d (diff) |
Support Variable Tensor API in SVDF kernel.
TFLite SVDF now supports 5 inputs (with variable tensor) and 4 inputs.
PiperOrigin-RevId: 209702845
Diffstat (limited to 'tensorflow/contrib/lite/kernels/svdf.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/svdf.cc | 107 |
1 files changed, 73 insertions, 34 deletions
diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc index 6d4912ce3a..9e8ed3cbf3 100644 --- a/tensorflow/contrib/lite/kernels/svdf.cc +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -40,19 +40,22 @@ namespace { struct OpData { int scratch_tensor_index; bool float_weights_time_initialized; + + int activation_state_tensor_index; }; static inline void ApplyTimeWeightsBiasAndActivation( int batch_size, int memory_size, int num_filters, int num_units, int rank, const TfLiteTensor* weights_time, const TfLiteTensor* bias, - TfLiteFusedActivation activation, TfLiteTensor* state, + TfLiteFusedActivation activation, TfLiteTensor* activation_state, TfLiteTensor* scratch, TfLiteTensor* output) { // Compute matmul(state, weights_time). // The right most column is used to save temporary output (with the size of - // num_filters). This is achieved by starting at state->data.f and having the - // stride equal to memory_size. + // num_filters). This is achieved by starting at activation_state->data.f, + // and having the stride equal to memory_size. for (int b = 0; b < batch_size; ++b) { - float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + float* state_ptr_batch = + activation_state->data.f + b * memory_size * num_filters; float* scratch_ptr_batch = scratch->data.f + b * num_filters; tensor_utils::BatchVectorBatchVectorDotProduct( weights_time->data.f, state_ptr_batch, memory_size, num_filters, @@ -82,13 +85,14 @@ static inline void ApplyTimeWeightsBiasAndActivation( activation, output_ptr_batch); } - // Left shift the state to make room for next cycle's activation. + // Left shift the activation_state to make room for next cycle's activation. // TODO(alanchiao): explore collapsing this into a single loop. for (int b = 0; b < batch_size; ++b) { - float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + float* state_ptr_batch = + activation_state->data.f + b * memory_size * num_filters; for (int f = 0; f < num_filters; ++f) { tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size, - /*shift_value=*/0.0); + /*shift_value=*/0.0f); state_ptr_batch += memory_size; } } @@ -96,10 +100,19 @@ static inline void ApplyTimeWeightsBiasAndActivation( } // namespace +// Input tensors. constexpr int kInputTensor = 0; constexpr int kWeightsFeatureTensor = 1; constexpr int kWeightsTimeTensor = 2; constexpr int kBiasTensor = 3; + +// * If the node has 5 inputs the following tensor is used as state tensor. +// This is defined to be a variable tensor, and will be modified by this op. +constexpr int kInputActivationStateTensor = 4; + +// Output tensors. +// * If node has 4 inputs, kStateTensor will be used as state tensor. +// * If node has 5 inputs, kStateTensor is ignored. constexpr int kStateTensor = 0; constexpr int kOutputTensor = 1; @@ -121,8 +134,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { int scratch_tensor_index = op_data->scratch_tensor_index; // Check we have all the inputs and outputs we need. - TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + bool use_input_variable_states; + if (node->inputs->size == 5) { + use_input_variable_states = true; + op_data->activation_state_tensor_index = + node->inputs->data[kInputActivationStateTensor]; + } else if (node->inputs->size == 4) { + use_input_variable_states = false; + op_data->activation_state_tensor_index = node->outputs->data[kStateTensor]; + } else { + context->ReportError(context, + "The SVDF kernel expects 4 or 5 inputs. Got %d inputs", + node->inputs->size); + return kTfLiteError; + } const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* weights_feature = @@ -148,22 +174,32 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units); } - TfLiteTensor* state = GetOutput(context, node, kStateTensor); + TfLiteTensor* activation_state = + &context->tensors[op_data->activation_state_tensor_index]; TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - // Resize state. - // For each batch, the state is a 2-D tensor: memory_size * num_filters - // The left most column is used to save current cycle activation. - // The right most column is used to save temporary output which will be - // reduced to num_units outputs. - TfLiteIntArray* state_size_array = TfLiteIntArrayCreate(2); - state_size_array->data[0] = batch_size; - state_size_array->data[1] = memory_size * num_filters; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, state, state_size_array)); - - // Mark state as a persistent tensor. - state->allocation_type = kTfLiteArenaRwPersistent; + if (use_input_variable_states) { + // Check the shape of input state tensors. + TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 0), + batch_size); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 1), + memory_size * num_filters); + } else { + // Resize activation_state. + // For each batch, the state is a 2-D tensor: memory_size * num_filters + // The left most column is used to save current cycle activation. + // The right most column is used to save temporary output which will be + // reduced to num_units outputs. + TfLiteIntArray* state_size_array = TfLiteIntArrayCreate(2); + state_size_array->data[0] = batch_size; + state_size_array->data[1] = memory_size * num_filters; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state, + state_size_array)); + + // Mark state as a persistent tensor. + activation_state->allocation_type = kTfLiteArenaRwPersistent; + } // Resize output. TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); @@ -220,8 +256,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { scaling_factors_size)); } - // Used to store dequantized weights_time matrix for hybrid computation - // of matmul(state, weights_time), which occurs in floating point. + // Used to store dequantized weights_time matrix for hybrid computation of + // matmul(activation_state, weights_time), which occurs in floating point. node->temporaries->data[3] = scratch_tensor_index + 3; TfLiteTensor* float_weights_time = GetTemporary(context, node, /*index=*/3); float_weights_time->type = kTfLiteFloat32; @@ -253,13 +289,13 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, const int memory_size = weights_time->dims->data[1]; // Clear the activation (state left most column). - // TODO(ghodrat): Add a test which initialize state with invalid values in - // left most column and make sure it passes. + // TODO(ghodrat): Add a test which initialize activation_state with invalid + // values in left most column and make sure it passes. for (int b = 0; b < batch_size; ++b) { float* state_ptr_batch = state->data.f + b * memory_size * num_filters; for (int c = 0; c < num_filters; ++c) { float* state_ptr = state_ptr_batch + c * memory_size; - state_ptr[memory_size - 1] = 0.0; + state_ptr[memory_size - 1] = 0.0f; } } @@ -307,7 +343,7 @@ TfLiteStatus EvalHybrid( // Clear the activation (state left most column). // TODO(ghodrat): Add a test which initialize state with invalid values in - // left most column and make sure it passes. + // the left most column and make sure it passes. for (int b = 0; b < batch_size; ++b) { float* state_ptr_batch = state->data.f + b * memory_size * num_filters; for (int c = 0; c < num_filters; ++c) { @@ -329,9 +365,10 @@ TfLiteStatus EvalHybrid( } // Compute conv1d(inputs, weights_feature). - // The state right most column is used to save current cycle activation. - // This is achieved by starting at state->data.f[memory_size - 1] and having - // the stride equal to memory_size. + // The rightmost column of state is used to save the current cycle + // activation. + // This is achieved by starting at state->data.f[memory_size - 1] + // and having the stride equal to memory_size. tensor_utils::MatrixBatchVectorMultiplyAccumulate( weights_feature_ptr, num_filters, input_size, quantized_input_ptr_batch, scaling_factors_ptr, batch_size, &state->data.f[memory_size - 1], @@ -359,13 +396,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0); - TfLiteTensor* state = GetOutput(context, node, kStateTensor); + TfLiteTensor* activation_state = + &context->tensors[op_data->activation_state_tensor_index]; TfLiteTensor* output = GetOutput(context, node, kOutputTensor); switch (weights_feature->type) { case kTfLiteFloat32: { return EvalFloat(context, node, input, weights_feature, weights_time, - bias, params, scratch, state, output); + bias, params, scratch, activation_state, output); break; } case kTfLiteUInt8: { @@ -392,7 +430,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } return EvalHybrid(context, node, input, weights_feature, float_weights_time, bias, params, scratch, - scaling_factors, input_quantized, state, output); + scaling_factors, input_quantized, activation_state, + output); break; } default: |