aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/svdf.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-21 19:41:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 19:45:00 -0700
commitd7f704424858c30a23c12c39d0841db326b8d148 (patch)
tree052d8ef2ccbca29946cf2b940845ac3828cd8707 /tensorflow/contrib/lite/kernels/svdf.cc
parent496023e9dc84a076caeb2e5e8e13b6a3d819ad6d (diff)
Support Variable Tensor API in SVDF kernel.
TFLite SVDF now supports 5 inputs (with variable tensor) and 4 inputs. PiperOrigin-RevId: 209702845
Diffstat (limited to 'tensorflow/contrib/lite/kernels/svdf.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/svdf.cc107
1 files changed, 73 insertions, 34 deletions
diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc
index 6d4912ce3a..9e8ed3cbf3 100644
--- a/tensorflow/contrib/lite/kernels/svdf.cc
+++ b/tensorflow/contrib/lite/kernels/svdf.cc
@@ -40,19 +40,22 @@ namespace {
struct OpData {
int scratch_tensor_index;
bool float_weights_time_initialized;
+
+ int activation_state_tensor_index;
};
static inline void ApplyTimeWeightsBiasAndActivation(
int batch_size, int memory_size, int num_filters, int num_units, int rank,
const TfLiteTensor* weights_time, const TfLiteTensor* bias,
- TfLiteFusedActivation activation, TfLiteTensor* state,
+ TfLiteFusedActivation activation, TfLiteTensor* activation_state,
TfLiteTensor* scratch, TfLiteTensor* output) {
// Compute matmul(state, weights_time).
// The right most column is used to save temporary output (with the size of
- // num_filters). This is achieved by starting at state->data.f and having the
- // stride equal to memory_size.
+ // num_filters). This is achieved by starting at activation_state->data.f,
+ // and having the stride equal to memory_size.
for (int b = 0; b < batch_size; ++b) {
- float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
+ float* state_ptr_batch =
+ activation_state->data.f + b * memory_size * num_filters;
float* scratch_ptr_batch = scratch->data.f + b * num_filters;
tensor_utils::BatchVectorBatchVectorDotProduct(
weights_time->data.f, state_ptr_batch, memory_size, num_filters,
@@ -82,13 +85,14 @@ static inline void ApplyTimeWeightsBiasAndActivation(
activation, output_ptr_batch);
}
- // Left shift the state to make room for next cycle's activation.
+ // Left shift the activation_state to make room for next cycle's activation.
// TODO(alanchiao): explore collapsing this into a single loop.
for (int b = 0; b < batch_size; ++b) {
- float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
+ float* state_ptr_batch =
+ activation_state->data.f + b * memory_size * num_filters;
for (int f = 0; f < num_filters; ++f) {
tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size,
- /*shift_value=*/0.0);
+ /*shift_value=*/0.0f);
state_ptr_batch += memory_size;
}
}
@@ -96,10 +100,19 @@ static inline void ApplyTimeWeightsBiasAndActivation(
} // namespace
+// Input tensors.
constexpr int kInputTensor = 0;
constexpr int kWeightsFeatureTensor = 1;
constexpr int kWeightsTimeTensor = 2;
constexpr int kBiasTensor = 3;
+
+// * If the node has 5 inputs the following tensor is used as state tensor.
+// This is defined to be a variable tensor, and will be modified by this op.
+constexpr int kInputActivationStateTensor = 4;
+
+// Output tensors.
+// * If node has 4 inputs, kStateTensor will be used as state tensor.
+// * If node has 5 inputs, kStateTensor is ignored.
constexpr int kStateTensor = 0;
constexpr int kOutputTensor = 1;
@@ -121,8 +134,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int scratch_tensor_index = op_data->scratch_tensor_index;
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+ bool use_input_variable_states;
+ if (node->inputs->size == 5) {
+ use_input_variable_states = true;
+ op_data->activation_state_tensor_index =
+ node->inputs->data[kInputActivationStateTensor];
+ } else if (node->inputs->size == 4) {
+ use_input_variable_states = false;
+ op_data->activation_state_tensor_index = node->outputs->data[kStateTensor];
+ } else {
+ context->ReportError(context,
+ "The SVDF kernel expects 4 or 5 inputs. Got %d inputs",
+ node->inputs->size);
+ return kTfLiteError;
+ }
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* weights_feature =
@@ -148,22 +174,32 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units);
}
- TfLiteTensor* state = GetOutput(context, node, kStateTensor);
+ TfLiteTensor* activation_state =
+ &context->tensors[op_data->activation_state_tensor_index];
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- // Resize state.
- // For each batch, the state is a 2-D tensor: memory_size * num_filters
- // The left most column is used to save current cycle activation.
- // The right most column is used to save temporary output which will be
- // reduced to num_units outputs.
- TfLiteIntArray* state_size_array = TfLiteIntArrayCreate(2);
- state_size_array->data[0] = batch_size;
- state_size_array->data[1] = memory_size * num_filters;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, state, state_size_array));
-
- // Mark state as a persistent tensor.
- state->allocation_type = kTfLiteArenaRwPersistent;
+ if (use_input_variable_states) {
+ // Check the shape of input state tensors.
+ TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 0),
+ batch_size);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 1),
+ memory_size * num_filters);
+ } else {
+ // Resize activation_state.
+ // For each batch, the state is a 2-D tensor: memory_size * num_filters
+ // The left most column is used to save current cycle activation.
+ // The right most column is used to save temporary output which will be
+ // reduced to num_units outputs.
+ TfLiteIntArray* state_size_array = TfLiteIntArrayCreate(2);
+ state_size_array->data[0] = batch_size;
+ state_size_array->data[1] = memory_size * num_filters;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state,
+ state_size_array));
+
+ // Mark state as a persistent tensor.
+ activation_state->allocation_type = kTfLiteArenaRwPersistent;
+ }
// Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
@@ -220,8 +256,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scaling_factors_size));
}
- // Used to store dequantized weights_time matrix for hybrid computation
- // of matmul(state, weights_time), which occurs in floating point.
+ // Used to store dequantized weights_time matrix for hybrid computation of
+ // matmul(activation_state, weights_time), which occurs in floating point.
node->temporaries->data[3] = scratch_tensor_index + 3;
TfLiteTensor* float_weights_time = GetTemporary(context, node, /*index=*/3);
float_weights_time->type = kTfLiteFloat32;
@@ -253,13 +289,13 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
const int memory_size = weights_time->dims->data[1];
// Clear the activation (state left most column).
- // TODO(ghodrat): Add a test which initialize state with invalid values in
- // left most column and make sure it passes.
+ // TODO(ghodrat): Add a test which initialize activation_state with invalid
+ // values in left most column and make sure it passes.
for (int b = 0; b < batch_size; ++b) {
float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
for (int c = 0; c < num_filters; ++c) {
float* state_ptr = state_ptr_batch + c * memory_size;
- state_ptr[memory_size - 1] = 0.0;
+ state_ptr[memory_size - 1] = 0.0f;
}
}
@@ -307,7 +343,7 @@ TfLiteStatus EvalHybrid(
// Clear the activation (state left most column).
// TODO(ghodrat): Add a test which initialize state with invalid values in
- // left most column and make sure it passes.
+ // the left most column and make sure it passes.
for (int b = 0; b < batch_size; ++b) {
float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
for (int c = 0; c < num_filters; ++c) {
@@ -329,9 +365,10 @@ TfLiteStatus EvalHybrid(
}
// Compute conv1d(inputs, weights_feature).
- // The state right most column is used to save current cycle activation.
- // This is achieved by starting at state->data.f[memory_size - 1] and having
- // the stride equal to memory_size.
+ // The rightmost column of state is used to save the current cycle
+ // activation.
+ // This is achieved by starting at state->data.f[memory_size - 1]
+ // and having the stride equal to memory_size.
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
weights_feature_ptr, num_filters, input_size, quantized_input_ptr_batch,
scaling_factors_ptr, batch_size, &state->data.f[memory_size - 1],
@@ -359,13 +396,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* state = GetOutput(context, node, kStateTensor);
+ TfLiteTensor* activation_state =
+ &context->tensors[op_data->activation_state_tensor_index];
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (weights_feature->type) {
case kTfLiteFloat32: {
return EvalFloat(context, node, input, weights_feature, weights_time,
- bias, params, scratch, state, output);
+ bias, params, scratch, activation_state, output);
break;
}
case kTfLiteUInt8: {
@@ -392,7 +430,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
return EvalHybrid(context, node, input, weights_feature,
float_weights_time, bias, params, scratch,
- scaling_factors, input_quantized, state, output);
+ scaling_factors, input_quantized, activation_state,
+ output);
break;
}
default: