aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-27 14:24:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-27 15:32:39 -0700
commit16cda320d92cfbfc6870140691ae2c5e6286688c (patch)
tree87a60a261560dd840f6be4c4fec89d15c5532da5
parent71b993a63c9f4c62d45303623f926219066902cc (diff)
Arm32/64 kernel optimizations:
- QuantizeV2 - Dequantize - QuantizedBiasAdd - QuantizeDownAndShrinkRange - QuantizedRelu - QuantizedRelu6 - QuantizedMatMul - QuantizedConv The optimizations are controled by three knobs: meta::SetEnabled(bool) -- turns codepath on/off, on by default meta::SetUseLocalContext(bool) -- true -- codepath will use it's own internal fine grain workers pool that offers performance improvement over the standard tensorflow worker pool. This workers pool is not compatible with other ops. Per use-case performance testing recommended. -- false (default) -- use the standard tf worker pool instance meta::SetNumThreads(int) -- no. of compute threads when the internal worker pool is used. If 0 use intra_parallelism_count, if x > 0 then x threads. Change: 137448955
-rw-r--r--tensorflow/contrib/cmake/external/gemmlowp.cmake4
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt1
-rw-r--r--tensorflow/core/kernels/BUILD4
-rw-r--r--tensorflow/core/kernels/dequantize_op.cc15
-rw-r--r--tensorflow/core/kernels/meta_support.cc373
-rw-r--r--tensorflow/core/kernels/meta_support.h112
-rw-r--r--tensorflow/core/kernels/quantize_down_and_shrink_range.cc17
-rw-r--r--tensorflow/core/kernels/quantize_op.cc15
-rw-r--r--tensorflow/core/kernels/quantized_activation_ops.cc34
-rw-r--r--tensorflow/core/kernels/quantized_bias_add_op.cc25
-rw-r--r--tensorflow/core/kernels/quantized_conv_ops.cc27
-rw-r--r--tensorflow/core/kernels/quantized_matmul_op.cc27
-rw-r--r--tensorflow/workspace.bzl6
13 files changed, 615 insertions, 45 deletions
diff --git a/tensorflow/contrib/cmake/external/gemmlowp.cmake b/tensorflow/contrib/cmake/external/gemmlowp.cmake
index 11868d44dd..024c064cf4 100644
--- a/tensorflow/contrib/cmake/external/gemmlowp.cmake
+++ b/tensorflow/contrib/cmake/external/gemmlowp.cmake
@@ -1,7 +1,7 @@
include (ExternalProject)
-set(gemmlowp_URL http://github.com/google/gemmlowp/archive/c0bacf11fb509a2cbe15a97362a2df067ffd57a2.tar.gz)
-set(gemmlowp_HASH SHA256=dc64a38f9927db18748d9024987c9b102115e25bc2be4b76aa8e422b8f83d882)
+set(gemmlowp_URL http://github.com/google/gemmlowp/archive/a6f29d8ac48d63293f845f2253eccbf86bc28321.tar.gz)
+set(gemmlowp_HASH SHA256=75d40ea8e68b0d1644f052fffe8f14a410b2a73d40ccb859a95c0578d194ec26)
set(gemmlowp_BUILD ${CMAKE_BINARY_DIR}/gemmlowp/src/gemmlowp)
set(gemmlowp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/gemmlowp/src/gemmlowp)
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 2633a3a939..ed5d6539b3 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -143,6 +143,7 @@ tensorflow/core/kernels/avgpooling_op.cc
tensorflow/core/kernels/argmax_op.cc
tensorflow/core/kernels/aggregate_ops.cc
tensorflow/core/kernels/dequantize_op.cc
+tensorflow/core/kernels/meta_support.cc
tensorflow/core/kernels/quantization_utils.cc
tensorflow/core/kernels/quantize_down_and_shrink_range.cc
tensorflow/core/kernels/quantize_op.cc
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 1a9001f99b..478d1bc332 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2430,6 +2430,8 @@ filegroup(
name = "android_quantized_ops",
srcs = [
"dequantize_op.cc",
+ "meta_support.cc",
+ "meta_support.h",
"quantization_utils.cc",
"quantization_utils.h",
"quantize_down_and_shrink_range.cc",
@@ -2531,6 +2533,7 @@ tf_kernel_library(
name = "quantized_ops",
srcs = [
"dequantize_op.cc",
+ "meta_support.cc",
"quantization_utils.cc",
"quantize_down_and_shrink_range.cc",
"quantize_op.cc",
@@ -2547,6 +2550,7 @@ tf_kernel_library(
"reshape_op.h",
],
hdrs = [
+ "meta_support.h",
"quantization_utils.h",
"reference_gemm.h",
],
diff --git a/tensorflow/core/kernels/dequantize_op.cc b/tensorflow/core/kernels/dequantize_op.cc
index 375287000e..c28909e03b 100644
--- a/tensorflow/core/kernels/dequantize_op.cc
+++ b/tensorflow/core/kernels/dequantize_op.cc
@@ -17,11 +17,12 @@ limitations under the License.
#define EIGEN_USE_THREADS
-#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/meta_support.h"
+#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/errors.h"
namespace {
@@ -75,9 +76,15 @@ class DequantizeOp : public OpKernel {
scale_factor) +
min_range;
} else if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
- QuantizedTensorToFloatInPlaceUsingEigen<T>(
- ctx->template eigen_device<Device>(), input, min_range, max_range,
- output);
+ if (meta::IsSupportedAndEnabled() && std::is_same<T, quint8>()) {
+ auto input_ui8_array = input.flat<quint8>();
+ meta::Dequantize(ctx, input_ui8_array.data(), input_ui8_array.size(),
+ min_range, max_range, output->flat<float>().data());
+ } else {
+ QuantizedTensorToFloatInPlaceUsingEigen<T>(
+ ctx->template eigen_device<Device>(), input, min_range, max_range,
+ output);
+ }
}
}
diff --git a/tensorflow/core/kernels/meta_support.cc b/tensorflow/core/kernels/meta_support.cc
new file mode 100644
index 0000000000..4ef56d1987
--- /dev/null
+++ b/tensorflow/core/kernels/meta_support.cc
@@ -0,0 +1,373 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/meta_support.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/quantization_utils.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+
+#if (defined(GEMMLOWP_NEON_32) || defined(GEMMLOWP_NEON_64)) && \
+ !defined(TENSORFLOW_DISABLE_META) && !defined(__APPLE__)
+#define TENSORFLOW_USE_META (1)
+#endif
+
+namespace tensorflow {
+namespace meta {
+
+namespace {
+
+int g_num_threads = 0;
+bool g_enabled = true;
+bool g_use_local_context = false;
+
+#ifdef TENSORFLOW_USE_META
+
+uint8_t* GetScratch() {
+ static uint8_t* scratch = new uint8_t[2048 * 1024];
+ return scratch;
+}
+
+gemmlowp::WorkersPool* GetWorkersPool() {
+ static gemmlowp::WorkersPool* pool = new gemmlowp::WorkersPool();
+ return pool;
+}
+
+mutex& GetMutex() {
+ static mutex mu;
+ return mu;
+}
+
+int GetWorkersCount(OpKernelContext* tf_context) {
+ if (g_num_threads == 0) {
+ return tf_context->device()->tensorflow_cpu_worker_threads()->num_threads;
+ }
+ return g_num_threads;
+}
+
+typedef gemmlowp::meta::SimpleContext<gemmlowp::WorkersPool> LocalContext;
+
+template <typename Context, typename Params>
+void MultiThreadGemm(Context* context, const Params& params) {
+ if (params.m <= 4) {
+ gemmlowp::meta::Gemm<gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>,
+ Params, 1, 8, 8>(params);
+ } else {
+ if (params.m >= params.n) {
+ gemmlowp::meta::MultiThreadGemm<
+ Context, gemmlowp::meta::GemmExecutorPackRHSCacheFriendly<>, Params,
+ 2, 4, 8>(context, params);
+ } else {
+ gemmlowp::meta::MultiThreadGemm<
+ Context, gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>, Params,
+ 2, 4, 8>(context, params);
+ }
+ }
+}
+
+template <typename LeftStream, typename RightStream>
+void QuantizedGemmImpl(OpKernelContext* tf_context, const quint8* a_data,
+ const quint8* b_data, qint32* c_data, int m, int n,
+ int k, int offset_a, int offset_b, int lda, int ldb,
+ int ldc) {
+ typedef gemmlowp::meta::GemmParams<
+ uint8_t, int32_t, LeftStream, RightStream,
+ gemmlowp::meta::QuantizedStaticPreprocessedAsInt32,
+ gemmlowp::meta::RowMajor>
+ Params;
+ Params params;
+
+ params.m = m;
+ params.n = n;
+ params.k = k;
+
+ params.lhs = reinterpret_cast<const uint8_t*>(&(a_data->value));
+ params.rhs = reinterpret_cast<const uint8_t*>(&(b_data->value));
+ params.result = reinterpret_cast<int32_t*>(&(c_data->value));
+ params.scratch = GetScratch();
+
+ params.left_stream.count = k;
+ params.left_stream.stride = lda;
+ params.left_stream.multiplicative_sum_offset = offset_b;
+ params.left_stream.additive_sum_offset = k * offset_a * offset_b;
+
+ params.right_stream.count = k;
+ params.right_stream.stride = ldb;
+ params.right_stream.multiplicative_sum_offset = offset_a;
+ params.right_stream.additive_sum_offset = 0;
+
+ params.fused_kernel.kernel.count = k;
+ params.fused_kernel.output_stream.stride = ldc * sizeof(int32_t);
+
+ if (g_use_local_context) {
+ LocalContext local_context(GetWorkersCount(tf_context), GetWorkersPool());
+ MultiThreadGemm<LocalContext, Params>(&local_context, params);
+ } else {
+ auto& workers = *(tf_context->device()->tensorflow_cpu_worker_threads());
+ TensorflowGemmContext context(workers.num_threads, workers.workers);
+ MultiThreadGemm<TensorflowGemmContext, Params>(&context, params);
+ }
+}
+
+template <typename Params, int kernel_size>
+void MultiThreadTransform1D(OpKernelContext* tf_context, const Params& params) {
+ if (g_use_local_context) {
+ LocalContext local_context(GetWorkersCount(tf_context), GetWorkersPool());
+ gemmlowp::meta::MultiThreadTransform1D<LocalContext, Params, kernel_size>(
+ &local_context, params);
+ } else {
+ auto& workers = *(tf_context->device()->tensorflow_cpu_worker_threads());
+ TensorflowGemmContext context(workers.num_threads, workers.workers);
+ gemmlowp::meta::MultiThreadTransform1D<TensorflowGemmContext, Params,
+ kernel_size>(&context, params);
+ }
+}
+
+template <typename QuantizedType>
+double CalculateRangeScale(float min, float max) {
+ const int bits = sizeof(QuantizedType) * 8;
+ return static_cast<double>(max - min) /
+ ((static_cast<int64_t>(1) << bits) - 1);
+}
+
+template <typename QuantizedType>
+double CalculateOneOverRangeScale(float min, float max) {
+ if (min == max) {
+ return 0.0;
+ }
+ const int bits = sizeof(QuantizedType) * 8;
+ return static_cast<double>((static_cast<int64_t>(1) << bits) - 1) /
+ (max - min);
+}
+
+#endif // TENSORFLOW_USE_META
+
+} // namespace
+
+void SetNumThreads(int num_threads) { g_num_threads = num_threads; }
+
+int GetNumThreads() { return g_num_threads; }
+
+void SetUseLocalContext(bool use_local_context) {
+ g_use_local_context = use_local_context;
+}
+
+bool GetUseLocalContext() { return g_use_local_context; }
+
+bool IsSupported() {
+#if defined(TENSORFLOW_USE_META)
+ return true;
+#else
+ return false;
+#endif
+}
+
+bool IsEnabled() { return g_enabled; }
+
+void SetEnabled(bool enabled) { g_enabled = enabled; }
+
+bool IsSupportedAndEnabled() { return IsSupported() && IsEnabled(); }
+
+void QuantizedGemm(OpKernelContext* tf_context, bool transpose_a,
+ bool transpose_b, const quint8* a_data, const quint8* b_data,
+ qint32* c_data, int m, int n, int k, int offset_a,
+ int offset_b, int lda, int ldb, int ldc) {
+#ifdef TENSORFLOW_USE_META
+ mutex_lock library_lock(GetMutex());
+ if (transpose_a) {
+ if (transpose_b) {
+ QuantizedGemmImpl<gemmlowp::meta::ColumnMajorWithSum,
+ gemmlowp::meta::RowMajorWithSum>(
+ tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
+ ldb, ldc);
+ } else {
+ QuantizedGemmImpl<gemmlowp::meta::ColumnMajorWithSum,
+ gemmlowp::meta::ColumnMajorWithSum>(
+ tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
+ ldb, ldc);
+ }
+ } else {
+ if (transpose_b) {
+ QuantizedGemmImpl<gemmlowp::meta::RowMajorWithSum,
+ gemmlowp::meta::RowMajorWithSum>(
+ tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
+ ldb, ldc);
+ } else {
+ QuantizedGemmImpl<gemmlowp::meta::RowMajorWithSum,
+ gemmlowp::meta::ColumnMajorWithSum>(
+ tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
+ ldb, ldc);
+ }
+ }
+#else
+ LOG(FATAL) << "QuantizedGemm: Meta fastpath not supported.";
+#endif
+}
+
+void Requantize(OpKernelContext* tf_context, const qint32* input, int count,
+ float input_min, float input_max, float output_min,
+ float output_max, quint8* output) {
+#ifdef TENSORFLOW_USE_META
+ mutex_lock library_lock(GetMutex());
+ typedef gemmlowp::meta::Transform1DParams<int32_t, uint8_t,
+ gemmlowp::meta::Requantize>
+ Params;
+
+ Params params;
+ params.input = reinterpret_cast<const int32_t*>(input);
+ params.output = reinterpret_cast<uint8_t*>(output);
+ params.kernel.count = count;
+ params.kernel.input_range_min = input_min;
+ params.kernel.output_range_min = output_min;
+ params.kernel.input_range_scale =
+ CalculateRangeScale<int32_t>(input_min, input_max);
+ params.kernel.one_over_output_range_scale =
+ CalculateOneOverRangeScale<uint8_t>(output_min, output_max);
+ params.kernel.input_range_offset =
+ static_cast<float>(std::numeric_limits<int32_t>::lowest());
+
+ // After adding the output_range_offset the value is cast from float to uint.
+ // The float to int/uint cast in NEON uses round toward 0. To keep the
+ // rounding consistent with Eigen, which uses round toward closest, we can
+ // add 0.5f and exploit the fact that we only operate on non negative values.
+ // TODO(maciekc): fix the actual kernel in gemmlowp/meta
+ params.kernel.output_range_offset =
+ static_cast<float>(std::numeric_limits<uint8_t>::lowest()) + 0.5f;
+
+ MultiThreadTransform1D<Params, 16>(tf_context, params);
+#else
+ LOG(FATAL) << "Requantize: Meta fastpath not supported.";
+#endif
+}
+
+void Dequantize(OpKernelContext* tf_context, const quint8* input, int count,
+ float range_min, float range_max, float* output) {
+#ifdef TENSORFLOW_USE_META
+ mutex_lock library_lock(GetMutex());
+ typedef gemmlowp::meta::Transform1DParams<uint8_t, float,
+ gemmlowp::meta::Dequantize>
+ Params;
+
+ Params params;
+ params.input = reinterpret_cast<const uint8_t*>(input);
+ params.output = reinterpret_cast<float*>(output);
+ params.kernel.count = count;
+ params.kernel.range_min = range_min;
+ params.kernel.range_scale =
+ CalculateRangeScale<uint8_t>(range_min, range_max);
+ params.kernel.range_offset =
+ static_cast<float>(std::numeric_limits<uint8_t>::lowest());
+
+ MultiThreadTransform1D<Params, 16>(tf_context, params);
+#else
+ LOG(FATAL) << "Dequantize: Meta fastpath not supported.";
+#endif
+}
+
+void Quantize(OpKernelContext* tf_context, const float* input, int count,
+ float range_min, float range_max, quint8* output) {
+#ifdef TENSORFLOW_USE_META
+ mutex_lock library_lock(GetMutex());
+ typedef gemmlowp::meta::Transform1DParams<float, uint8_t,
+ gemmlowp::meta::Quantize>
+ Params;
+
+ Params params;
+ params.input = reinterpret_cast<const float*>(input);
+ params.output = reinterpret_cast<uint8_t*>(output);
+ params.kernel.count = count;
+ params.kernel.range_min = range_min;
+ params.kernel.range_scale =
+ CalculateOneOverRangeScale<uint8_t>(range_min, range_max);
+
+ // After adding the range_offset the value is cast from float to uint.
+ // The float to int/uint cast in NEON uses round toward 0. To keep the
+ // rounding consistent with Eigen, which uses round toward closest, we can
+ // add 0.5f and exploit the fact that we only operate on non negative values.
+ // TODO(maciekc): fix the the actual kernel in gemmlowp/meta
+ params.kernel.range_offset =
+ static_cast<float>(std::numeric_limits<uint8_t>::lowest()) + 0.5f;
+
+ MultiThreadTransform1D<Params, 16>(tf_context, params);
+#else
+ LOG(FATAL) << "Quantize: Meta fastpath not supported.";
+#endif
+}
+
+void QuantizedBiasAdd(OpKernelContext* tf_context, const quint8* input,
+ int input_count, const quint8* bias, int bias_count,
+ float input_min, float input_max, float bias_min,
+ float bias_max, float output_min, float output_max,
+ qint32* output) {
+#ifdef TENSORFLOW_USE_META
+ mutex_lock library_lock(GetMutex());
+ typedef gemmlowp::meta::Transform1DParams<uint8_t, int32_t,
+ gemmlowp::meta::BiasAdd<uint8_t>>
+ Params;
+
+ Params params;
+ params.input = reinterpret_cast<const uint8_t*>(input);
+ params.output = reinterpret_cast<int32_t*>(output);
+ params.kernel.bias = reinterpret_cast<const uint8_t*>(bias);
+ params.kernel.count = bias_count;
+ params.kernel.rows = input_count / bias_count;
+ params.kernel.input_range_min = input_min;
+ params.kernel.bias_range_min = bias_min;
+ params.kernel.input_range_scale =
+ CalculateRangeScale<uint8_t>(input_min, input_max);
+ params.kernel.bias_range_scale =
+ CalculateRangeScale<uint8_t>(bias_min, bias_max);
+ params.kernel.input_range_offset = 0;
+ params.kernel.bias_range_offset = 0;
+ params.kernel.output_range_min = output_min;
+ params.kernel.one_over_output_range_scale =
+ CalculateOneOverRangeScale<int32_t>(output_min, output_max);
+ params.kernel.output_range_offset =
+ static_cast<float>(std::numeric_limits<int32_t>::lowest());
+
+ // TODO(maciekc): add multithreading to bias add.
+ // Right now this kernel does not support multi threaded execution.
+ gemmlowp::meta::Transform1D<Params, 16>(params);
+#else
+ LOG(FATAL) << "QuantizedBiasAdd: Meta fastpath not supported.";
+#endif
+}
+
+void Clamp(OpKernelContext* tf_context, const quint8* input, int count,
+ quint8 clamp_min, quint8 clamp_max, quint8* output) {
+#ifdef TENSORFLOW_USE_META
+ mutex_lock library_lock(GetMutex());
+ typedef gemmlowp::meta::Transform1DParams<uint8_t, uint8_t,
+ gemmlowp::meta::MinMax<uint8_t>>
+ Params;
+
+ Params params;
+ params.input = reinterpret_cast<const uint8_t*>(input);
+ params.output = reinterpret_cast<uint8_t*>(output);
+ params.kernel.count = count;
+ params.kernel.min = clamp_min;
+ params.kernel.max = clamp_max;
+
+ MultiThreadTransform1D<Params, 16>(tf_context, params);
+#else
+ LOG(FATAL) << "Clamp: Meta fastpath not supported.";
+#endif
+}
+
+} // namespace meta
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/meta_support.h b/tensorflow/core/kernels/meta_support.h
new file mode 100644
index 0000000000..0d87baf034
--- /dev/null
+++ b/tensorflow/core/kernels/meta_support.h
@@ -0,0 +1,112 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_META_SUPPORT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_META_SUPPORT_H_
+
+#include "meta/multi_thread_gemm.h"
+#include "meta/multi_thread_transform.h"
+#include "meta/quantized_mul_kernels.h"
+#include "meta/streams.h"
+#include "meta/transform_kernels.h"
+
+#include "tensorflow/core/framework/numeric_types.h"
+
+namespace tensorflow {
+
+class OpKernelContext;
+
+namespace meta {
+
+// Gemmlowp/meta is a small library of optimized Arm32/64 kernels for quantized
+// matrix multiplication and other quantized computations.
+
+// Set the maximum number of threads of computation that the internal workers
+// pool can use. If num_threads is 0, then use intra_op_parallelism_threads.
+void SetNumThreads(int num_threads);
+
+int GetNumThreads();
+
+// Toggle the internal workers pool. If set to false, the computations will
+// use the worker pool passed each time in the OpKernelContext. If set to true
+// then the OpKernelContext will be ignored, and the internal optimized workers
+// pool will be used.
+//
+// The internal workers pool is disabled by default (false).
+void SetUseLocalContext(bool use_local_context);
+
+bool GetUseLocalContext();
+
+// Toggles the codepath. Enabled by default (true) on supported platforms.
+void SetEnabled(bool enabled);
+
+// Returns true if the codepath is supported and is enabled. Use this call
+// before calling the compute functions. If the codepath is not supported, and
+// any of the compute function is called, the library will log a FATAL error.
+bool IsSupportedAndEnabled();
+
+// Calculate the quantized matrix multiplication:
+//
+// for (i, j) in [0, m) x [0, n) do
+// c_data[i, j] :=
+// sum((a_data[i, l] + offset_a) * (b_data[l, j] + offset_b)) : l in [0, k)
+//
+// If transpose_a is false the lhs operand has row major layout, otherwise
+// column major. Similarily transpose_b describes the layout of the rhs operand.
+// lda, ldb, and ldc are the strides of the lhs operand, rhs operand and the
+// result arrays.
+void QuantizedGemm(OpKernelContext* context, bool transpose_a, bool transpose_b,
+ const quint8* a_data, const quint8* b_data, qint32* c_data,
+ int m, int n, int k, int offset_a, int offset_b, int lda,
+ int ldb, int ldc);
+
+// Take an array of numbers from the range [input_min, input_max] quantized
+// uniformly to int32 values, recover their float values, and then quantize
+// them back uniformly to the range [output_min, output_max] as uint8.
+// Saturate the uint8 values.
+void Requantize(OpKernelContext* context, const qint32* input, int count,
+ float input_min, float input_max, float output_min,
+ float output_max, quint8* output);
+
+// Take an array of numbers from the range [range_min, range_max] quantized
+// uniformly to uint8 values and recover their float values.
+void Dequantize(OpKernelContext* context, const quint8* input, int count,
+ float range_min, float range_max, float* output);
+
+// Take an array of float values and quantize them uniformly to the range
+// [range_min, range_max] expressed as uint8. Saturate the uint8 values.
+void Quantize(OpKernelContext*, const float* input, int count, float range_min,
+ float range_max, quint8* output);
+
+// Take two arrays: the inputs and the bias quantized uniformly in the ranges
+// [input_min, input_max], and [bias_min, bias_max] accordingly, as uint8
+// values. Recover their float values. Add the values. Quantize them back
+// uniformly to the range [output_min, output_max] as int32. Saturate the
+// int32 values.
+void QuantizedBiasAdd(OpKernelContext* context, const quint8* input,
+ int input_count, const quint8* bias, int bias_count,
+ float input_min, float input_max, float bias_min,
+ float bias_max, float output_min, float output_max,
+ qint32* output);
+
+// Take an array of uint8 values and clamp them to the range [clamp_min,
+// clamp_max].
+void Clamp(OpKernelContext* context, const quint8* input, int input_count,
+ quint8 clamp_min, quint8 clamp_max, quint8* output);
+
+} // namespace meta
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_META_SUPPORT_H_
diff --git a/tensorflow/core/kernels/quantize_down_and_shrink_range.cc b/tensorflow/core/kernels/quantize_down_and_shrink_range.cc
index 5806d68944..9893a85587 100644
--- a/tensorflow/core/kernels/quantize_down_and_shrink_range.cc
+++ b/tensorflow/core/kernels/quantize_down_and_shrink_range.cc
@@ -20,11 +20,12 @@ limitations under the License.
#include <math.h>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/meta_support.h"
+#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -79,9 +80,17 @@ class QuantizeDownAndShrinkRangeOp : public OpKernel {
#endif
if (input_array.size() > 0) {
- RequantizeManyInNewRangeUsingEigen<T1, T2>(
- ctx->eigen_device<CPUDevice>(), input, input_min_float,
- input_max_float, actual_min_float, actual_max_float, output);
+ if (meta::IsSupportedAndEnabled() && std::is_same<T1, qint32>() &&
+ std::is_same<T2, quint8>()) {
+ auto input_i32_array = input.flat<qint32>();
+ meta::Requantize(ctx, input_i32_array.data(), input_i32_array.size(),
+ input_min_float, input_max_float, actual_min_float,
+ actual_max_float, output->flat<quint8>().data());
+ } else {
+ RequantizeManyInNewRangeUsingEigen<T1, T2>(
+ ctx->eigen_device<CPUDevice>(), input, input_min_float,
+ input_max_float, actual_min_float, actual_max_float, output);
+ }
}
output_min->flat<float>().setConstant(actual_min_float);
diff --git a/tensorflow/core/kernels/quantize_op.cc b/tensorflow/core/kernels/quantize_op.cc
index 003654c1b0..b8f0dd8642 100644
--- a/tensorflow/core/kernels/quantize_op.cc
+++ b/tensorflow/core/kernels/quantize_op.cc
@@ -17,11 +17,12 @@ limitations under the License.
#define EIGEN_USE_THREADS
-#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/meta_support.h"
+#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/errors.h"
namespace {
@@ -124,9 +125,15 @@ class QuantizeV2Op : public OpKernel {
.template cast<T>();
}
} else if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
- FloatTensorToQuantizedInPlaceUsingEigen<T>(
- ctx->template eigen_device<Device>(), input, min_range, max_range,
- output);
+ if (meta::IsSupportedAndEnabled() && std::is_same<T, quint8>()) {
+ auto input_array = input.flat<float>();
+ meta::Quantize(ctx, input_array.data(), input_array.size(), min_range,
+ max_range, output->flat<quint8>().data());
+ } else {
+ FloatTensorToQuantizedInPlaceUsingEigen<T>(
+ ctx->template eigen_device<Device>(), input, min_range, max_range,
+ output);
+ }
}
Tensor* output_min_tensor = nullptr;
diff --git a/tensorflow/core/kernels/quantized_activation_ops.cc b/tensorflow/core/kernels/quantized_activation_ops.cc
index ea1cf15f7b..2896c3d45a 100644
--- a/tensorflow/core/kernels/quantized_activation_ops.cc
+++ b/tensorflow/core/kernels/quantized_activation_ops.cc
@@ -16,10 +16,11 @@ limitations under the License.
// Implements a quantized version of the Relu6 operation.
#define EIGEN_USE_THREADS
-#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/meta_support.h"
+#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -37,8 +38,16 @@ class QuantizedReluOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
const T min_as_quantized = FloatToQuantized<T>(0.0f, min_input, max_input);
- output->flat<T>().device(context->eigen_cpu_device()) =
- input.flat<T>().cwiseMax(min_as_quantized).template cast<T>();
+
+ if (meta::IsSupportedAndEnabled() && std::is_same<T, quint8>()) {
+ auto input_ui8_array = input.flat<quint8>();
+ meta::Clamp(context, input_ui8_array.data(), input_ui8_array.size(),
+ min_as_quantized, 255, output->flat<quint8>().data());
+ } else {
+ output->flat<T>().device(context->eigen_cpu_device()) =
+ input.flat<T>().cwiseMax(min_as_quantized).template cast<T>();
+ }
+
Tensor* output_min = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min));
output_min->flat<float>()(0) = min_input;
@@ -63,11 +72,20 @@ class QuantizedRelu6Op : public OpKernel {
context->allocate_output(0, input.shape(), &output));
const T min_as_quantized = FloatToQuantized<T>(0.0f, min_input, max_input);
const T max_as_quantized = FloatToQuantized<T>(6.0f, min_input, max_input);
- output->flat<T>().device(context->eigen_cpu_device()) =
- input.flat<T>()
- .cwiseMax(min_as_quantized)
- .cwiseMin(max_as_quantized)
- .template cast<T>();
+
+ if (meta::IsSupportedAndEnabled() && std::is_same<T, quint8>()) {
+ auto input_ui8_array = input.flat<quint8>();
+ meta::Clamp(context, input_ui8_array.data(), input_ui8_array.size(),
+ min_as_quantized, max_as_quantized,
+ output->flat<quint8>().data());
+ } else {
+ output->flat<T>().device(context->eigen_cpu_device()) =
+ input.flat<T>()
+ .cwiseMax(min_as_quantized)
+ .cwiseMin(max_as_quantized)
+ .template cast<T>();
+ }
+
Tensor* output_min = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min));
output_min->flat<float>()(0) = min_input;
diff --git a/tensorflow/core/kernels/quantized_bias_add_op.cc b/tensorflow/core/kernels/quantized_bias_add_op.cc
index 0b34bfcad8..5457d290c2 100644
--- a/tensorflow/core/kernels/quantized_bias_add_op.cc
+++ b/tensorflow/core/kernels/quantized_bias_add_op.cc
@@ -15,11 +15,14 @@ limitations under the License.
// Implements a quantized eight-bit version of the bias addition operation.
-#include "tensorflow/core/kernels/quantization_utils.h"
+#define EIGEN_USE_THREADS
+
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/meta_support.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -60,9 +63,23 @@ class QuantizedBiasAddOp : public OpKernel {
float total_min;
float total_max;
- QuantizedAddUsingEigen<T1, T2, T3>(
- context->template eigen_device<CPUDevice>(), input, input_min,
- input_max, bias, bias_min, bias_max, output, &total_min, &total_max);
+
+ if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() &&
+ std::is_same<T2, quint8>() && std::is_same<T3, qint32>()) {
+ auto input_ui8_array = input.flat<quint8>();
+ auto bias_ui8_array = bias.flat<quint8>();
+ GetOutputMinAndMaxForQuantizedAdd(input_min, input_max, bias_min,
+ bias_max, &total_min, &total_max);
+ meta::QuantizedBiasAdd(context, input_ui8_array.data(),
+ input_ui8_array.size(), bias_ui8_array.data(),
+ bias_ui8_array.size(), input_min, input_max,
+ bias_min, bias_max, total_min, total_max,
+ output->flat<qint32>().data());
+ } else {
+ QuantizedAddUsingEigen<T1, T2, T3>(
+ context->template eigen_device<CPUDevice>(), input, input_min,
+ input_max, bias, bias_min, bias_max, output, &total_min, &total_max);
+ }
Tensor* output_min = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min));
diff --git a/tensorflow/core/kernels/quantized_conv_ops.cc b/tensorflow/core/kernels/quantized_conv_ops.cc
index fb69d770c0..2405c55c5b 100644
--- a/tensorflow/core/kernels/quantized_conv_ops.cc
+++ b/tensorflow/core/kernels/quantized_conv_ops.cc
@@ -18,12 +18,15 @@ limitations under the License.
#include <algorithm>
#include <vector>
+#define EIGEN_USE_THREADS
+
#include "public/gemmlowp.h"
-#include "tensorflow/core/kernels/quantization_utils.h"
-#include "tensorflow/core/kernels/reference_gemm.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/meta_support.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/quantization_utils.h"
+#include "tensorflow/core/kernels/reference_gemm.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/padding.h"
@@ -338,12 +341,20 @@ class Im2ColConvFunctor {
const int lda = filter_value_count;
const int ldb = filter_count;
const int ldc = filter_count;
- // The gemmlowp optimized library only works for a particular set of data
- // types, so check if we meet those requirements and
- // fall back to a slower reference implementation if not.
- if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() &&
- std::is_same<T3, qint32>() && (output_offset == 0) &&
- (output_mult == 1) && (output_shift == 0)) {
+
+ if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() &&
+ std::is_same<T2, quint8>() && std::is_same<T3, qint32>() &&
+ (output_offset == 0) && (output_mult == 1) && (output_shift == 0) &&
+ (transpose_c == false)) {
+ meta::QuantizedGemm(op_context, transpose_a, transpose_b,
+ im2col_buffer.get(), filter_data, output_data, m, n,
+ k, -input_offset, -filter_offset, lda, ldb, ldc);
+ } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() &&
+ std::is_same<T3, qint32>() && (output_offset == 0) &&
+ (output_mult == 1) && (output_shift == 0)) {
+ // The gemmlowp optimized library only works for a particular set of data
+ // types, so check if we meet those requirements and
+ // fall back to a slower reference implementation if not.
const uint8* im2col_data_as_uint8 = &(im2col_buffer.get()->value);
const uint8* filter_data_as_uint8 = &(filter_data->value);
int32* output_data_as_int32 = &(output_data->value);
diff --git a/tensorflow/core/kernels/quantized_matmul_op.cc b/tensorflow/core/kernels/quantized_matmul_op.cc
index 0ce9e37642..4abcae0d35 100644
--- a/tensorflow/core/kernels/quantized_matmul_op.cc
+++ b/tensorflow/core/kernels/quantized_matmul_op.cc
@@ -15,11 +15,14 @@ limitations under the License.
// Implements a quantized eight-bit version of the matmul operation.
+#define EIGEN_USE_THREADS
+
#include "public/gemmlowp.h"
-#include "tensorflow/core/kernels/quantization_utils.h"
-#include "tensorflow/core/kernels/reference_gemm.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/meta_support.h"
+#include "tensorflow/core/kernels/quantization_utils.h"
+#include "tensorflow/core/kernels/reference_gemm.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -125,12 +128,20 @@ class QuantizedMatMulOp : public OpKernel {
const size_t ldb = b.dim_size(1);
const size_t ldc = n;
- // The gemmlowp optimized library only works for a particular set of data
- // types, so check if we meet those requirements and
- // fall back to a slower reference implementation if not.
- if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() &&
- std::is_same<Toutput, qint32>() && (offset_c == 0) && (mult_c == 1) &&
- (shift_c == 0) && (transpose_c == false)) {
+ if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() &&
+ std::is_same<T2, quint8>() && std::is_same<Toutput, qint32>() &&
+ (offset_c == 0) && (mult_c == 1) && (shift_c == 0) &&
+ (transpose_c == false)) {
+ // Gemmlowp/meta code path works on 32 & 64 bit Arm with NEON Simd and
+ // allows optimized quantized 8bit to 32bit gemm.
+ meta::QuantizedGemm(context, transpose_a_, transpose_b_, a_data, b_data,
+ c_data, m, n, k, offset_a, offset_b, lda, ldb, ldc);
+ } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() &&
+ std::is_same<Toutput, qint32>() && (offset_c == 0) &&
+ (mult_c == 1) && (shift_c == 0) && (transpose_c == false)) {
+ // The gemmlowp optimized library only works for a particular set of data
+ // types, so check if we meet those requirements and fall back to a slower
+ // reference implementation if not.
if (transpose_a_) {
if (transpose_b_) {
GemmlowpMultiply<true, true, false>(context, a_data, b_data, c_data,
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 0eeea1fee7..b13e6c7d88 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -34,9 +34,9 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
native.http_archive(
name = "gemmlowp",
- url = "http://github.com/google/gemmlowp/archive/c0bacf11fb509a2cbe15a97362a2df067ffd57a2.tar.gz",
- sha256 = "dc64a38f9927db18748d9024987c9b102115e25bc2be4b76aa8e422b8f83d882",
- strip_prefix = "gemmlowp-c0bacf11fb509a2cbe15a97362a2df067ffd57a2",
+ url = "http://github.com/google/gemmlowp/archive/a6f29d8ac48d63293f845f2253eccbf86bc28321.tar.gz",
+ sha256 = "75d40ea8e68b0d1644f052fffe8f14a410b2a73d40ccb859a95c0578d194ec26",
+ strip_prefix = "gemmlowp-a6f29d8ac48d63293f845f2253eccbf86bc28321",
)
native.new_http_archive(