aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc174
1 files changed, 92 insertions, 82 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 38ad32c734..420bc68b43 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -55,83 +55,33 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
const int postamble_start =
m_cols - (m_cols & (kFloatWeightsPerNeonLane - 1));
- // The arrays used to cache the vector.
- void* aligned_vector_cache_free = nullptr;
- float32x4_t* vector_cache_float32x4 =
- reinterpret_cast<float32x4_t*>(aligned_alloc(
- sizeof(float32x4_t), (postamble_start >> 2) * sizeof(float32x4_t),
- &aligned_vector_cache_free));
-
- const int kUnrollSize = 2;
for (int b = 0; b < n_batch; b++) {
float* result_in_batch = result + b * m_rows * result_stride;
const float* vector_in_batch = vector + b * m_cols;
+ const float* matrix_row = matrix;
- const float* matrix_ptr0 = matrix;
- // If there is only 1 row, we don't want to assign an illegal pointer.
- const float* matrix_ptr1 = nullptr;
- if (m_rows > 1) {
- matrix_ptr1 = matrix + m_cols;
- }
-
- // Cache the vector.
- for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
- vector_cache_float32x4[c >> 2] = vld1q_f32(vector_in_batch + c);
- }
-
- // Main matrix by vector multiplication loop, which handles two rows of
- // matrix by vector multiplication.
- for (int r = 0; r < (m_rows & ~(kUnrollSize - 1)); r += kUnrollSize) {
- float32x4_t acc0_32x4 = vmovq_n_f32(0.0);
- float32x4_t acc1_32x4 = vmovq_n_f32(0.0);
+ // Main matrix by vector multiplication loop
+ for (int r = 0; r < m_rows; r++) {
+ float32x4_t acc_32x4 = vmovq_n_f32(0.0);
for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
- float32x4_t temp = vector_cache_float32x4[c >> 2];
- // Load 4 float values from vector1 and vector2 and accumulator.
- float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c);
- float32x4_t v1_f32x4 = vld1q_f32(matrix_ptr1 + c);
- // Vector multiply-accumulate 4 float
- acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp);
- acc1_32x4 = vmlaq_f32(acc1_32x4, v1_f32x4, temp);
+ // Load 4 float values from vector and matrix row.
+ float32x4_t vector_f32x4 = vld1q_f32(vector_in_batch + c);
+ float32x4_t matrix_f32x4 = vld1q_f32(matrix_row + c);
+ // Multiply the vector and matrix row and add to accumulator.
+ acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4);
}
// Add the 4 intermediate sum values to get the final dot-prod value for
// this column.
*result_in_batch +=
- (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) +
- vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3));
- *(result_in_batch + result_stride) +=
- (vgetq_lane_f32(acc1_32x4, 0) + vgetq_lane_f32(acc1_32x4, 1) +
- vgetq_lane_f32(acc1_32x4, 2) + vgetq_lane_f32(acc1_32x4, 3));
+ (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) +
+ vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3));
for (int c = postamble_start; c < m_cols; c++) {
- *result_in_batch += matrix_ptr0[c] * vector_in_batch[c];
- *(result_in_batch + result_stride) +=
- matrix_ptr1[c] * vector_in_batch[c];
+ *result_in_batch += matrix_row[c] * vector_in_batch[c];
}
- matrix_ptr0 += kUnrollSize * m_cols;
- matrix_ptr1 += kUnrollSize * m_cols;
- result_in_batch += kUnrollSize * result_stride;
- }
- for (int r = (m_rows & ~(kUnrollSize - 1)); r < m_rows; r++) {
- float32x4_t acc0_32x4 = vmovq_n_f32(0.0);
- for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
- float32x4_t temp = vector_cache_float32x4[c >> 2];
- // Load 4 float values from vector1 and vector2 and accumulator.
- float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c);
- // Vector multiply-accumulate 4 float
- acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp);
- }
- // Add the 4 intermediate sum values to get the final dot-prod value for
- // this column.
- *result_in_batch +=
- (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) +
- vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3));
- for (int c = postamble_start; c < m_cols; c++) {
- *result_in_batch += matrix_ptr0[c] * vector_in_batch[c];
- }
- matrix_ptr0 += m_cols;
+ matrix_row += m_cols;
result_in_batch += result_stride;
}
}
- free(aligned_vector_cache_free);
}
void NeonMatrixBatchVectorMultiplyAccumulate(
@@ -162,7 +112,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
int batch, row, col;
for (batch = 0; batch < n_batch; ++batch) {
- const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch];
+ const float batch_scaling_factor = scaling_factors[batch];
// Copy the vector data to an aligned vector.
memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8) * m_cols);
// Compute dot-product for every column.
@@ -232,7 +182,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
int32 neon_sum =
vgetq_lane_s64(pairwiseAdded, 0) + vgetq_lane_s64(pairwiseAdded, 1);
- *result += ((neon_sum + postable_sum) * batch_scaling_factor_inv);
+ *result += ((neon_sum + postable_sum) * batch_scaling_factor);
} // for row
} // for batch
@@ -296,17 +246,6 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
const int postamble_start =
v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
- // The arrays used to cache the vector.
- void* aligned_vector_cache_free = nullptr;
- float32x4_t* vector_cache_float32x4 =
- reinterpret_cast<float32x4_t*>(aligned_alloc(
- sizeof(float32x4_t), (postamble_start >> 2) * sizeof(float32x4_t),
- &aligned_vector_cache_free));
-
- for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
- vector_cache_float32x4[v >> 2] = vld1q_f32(vector + v);
- }
-
float* result_ptr = result;
const float* batch_vector_ptr = batch_vector;
for (int b = 0; b < n_batch; b++) {
@@ -314,9 +253,9 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
// Load from memory to vectors.
float32x4_t result_f32x4 = vld1q_f32(result_ptr + v);
float32x4_t batch_vector_f32x4 = vld1q_f32(batch_vector_ptr + v);
+ float32x4_t vector_f32x4 = vld1q_f32(vector + v);
// Multiply-accumulate.
- result_f32x4 = vmlaq_f32(result_f32x4, batch_vector_f32x4,
- vector_cache_float32x4[v >> 2]);
+ result_f32x4 = vmlaq_f32(result_f32x4, batch_vector_f32x4, vector_f32x4);
// Store.
vst1q_f32(result_ptr + v, result_f32x4);
}
@@ -328,7 +267,6 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
result_ptr += v_size;
batch_vector_ptr += v_size;
}
- free(aligned_vector_cache_free);
}
void NeonSub1Vector(const float* vector, int v_size, float* result) {
@@ -404,6 +342,77 @@ void NeonClipVector(const float* vector, int v_size, float abs_limit,
}
}
+void NeonVectorScalarMultiply(const int8_t* vector, const int v_size,
+ const float scale, float* result) {
+ // Here the assumption is that each buffer is 4-byte aligned.
+ const int kWeightsPerUint32 = 4;
+ TFLITE_CHECK_EQ((intptr_t)(&vector[0]) & (kWeightsPerUint32 - 1), 0);
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int kWeightsPerNeonLane = 16;
+ const int postamble_start = v_size - (v_size & (kWeightsPerNeonLane - 1));
+
+ // Create a vector of 4 floats with the scale value.
+ const float32x4_t scale_f32x4 = vdupq_n_f32(scale);
+ int v = 0;
+ for (; v < postamble_start; v += kWeightsPerNeonLane) {
+ // Load int8 values, sixteen at a time.
+ const int8x16_t v_i8x16 = vld1q_s8(vector + v);
+ // Split it into two components of size eight.
+ const int8x8_t v0_i8x8 = vget_low_s8(v_i8x16);
+ const int8x8_t v1_i8x8 = vget_high_s8(v_i8x16);
+ // Convert both components to int16 first.
+ const int16x8_t v0_i16x8 = vmovl_s8(v0_i8x8);
+ const int16x8_t v1_i16x8 = vmovl_s8(v1_i8x8);
+ // Split each of them into two components each.
+ const int16x4_t v0_i16x4 = vget_low_s16(v0_i16x8);
+ const int16x4_t v1_i16x4 = vget_high_s16(v0_i16x8);
+ const int16x4_t v2_i16x4 = vget_low_s16(v1_i16x8);
+ const int16x4_t v3_i16x4 = vget_high_s16(v1_i16x8);
+ // Convert these to int32 and then to float.
+ float32x4_t v0_f32x4 = vcvtq_f32_s32(vmovl_s16(v0_i16x4));
+ float32x4_t v1_f32x4 = vcvtq_f32_s32(vmovl_s16(v1_i16x4));
+ float32x4_t v2_f32x4 = vcvtq_f32_s32(vmovl_s16(v2_i16x4));
+ float32x4_t v3_f32x4 = vcvtq_f32_s32(vmovl_s16(v3_i16x4));
+ // Vector multiply four floats at a time.
+ v0_f32x4 = vmulq_f32(v0_f32x4, scale_f32x4);
+ v1_f32x4 = vmulq_f32(v1_f32x4, scale_f32x4);
+ v2_f32x4 = vmulq_f32(v2_f32x4, scale_f32x4);
+ v3_f32x4 = vmulq_f32(v3_f32x4, scale_f32x4);
+ // Store the results.
+ vst1q_f32(result + v, v0_f32x4);
+ vst1q_f32(result + v + 4, v1_f32x4);
+ vst1q_f32(result + v + 8, v2_f32x4);
+ vst1q_f32(result + v + 12, v3_f32x4);
+ }
+
+ if (v_size - postamble_start >= (kWeightsPerNeonLane >> 1)) {
+ // Load eight int8 values, if there is at least eight remaining.
+ const int8x8_t v_i8x8 = vld1_s8(vector + v);
+ // Convert them to int16 first.
+ const int16x8_t v_i16x8 = vmovl_s8(v_i8x8);
+ // Split it into two components.
+ const int16x4_t v0_i16x4 = vget_low_s16(v_i16x8);
+ const int16x4_t v1_i16x4 = vget_high_s16(v_i16x8);
+ // Convert the components two floats.
+ float32x4_t v0_f32x4 = vcvtq_f32_s32(vmovl_s16(v0_i16x4));
+ float32x4_t v1_f32x4 = vcvtq_f32_s32(vmovl_s16(v1_i16x4));
+ // Vector multiply four floats at a time.
+ v0_f32x4 = vmulq_f32(v0_f32x4, scale_f32x4);
+ v1_f32x4 = vmulq_f32(v1_f32x4, scale_f32x4);
+ // Store the results.
+ vst1q_f32(result + v, v0_f32x4);
+ vst1q_f32(result + v + 4, v1_f32x4);
+ v += (kWeightsPerNeonLane >> 1);
+ }
+
+ // Postamble loop.
+ for (; v < v_size; v++) {
+ result[v] = scale * vector[v];
+ }
+}
+
void NeonSymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values, float* min,
float* max, float* scaling_factor) {
@@ -418,13 +427,14 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size,
*scaling_factor = 1;
return;
}
- *scaling_factor = kScale / range;
+ *scaling_factor = range / kScale;
+ const float scaling_factor_inv = 1.0f / *scaling_factor;
const int postamble_start =
size - (size & (2 * kFloatWeightsPerNeonLane - 1));
// Vectorized constants.
- const float32x4_t q_factor_f32x4 = vmovq_n_f32(*scaling_factor);
+ const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv);
const float32x4_t point5_f32x4 = vmovq_n_f32(0.5);
const float32x4_t zero_f32x4 = vmovq_n_f32(0.0);
const int32x4_t scale_i32x4 = vmovq_n_s32(kScale);
@@ -476,7 +486,7 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size,
for (int i = postamble_start; i < size; ++i) {
const int32 quantized_value =
- static_cast<int32>(TfLiteRound(*scaling_factor * values[i]));
+ static_cast<int32>(TfLiteRound(scaling_factor_inv * values[i]));
quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value));
}
}