aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-24 14:28:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-24 14:32:48 -0700
commit625ce1ac462292fb4bc76a06343f170871cb428c (patch)
treebed59540605273477e07e79f6088db8224399743
parent63aa126486b7edd48a3c9b52e8d2047c17004c9f (diff)
Implement quantized addition op, with NEON-acceleration for ARM devices
PiperOrigin-RevId: 157037658
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt1
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/kernels/BUILD116
-rw-r--r--tensorflow/core/kernels/quantization_utils.h282
-rw-r--r--tensorflow/core/kernels/quantization_utils_test.cc605
-rw-r--r--tensorflow/core/kernels/quantized_add_op.cc581
-rw-r--r--tensorflow/core/kernels/quantized_add_op_test.cc311
-rw-r--r--tensorflow/core/ops/math_ops.cc29
-rw-r--r--tensorflow/core/platform/test.cc8
-rw-r--r--tensorflow/tools/graph_transforms/quantize_nodes.cc7
-rw-r--r--tensorflow/tools/graph_transforms/quantize_nodes_test.cc43
11 files changed, 1754 insertions, 230 deletions
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index b9cd91e519..a2c9b4c5bd 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -189,6 +189,7 @@ tensorflow/core/kernels/quantization_utils.cc
tensorflow/core/kernels/quantize_down_and_shrink_range.cc
tensorflow/core/kernels/quantize_op.cc
tensorflow/core/kernels/quantized_activation_ops.cc
+tensorflow/core/kernels/quantized_add_op.cc
tensorflow/core/kernels/quantized_batch_norm_op.cc
tensorflow/core/kernels/quantized_bias_add_op.cc
tensorflow/core/kernels/quantized_concat_op.cc
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index ed3eeb4b21..e5ab8570d8 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1072,6 +1072,7 @@ filegroup(
":framework/shape_inference_testutil.h",
":framework/tensor_testutil.cc",
":framework/tensor_testutil.h",
+ ":platform/test.cc",
":platform/test.h",
":util/reporter.cc",
":util/reporter.h",
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index aaa2543183..fe7e1f46f4 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -4259,6 +4259,7 @@ filegroup(
"quantize_down_and_shrink_range.cc",
"quantize_op.cc",
"quantized_activation_ops.cc",
+ "quantized_add_op.cc",
"quantized_batch_norm_op.cc",
"quantized_bias_add_op.cc",
"quantized_concat_op.cc",
@@ -4365,6 +4366,7 @@ tf_kernel_library(
"quantize_down_and_shrink_range.cc",
"quantize_op.cc",
"quantized_activation_ops.cc",
+ "quantized_add_op.cc",
"quantized_batch_norm_op.cc",
"quantized_bias_add_op.cc",
"quantized_concat_op.cc",
@@ -4468,6 +4470,48 @@ tf_cc_test(
],
)
+# Android-only test for quantization utilities.
+cc_binary(
+ name = "quantization_utils_test_android_only",
+ testonly = 1,
+ srcs = ["quantization_utils_test.cc"],
+ copts = tf_copts(),
+ linkopts = select({
+ "//tensorflow:android": [
+ "-lm",
+ "-llog",
+ "-pie",
+ "-std=c++11",
+ ],
+ "//conditions:default": [],
+ }),
+ linkstatic = 1,
+ tags = [
+ "manual",
+ "notap",
+ ],
+ deps = [
+ ] + select({
+ "//tensorflow:android": [
+ ":android_tensorflow_kernels",
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_test_lib",
+ ],
+ "//conditions:default": [
+ ":quantized_ops",
+ "//third_party/eigen3",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:client_session",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:tensor_testutil",
+ "//tensorflow/core:test_main",
+ ],
+ }),
+)
+
tf_cc_test(
name = "quantized_activation_ops_test",
srcs = ["quantized_activation_ops_test.cc"],
@@ -4486,6 +4530,72 @@ tf_cc_test(
],
)
+# Android-only test for quantized addition.
+cc_binary(
+ name = "quantized_add_op_test_android_only",
+ testonly = 1,
+ srcs = ["quantized_add_op_test.cc"],
+ copts = tf_copts(),
+ linkopts = select({
+ "//tensorflow:android": [
+ "-lm",
+ "-llog",
+ "-pie",
+ "-std=c++11",
+ ],
+ "//conditions:default": [],
+ }),
+ linkstatic = 1,
+ tags = [
+ "manual",
+ "notap",
+ ],
+ deps = [
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:client_session",
+ ] + select({
+ "//tensorflow:android": [
+ ":android_tensorflow_kernels",
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_test_lib",
+ ],
+ "//conditions:default": [
+ ":ops_util",
+ ":quantized_ops",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensor_testutil",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+ }),
+)
+
+tf_cc_test(
+ name = "quantized_add_op_test",
+ size = "small",
+ srcs = ["quantized_add_op_test.cc"],
+ deps = [
+ ":math",
+ ":ops_testutil",
+ ":ops_util",
+ ":quantized_ops",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:client_session",
+ "//tensorflow/core:array_ops_op_lib",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:direct_session",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:math_ops_op_lib",
+ "//tensorflow/core:nn_ops_op_lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
tf_cc_test(
name = "quantized_bias_add_op_test",
size = "small",
@@ -4564,7 +4674,7 @@ tf_cc_test(
],
)
-# Android-only test for quantized instance norm.
+# Android-only test for quantized multiply.
cc_binary(
name = "quantized_mul_op_test_android_only",
testonly = 1,
@@ -4590,9 +4700,13 @@ cc_binary(
"//tensorflow/core:android_tensorflow_test_lib",
],
"//conditions:default": [
+ ":ops_util",
+ ":quantized_ops",
"//tensorflow/core:framework",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test_main",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
],
}),
)
diff --git a/tensorflow/core/kernels/quantization_utils.h b/tensorflow/core/kernels/quantization_utils.h
index be67dfd112..e258efd545 100644
--- a/tensorflow/core/kernels/quantization_utils.h
+++ b/tensorflow/core/kernels/quantization_utils.h
@@ -24,6 +24,13 @@ limitations under the License.
// optimized. They should be implementable using fixed point representations
// to avoid a dependency on floating-point hardware.
+#if defined(__ARM_NEON__) || defined(__ARM_NEON)
+#define QUANTIZATION_UTILS_USE_NEON
+#include <arm_neon.h>
+#endif
+
+#include <array>
+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
#include "public/gemmlowp.h"
@@ -203,7 +210,7 @@ inline T2 RequantizeInNewRange(T1 input, float min_input, float max_input,
}
template <class T1, class T2>
-inline void RequantizeManyInNewRange(const T1* input, size_t count,
+inline void RequantizeManyInNewRange(const T1* input, int64 count,
float min_input, float max_input,
float min_output, float max_output,
T2* output) {
@@ -217,10 +224,11 @@ inline void RequantizeManyInNewRange(const T1* input, size_t count,
// Because converting 32-bit accumulated results down to eight bit is a common
// case, we have a specialized code path to handle it as efficiently as
// possible using only fixed-point math for the inner loop.
-template <>
-inline void RequantizeManyInNewRange<qint32, quint8>(
- const qint32* input, size_t count, float min_input, float max_input,
- float min_output, float max_output, quint8* output) {
+inline void RequantizeManyInNewRangeReference(const qint32* input, int64 count,
+ float min_input, float max_input,
+ float min_output,
+ float max_output,
+ quint8* output) {
// Initially we calculate all the constants we need once, before we go into
// the inner loop. If this is updated, also update the Eigen version.
const int fp_shift = 16;
@@ -236,9 +244,10 @@ inline void RequantizeManyInNewRange<qint32, quint8>(
const int64 input_offset_fp =
static_cast<int64>(input_rezero * recip_output_range * (1 << fp_shift));
const int64 output_offset_fp =
- output_range == 0.0 ? 0 : static_cast<int64>((1 << fp_shift) *
- (min_output * 255.0) /
- output_range);
+ output_range == 0.0
+ ? 0
+ : static_cast<int64>((1 << fp_shift) * (min_output * 255.0) /
+ output_range);
const int64 rounding_delta = 1 << (fp_shift - 1);
// Inside this loop we just do minimal adds, multiplies, and shifts, in a way
@@ -258,6 +267,256 @@ inline void RequantizeManyInNewRange<qint32, quint8>(
}
}
+// Another common case is converting eight bit inputs up to thirty two bits, so
+// we have specialized fixed-point code to accelerate that. There is also a NEON
+// version for ARM devices below.
+inline void RequantizeManyInNewRange8To32BitReference(
+ const quint8* input, int64 count, float min_input, float max_input,
+ float min_output, float max_output, qint32* output) {
+ const float code_0_float = QuantizedToFloat<quint8>(0, min_input, max_input);
+ const float code_1_float = QuantizedToFloat<quint8>(1, min_input, max_input);
+ const int64 code_0_int64 =
+ FloatToQuantizedUnclamped<qint32>(code_0_float, min_output, max_output);
+ const int64 code_1_int64 =
+ FloatToQuantizedUnclamped<qint32>(code_1_float, min_output, max_output);
+ const int32 mult_int32 = code_1_int64 - code_0_int64;
+ const int64 lowest_quantized =
+ static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
+ const int64 highest_quantized =
+ static_cast<int64>(Eigen::NumTraits<qint32>::highest());
+ for (int64 i = 0; i < count; ++i) {
+ const int64 input_value = static_cast<int64>(input[i]);
+ int64 output_value = code_0_int64 + (input_value * mult_int32);
+ output_value = std::max(output_value, lowest_quantized);
+ output_value = std::min(output_value, highest_quantized);
+ output[i] = static_cast<int32>(output_value);
+ }
+}
+
+#ifdef QUANTIZATION_UTILS_USE_NEON
+// Speeds up the 32->8bit conversion using fixed-point arithmetic and NEON SIMD
+// intrinsics for ARM platforms.
+inline void RequantizeManyInNewRangeNeon(const qint32* input, int64 count,
+ float min_input, float max_input,
+ float min_output, float max_output,
+ quint8* output) {
+ // Initially we calculate all the constants we need once, before we go into
+ // the inner loop. If this is updated, also update the Eigen version.
+ const int fp_shift = 16;
+
+ // Calculate range variables in advance.
+ // Input range.
+ const float input_range = max_input - min_input;
+ // Output range.
+ const float output_range = max_output - min_output;
+ // Ratio of output range.
+ const float recip_output_range =
+ output_range == 0.0 ? 0.0 : (255.0 / output_range);
+ // Average of input range as zero position of input.
+ const float input_rezero = (min_input + max_input) / 2.0;
+ // In-out range scale.
+ const int32 range_scale_fp =
+ output_range == 0.0 ? 0.0
+ : static_cast<int32>(255.0 * (1 << (fp_shift - 16)) *
+ input_range / output_range);
+ // Input zero position offset to output.
+ const int32 input_offset_fp =
+ static_cast<int32>(input_rezero * recip_output_range * (1 << fp_shift));
+ // Output min offset.
+ const int32 output_offset_fp =
+ output_range == 0.0
+ ? 0
+ : static_cast<int32>((1 << fp_shift) * (min_output * 255.0) /
+ output_range);
+ const int32 rounding_delta = 1 << (fp_shift - 1);
+
+ // broadcast range to each lane
+ const int32x4_t range_scale_fp_32x4 = vmovq_n_s32(range_scale_fp);
+ const int32x4_t input_offset_fp_32x4 = vmovq_n_s32(input_offset_fp);
+ const int32x4_t output_offset_fp_32x4 = vmovq_n_s32(output_offset_fp);
+ const int32x4_t rounding_delta_32x4 = vmovq_n_s32(rounding_delta);
+
+ int64 index = 0;
+ // Use SIMD to requantize.
+ for (; index < (count - 7); index += 8) {
+ const int32* input_ptr = &(input->value) + index;
+ const int32x4_t input_value_low_32x4 = vld1q_s32(input_ptr);
+ const int32x4_t input_value_high_32x4 = vld1q_s32(input_ptr + 4);
+ const int32x4_t fp_value_low_32x4 = vaddq_s32(
+ input_offset_fp_32x4,
+ vmulq_s32(vshrq_n_s32(input_value_low_32x4, 16), range_scale_fp_32x4));
+ const int32x4_t fp_value_high_32x4 = vaddq_s32(
+ input_offset_fp_32x4,
+ vmulq_s32(vshrq_n_s32(input_value_high_32x4, 16), range_scale_fp_32x4));
+ const int32x4_t offset_intermediate_low_32x4 =
+ vsubq_s32(fp_value_low_32x4, output_offset_fp_32x4);
+ const int32x4_t offset_intermediate_high_32x4 =
+ vsubq_s32(fp_value_high_32x4, output_offset_fp_32x4);
+ const int32x4_t round_intermediate_low_32x4 =
+ vaddq_s32(offset_intermediate_low_32x4, rounding_delta_32x4);
+ const int32x4_t round_intermediate_high_32x4 =
+ vaddq_s32(offset_intermediate_high_32x4, rounding_delta_32x4);
+ const int16x4_t quantized_low_16x4 =
+ vqmovn_s32(vshrq_n_s32(round_intermediate_low_32x4, fp_shift));
+ const int16x4_t quantized_high_16x4 =
+ vqmovn_s32(vshrq_n_s32(round_intermediate_high_32x4, fp_shift));
+ const uint8x8_t quantized_8x8 =
+ vqmovun_s16(vcombine_s16(quantized_low_16x4, quantized_high_16x4));
+ uint8* output_ptr = &(output->value) + index;
+ vst1_u8(output_ptr, quantized_8x8);
+ }
+
+ // Requantize remaining elements in array without SIMD.
+ for (; index < count; ++index) {
+ const int32 input_value = static_cast<int32>(input[index]);
+ const int32 fp_value =
+ static_cast<int32>(
+ (static_cast<int32>(input_value >> 16) * (range_scale_fp))) +
+ input_offset_fp;
+ const int32 offset_intermediate = fp_value - output_offset_fp;
+ const int32 round_intermediate = offset_intermediate + rounding_delta;
+ int32 quantized_int32 = round_intermediate >> fp_shift;
+ quantized_int32 = std::max(quantized_int32, 0);
+ quantized_int32 = std::min(quantized_int32, 255);
+ output[index] = static_cast<quint8>(static_cast<int32>(quantized_int32));
+ }
+}
+
+template <>
+inline void RequantizeManyInNewRange<qint32, quint8>(
+ const qint32* input, int64 count, float min_input, float max_input,
+ float min_output, float max_output, quint8* output) {
+ const float input_range = max_input - min_input;
+ const float output_range = max_output - min_output;
+ if ((input_range / output_range) > 16384.0f) {
+ // Our NEON implementation uses 32-bit math and can't handle very
+ // large ranges, so fall back to the reference implementation. We don't
+ // expect these to be common in models, so this shouldn't be a performance
+ // problem in practice.
+ RequantizeManyInNewRangeReference(input, count, min_input, max_input,
+ min_output, max_output, output);
+ } else {
+ RequantizeManyInNewRangeNeon(input, count, min_input, max_input, min_output,
+ max_output, output);
+ }
+}
+
+// Requantize 8 x 8 quints to 8 x 32 qints in parallel by neon
+// Return std::array instead of pointer to leverage return value optimization
+inline std::array<int32x4_t, 2> Requantize8x8To32Neon(
+ const uint8* input_ptr, const int64x2_t input_0_64x2,
+ const int32x2_t input_mult_32x2) {
+ const uint8x8_t input_value_8x8 = vld1_u8(input_ptr);
+ const int16x8_t input_value_16x8 =
+ vreinterpretq_s16_u16(vmovl_u8(input_value_8x8));
+ const int16x4_t input_value_low_16x4 = vget_low_s16(input_value_16x8);
+ const int16x4_t input_value_high_16x4 = vget_high_s16(input_value_16x8);
+ const int32x4_t input_value_low_32x4 = vmovl_s16(input_value_low_16x4);
+ const int32x4_t input_value_high_32x4 = vmovl_s16(input_value_high_16x4);
+ const int32x2_t input_value_low_low_32x2 = vget_low_s32(input_value_low_32x4);
+ const int32x2_t input_value_low_high_32x2 =
+ vget_high_s32(input_value_low_32x4);
+ const int32x2_t input_value_high_low_32x2 =
+ vget_low_s32(input_value_high_32x4);
+ const int32x2_t input_value_high_high_32x2 =
+ vget_high_s32(input_value_high_32x4);
+ const int64x2_t mult_result_low_low_64x2 =
+ vmlal_s32(input_0_64x2, input_value_low_low_32x2, input_mult_32x2);
+ const int64x2_t mult_result_low_high_64x2 =
+ vmlal_s32(input_0_64x2, input_value_low_high_32x2, input_mult_32x2);
+ const int64x2_t mult_result_high_low_64x2 =
+ vmlal_s32(input_0_64x2, input_value_high_low_32x2, input_mult_32x2);
+ const int64x2_t mult_result_high_high_64x2 =
+ vmlal_s32(input_0_64x2, input_value_high_high_32x2, input_mult_32x2);
+ const int32x2_t output_value_low_low_32x2 =
+ vqmovn_s64(mult_result_low_low_64x2);
+ const int32x2_t output_value_low_high_32x2 =
+ vqmovn_s64(mult_result_low_high_64x2);
+ const int32x2_t output_value_high_low_32x2 =
+ vqmovn_s64(mult_result_high_low_64x2);
+ const int32x2_t output_value_high_high_32x2 =
+ vqmovn_s64(mult_result_high_high_64x2);
+ const int32x4_t output_value_low_32x4 =
+ vcombine_s32(output_value_low_low_32x2, output_value_low_high_32x2);
+ const int32x4_t output_value_high_32x4 =
+ vcombine_s32(output_value_high_low_32x2, output_value_high_high_32x2);
+ return std::array<int32x4_t, 2>{
+ {output_value_low_32x4, output_value_high_32x4}};
+}
+
+// Speeds up the 8->32bit conversion using fixed-point arithmetic and NEON SIMD
+// intrinsics for ARM platforms.
+template <>
+inline void RequantizeManyInNewRange<quint8, qint32>(
+ const quint8* input, int64 count, float min_input, float max_input,
+ float min_output, float max_output, qint32* output) {
+ // Pre-calculate zero position and multiplier.
+ // Calculate 0 and 1 value in float.
+ const float code_0_float = QuantizedToFloat<quint8>(0, min_input, max_input);
+ const float code_1_float = QuantizedToFloat<quint8>(1, min_input, max_input);
+
+ // Cast 0 and 1 value in int64.
+ const int64 code_0_int64 =
+ FloatToQuantizedUnclamped<qint32>(code_0_float, min_output, max_output);
+ const int64 code_1_int64 =
+ FloatToQuantizedUnclamped<qint32>(code_1_float, min_output, max_output);
+
+ // Calculate multiplier.
+ const int32 mult_int32 = static_cast<int32>(code_1_int64 - code_0_int64);
+
+ // Broadcast 0 position and multiplier to lanes
+ const int64x2_t code_0_64x2 = vmovq_n_s64(code_0_int64);
+ const int32x2_t mult_32x2 = vmov_n_s32(mult_int32);
+
+ int64 i = 0;
+
+ // Use SIMD to requantize array.
+ for (; i < (count - 7); i += 8) {
+ const uint8* input_ptr = &(input->value) + i;
+ int32* output_ptr = &(output->value) + i;
+ const std::array<int32x4_t, 2> output_value =
+ Requantize8x8To32Neon(input_ptr, code_0_64x2, mult_32x2);
+ vst1q_s32(output_ptr + 0, output_value[0]);
+ vst1q_s32(output_ptr + 4, output_value[1]);
+ }
+
+ // Requantize remaining elements in array without SIMD.
+ const int64 lowest_quantized =
+ static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
+ const int64 highest_quantized =
+ static_cast<int64>(Eigen::NumTraits<qint32>::highest());
+
+ for (; i < count; ++i) {
+ const int64 input_value = static_cast<int64>(input[i]);
+ int64 output_value = code_0_int64 + (input_value * mult_int32);
+ output_value = std::max(output_value, lowest_quantized);
+ output_value = std::min(output_value, highest_quantized);
+ output[i] = static_cast<int32>(output_value);
+ }
+}
+
+#else
+
+// If SIMD implementations aren't available, then use these default reference
+// versions.
+template <>
+inline void RequantizeManyInNewRange<qint32, quint8>(
+ const qint32* input, int64 count, float min_input, float max_input,
+ float min_output, float max_output, quint8* output) {
+ RequantizeManyInNewRangeReference(input, count, min_input, max_input,
+ min_output, max_output, output);
+}
+
+template <>
+inline void RequantizeManyInNewRange<quint8, qint32>(
+ const quint8* input, int64 count, float min_input, float max_input,
+ float min_output, float max_output, qint32* output) {
+ RequantizeManyInNewRange8To32BitReference(input, count, min_input, max_input,
+ min_output, max_output, output);
+}
+
+#endif
+
template <int shift>
struct int64_right_shift_op {
EIGEN_EMPTY_STRUCT_CTOR(int64_right_shift_op)
@@ -305,9 +564,10 @@ inline void RequantizeManyInNewRangeUsingEigen<qint32, quint8>(
const int64 input_offset_fp =
static_cast<int64>(input_rezero * recip_output_range * (1 << fp_shift));
const int64 output_offset_fp =
- output_range == 0.0 ? 0 : static_cast<int64>((1 << fp_shift) *
- (min_output * 255.0) /
- output_range);
+ output_range == 0.0
+ ? 0
+ : static_cast<int64>((1 << fp_shift) * (min_output * 255.0) /
+ output_range);
const int64 rounding_delta = 1 << (fp_shift - 1);
// Inside this eigen expression we just do minimal adds, multiplies, and
diff --git a/tensorflow/core/kernels/quantization_utils_test.cc b/tensorflow/core/kernels/quantization_utils_test.cc
index 0c23c0586c..c547b166ee 100644
--- a/tensorflow/core/kernels/quantization_utils_test.cc
+++ b/tensorflow/core/kernels/quantization_utils_test.cc
@@ -18,66 +18,99 @@ limitations under the License.
#include <limits>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
+namespace {
+
+void TestRequantizeMany(Eigen::ThreadPoolDevice* eigen_device, float input_min,
+ float input_max, float output_min, float output_max,
+ const std::vector<qint32>& values_quantized,
+ int tolerance = 1) {
+ const int values_count = values_quantized.size();
+ std::vector<quint8> expected_values;
+ for (int value_index = 0; value_index < values_count; ++value_index) {
+ expected_values.push_back(FloatToQuantized<quint8>(
+ QuantizedToFloat(values_quantized[value_index], input_min, input_max),
+ output_min, output_max));
+ }
-class QuantizationUtilsTest : public ::testing::Test {
- protected:
- void TestRequantizeMany(Eigen::ThreadPoolDevice* eigen_device,
- float input_min, float input_max, float output_min,
- float output_max,
- const std::vector<qint32>& values_quantized,
- int tolerance = 1) {
- const int values_count = values_quantized.size();
- std::vector<quint8> expected_values;
- for (int value_index = 0; value_index < values_count; ++value_index) {
- expected_values.push_back(FloatToQuantized<quint8>(
- QuantizedToFloat(values_quantized[value_index], input_min, input_max),
- output_min, output_max));
- }
+ Tensor i_tensor =
+ tensorflow::test::AsTensor(gtl::ArraySlice<qint32>(values_quantized));
+ Tensor o_tensor(DT_QUINT8, TensorShape{values_count});
+ auto output_values = o_tensor.flat<quint8>();
+
+ if (eigen_device == nullptr) {
+ auto input_array = i_tensor.flat<qint32>();
+ RequantizeManyInNewRange(input_array.data(), input_array.size(), input_min,
+ input_max, output_min, output_max,
+ output_values.data());
+ } else {
+ RequantizeManyInNewRangeUsingEigen<qint32, quint8>(
+ *eigen_device, i_tensor, input_min, input_max, output_min, output_max,
+ &o_tensor);
+ }
- Tensor i_tensor =
- tensorflow::test::AsTensor(gtl::ArraySlice<qint32>(values_quantized));
- Tensor o_tensor(DT_QUINT8, TensorShape{values_count});
- auto output_values = o_tensor.flat<quint8>();
+ const string tolerance_str = strings::StrCat("+-", tolerance);
+ for (size_t value_index = 0; value_index < values_count; ++value_index) {
+ int e = expected_values[value_index];
+ int v = output_values(value_index);
+ ASSERT_TRUE(std::abs(e - v) <= tolerance)
+ << "actual=" << v << ", expected=" << e << tolerance_str
+ << ", values_quantized[" << value_index
+ << "]=" << values_quantized[value_index] << ", input_min=" << input_min
+ << ", input_max=" << input_max << ", output_min=" << output_min
+ << ", output_max=" << output_max << ", value_index=" << value_index;
+ }
+}
- if (eigen_device == nullptr) {
- auto input_array = i_tensor.flat<qint32>();
- RequantizeManyInNewRange(input_array.data(), input_array.size(),
- input_min, input_max, output_min, output_max,
- output_values.data());
- } else {
- RequantizeManyInNewRangeUsingEigen<qint32, quint8>(
- *eigen_device, i_tensor, input_min, input_max, output_min, output_max,
- &o_tensor);
- }
+void TestRequantizeMany8To32Bit(float input_min, float input_max,
+ float output_min, float output_max,
+ const std::vector<quint8>& values_quantized,
+ int tolerance = 256) {
+ const int values_count = values_quantized.size();
+ std::vector<qint32> expected_values;
+ for (int value_index = 0; value_index < values_count; ++value_index) {
+ expected_values.push_back(FloatToQuantized<qint32>(
+ QuantizedToFloat(values_quantized[value_index], input_min, input_max),
+ output_min, output_max));
+ }
- const string tolerance_str = strings::StrCat("+-", tolerance);
- for (size_t value_index = 0; value_index < values_count; ++value_index) {
- int e = expected_values[value_index];
- int v = output_values(value_index);
- ASSERT_TRUE(std::abs(e - v) <= tolerance)
- << "actual=" << v << ", expected=" << e << tolerance_str
- << ", values_quantized[" << value_index
- << "]=" << values_quantized[value_index]
- << ", input_min=" << input_min << ", input_max=" << input_max
- << ", output_min=" << output_min << ", output_max=" << output_max
- << ", value_index=" << value_index;
- }
+ Tensor i_tensor =
+ tensorflow::test::AsTensor(gtl::ArraySlice<quint8>(values_quantized));
+ Tensor o_tensor(DT_QINT32, TensorShape{values_count});
+ auto output_values = o_tensor.flat<qint32>();
+
+ auto input_array = i_tensor.flat<quint8>();
+ RequantizeManyInNewRange(input_array.data(), input_array.size(), input_min,
+ input_max, output_min, output_max,
+ output_values.data());
+
+ const string tolerance_str = strings::StrCat("+-", tolerance);
+ for (size_t value_index = 0; value_index < values_count; ++value_index) {
+ int e = expected_values[value_index];
+ int v = output_values(value_index);
+ ASSERT_TRUE(std::abs(e - v) <= tolerance)
+ << "actual=" << v << ", expected=" << e << tolerance_str
+ << ", values_quantized[" << value_index
+ << "]=" << values_quantized[value_index] << ", input_min=" << input_min
+ << ", input_max=" << input_max << ", output_min=" << output_min
+ << ", output_max=" << output_max << ", value_index=" << value_index;
}
+}
- // If eigen_device is NULL, then the reference implementation is tested.
- void TestRequantizeManyInNewRange32To8Bit(
- Eigen::ThreadPoolDevice* eigen_device) {
+// If eigen_device is NULL, then the reference implementation is tested.
+void TestRequantizeManyInNewRange32To8Bit(
+ Eigen::ThreadPoolDevice* eigen_device) {
+ if (true) {
// These are the float values we're going to test the conversions on.
const size_t values_count = 6;
const float values[values_count] = {0.0f, 0.45f, 1.0f,
@@ -108,7 +141,7 @@ class QuantizationUtilsTest : public ::testing::Test {
qint32 high = Eigen::NumTraits<qint32>::highest();
std::vector<qint32> vals{low, high};
int num_steps = 14419;
- qint32 step = static_cast<int32>((1L << 32) / num_steps);
+ qint32 step = static_cast<int32>((1LL << 32) / num_steps);
qint32 v = low + static_cast<qint32>(1);
for (int i = 0; i < num_steps; ++i) {
vals.push_back(v);
@@ -120,181 +153,278 @@ class QuantizationUtilsTest : public ::testing::Test {
vals);
TestRequantizeMany(eigen_device, -1.0f, 12345678.0f, -12345678.0f,
12345678.0f, vals);
+ }
+ // Test when the input range is large and output range is small.
+ // Use all quantized values where the float is in the output range.
+ const float out_min = -29.1234;
+ const float out_max = 23.1234;
+ const float in_min = -1e6;
+ const float in_max = 1e6;
+
+ qint32 low = FloatToQuantized<qint32>(out_min, in_min, in_max);
+ qint32 high = FloatToQuantized<qint32>(out_max, in_min, in_max);
+ std::vector<qint32> vals;
+ vals.clear();
+ for (int32 i = low; i <= high; ++i) vals.push_back(i);
+ TestRequantizeMany(eigen_device, in_min, in_max, out_min, out_max, vals);
+}
- // Test when the input range is large and output range is small.
- // Use all quantized values where the float is in the output range.
- const float out_min = -29.1234;
- const float out_max = 23.1234;
- const float in_min = -1e6;
- const float in_max = 1e6;
-
- low = FloatToQuantized<qint32>(out_min, in_min, in_max);
- high = FloatToQuantized<qint32>(out_max, in_min, in_max);
- vals.clear();
- for (int32 i = low; i <= high; ++i) vals.push_back(i);
- TestRequantizeMany(eigen_device, in_min, in_max, out_min, out_max, vals);
+void TestRequantizeManyInNewRange8To32Bit() {
+ // These are the float values we're going to test the conversions on.
+ const size_t values_count = 6;
+ const float values[values_count] = {0.0f, 0.45f, 1.0f, -1.0f, 127.0f, 255.0f};
+ // These are the input and output ranges we'll test.
+ const size_t ranges_count = 6;
+ const float ranges[ranges_count][4] = {
+ {0.0f, 255.0f, 0.0f, 255.0f}, //
+ {0.0f, 1.0f, 0.0f, 1.0f}, //
+ {-1.0f, 1.0f, -1.0f, 1.0f}, //
+ {-1.0f, 1.0f, -255.0f, 255.0f}, //
+ {3.0f, 3.0f, 0.0f, 255.0f}, // input min == max
+ {0.0f, 255.0f, 5.0f, 5.0f}, // output min == max
+ };
+ for (int i = 0; i < ranges_count; ++i) {
+ const auto& r = ranges[i];
+ std::vector<quint8> values_quantized;
+ for (int value_index = 0; value_index < values_count; ++value_index) {
+ const float v = values[value_index];
+ values_quantized.push_back(FloatToQuantized<quint8>(v, r[0], r[1]));
+ }
+ TestRequantizeMany8To32Bit(r[0], r[1], r[2], r[3], values_quantized);
}
- template <typename InputType, typename OutputType>
- void TestRequantizeManyInNewRangeEigenVsNonEigen() {
- thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
- EigenThreadPoolWrapper wrapper(&threadpool);
- Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
+ // Test with many different values in the input quantized range.
+ int low = Eigen::NumTraits<quint8>::lowest();
+ int high = Eigen::NumTraits<quint8>::highest();
+ std::vector<quint8> vals;
+ for (int val = low; val <= high; ++val) {
+ vals.push_back(val);
+ }
+ TestRequantizeMany8To32Bit(-1.0f, 1.0f, -1.0f, 1.0f, vals);
+ TestRequantizeMany8To32Bit(-255.0f, 255.0f, -255.0f, 255.0f, vals);
+ TestRequantizeMany8To32Bit(-1.0f, 1.0f, -12345678.0f, 12345678.0f, vals);
+ TestRequantizeMany8To32Bit(-1.0f, 12345678.0f, -12345678.0f, 12345678.0f,
+ vals);
+}
- const size_t ranges_count = 6;
- const float ranges[ranges_count][4] = {
- {0.0f, 255.0f, 0.0f, 255.0f}, //
- {0.0f, 1.0f, 0.0f, 1.0f}, //
- {-1.0f, 1.0f, -1.0f, 1.0f}, //
- {-1.0f, 1.0f, -255.0f, 255.0f}, //
- {3.0f, 3.0f, 0.0f, 255.0f}, // input min == max
- {0.0f, 255.0f, 5.0f, 5.0f}, // output min == max
- };
+template <typename InputType, typename OutputType>
+void TestRequantizeManyInNewRangeEigenVsNonEigen() {
+ thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
+ EigenThreadPoolWrapper wrapper(&threadpool);
+ Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
- // Random values.
- for (size_t range_index = 0; range_index < ranges_count; ++range_index) {
- const float input_min = ranges[range_index][0];
- const float input_max = ranges[range_index][1];
- const float output_min = ranges[range_index][2];
- const float output_max = ranges[range_index][3];
- const int values_count = 10000;
- random::PhiloxRandom philox(testing::RandomSeed(), 17);
- random::SimplePhilox rnd(&philox);
- std::vector<InputType> values_quantized;
- for (int i = 0; i < values_count; ++i) {
- float v = (rnd.RandFloat() * (input_max - input_min)) + input_min;
- values_quantized.push_back(
- FloatToQuantized<InputType>(v, input_min, input_max));
- }
+ const size_t ranges_count = 6;
+ const float ranges[ranges_count][4] = {
+ {0.0f, 255.0f, 0.0f, 255.0f}, //
+ {0.0f, 1.0f, 0.0f, 1.0f}, //
+ {-1.0f, 1.0f, -1.0f, 1.0f}, //
+ {-1.0f, 1.0f, -255.0f, 255.0f}, //
+ {3.0f, 3.0f, 0.0f, 255.0f}, // input min == max
+ {0.0f, 255.0f, 5.0f, 5.0f}, // output min == max
+ };
+
+ // Random values.
+ for (size_t range_index = 0; range_index < ranges_count; ++range_index) {
+ const float input_min = ranges[range_index][0];
+ const float input_max = ranges[range_index][1];
+ const float output_min = ranges[range_index][2];
+ const float output_max = ranges[range_index][3];
+ const int values_count = 10000;
+ random::PhiloxRandom philox(testing::RandomSeed(), 17);
+ random::SimplePhilox rnd(&philox);
+ std::vector<InputType> values_quantized;
+ for (int i = 0; i < values_count; ++i) {
+ float v = (rnd.RandFloat() * (input_max - input_min)) + input_min;
+ values_quantized.push_back(
+ FloatToQuantized<InputType>(v, input_min, input_max));
+ }
- Tensor i_tensor = tensorflow::test::AsTensor(
- gtl::ArraySlice<InputType>(values_quantized));
- const auto i_array = i_tensor.flat<InputType>();
- Tensor o_tensor_eigen(DataTypeToEnum<OutputType>::v(),
- TensorShape{values_count});
- auto output_values_eigen = o_tensor_eigen.flat<OutputType>();
- Tensor o_tensor_ref(DataTypeToEnum<OutputType>::v(),
+ Tensor i_tensor = tensorflow::test::AsTensor(
+ gtl::ArraySlice<InputType>(values_quantized));
+ const auto i_array = i_tensor.flat<InputType>();
+ Tensor o_tensor_eigen(DataTypeToEnum<OutputType>::v(),
TensorShape{values_count});
- auto output_values_ref = o_tensor_ref.flat<OutputType>();
+ auto output_values_eigen = o_tensor_eigen.flat<OutputType>();
+ Tensor o_tensor_ref(DataTypeToEnum<OutputType>::v(),
+ TensorShape{values_count});
+ auto output_values_ref = o_tensor_ref.flat<OutputType>();
+
+ RequantizeManyInNewRange(i_array.data(), i_array.size(), input_min,
+ input_max, output_min, output_max,
+ output_values_ref.data());
+ RequantizeManyInNewRangeUsingEigen<InputType, OutputType>(
+ eigen_device, i_tensor, input_min, input_max, output_min, output_max,
+ &o_tensor_eigen);
+
+ const int tolerance = 1;
+ for (int i = 0; i < values_quantized.size(); ++i) {
+ auto expected = output_values_ref(i);
+ auto actual = output_values_eigen(i);
+ // The eigen computation uses float for constants and computation
+ // instead of doubles, so can be different by 1 or 2 in some cases
+ // (e.g., input value 144.062744140625, min -1, max 255, type quint8).
+ ASSERT_TRUE(std::abs(expected - actual) <= tolerance)
+ << "expected=" << expected << " actual=" << actual
+ << " tolerance=" << tolerance << " v=" << values_quantized[i]
+ << " i=" << i << " input_min=" << input_min
+ << " input_max=" << input_max
+ << " input_type=" << DataTypeString(DataTypeToEnum<InputType>::v())
+ << " output_type=" << DataTypeString(DataTypeToEnum<OutputType>::v());
+ }
+ }
+}
+
+template <typename InputType, typename OutputType>
+void TimeRequantizeManyInNewRange(int64 num_elements, int64 iterations,
+ bool use_eigen) {
+ const float input_min = -100.0f;
+ const float input_max = 100.0f;
+ const float output_min = -1000000.0f;
+ const float output_max = 1000000.0f;
+
+ random::PhiloxRandom philox(testing::RandomSeed(), 17);
+ random::SimplePhilox rnd(&philox);
+ std::vector<InputType> values_quantized;
+ for (int i = 0; i < num_elements; ++i) {
+ float v = (rnd.RandFloat() * (input_max - input_min)) + input_min;
+ values_quantized.push_back(
+ FloatToQuantized<InputType>(v, input_min, input_max));
+ }
- RequantizeManyInNewRange(i_array.data(), i_array.size(), input_min,
- input_max, output_min, output_max,
- output_values_ref.data());
+ thread::ThreadPool threadpool(Env::Default(), "test", 4 /* num_threads */);
+ EigenThreadPoolWrapper wrapper(&threadpool);
+ Eigen::ThreadPoolDevice eigen_device(&wrapper, 4 /* num_threads */);
+
+ Tensor i_tensor =
+ tensorflow::test::AsTensor(gtl::ArraySlice<InputType>(values_quantized));
+ const auto i_array = i_tensor.flat<InputType>();
+ Tensor o_tensor_eigen(DataTypeToEnum<OutputType>::v(),
+ TensorShape{num_elements});
+ Tensor o_tensor_ref(DataTypeToEnum<OutputType>::v(),
+ TensorShape{num_elements});
+ auto output_values_ref = o_tensor_ref.flat<OutputType>();
+
+ int64 total_duration = 0;
+ for (int i = 0; i < iterations; ++i) {
+ const int64 start_time = Env::Default()->NowMicros();
+ if (use_eigen) {
RequantizeManyInNewRangeUsingEigen<InputType, OutputType>(
eigen_device, i_tensor, input_min, input_max, output_min, output_max,
&o_tensor_eigen);
-
- const int tolerance = 1;
- for (int i = 0; i < values_quantized.size(); ++i) {
- auto expected = output_values_ref(i);
- auto actual = output_values_eigen(i);
- // The eigen computation uses float for constants and computation
- // instead of doubles, so can be different by 1 or 2 in some cases
- // (e.g., input value 144.062744140625, min -1, max 255, type quint8).
- ASSERT_TRUE(std::abs(expected - actual) <= tolerance)
- << "expected=" << expected << " actual=" << actual
- << " tolerance=" << tolerance << " v=" << values_quantized[i]
- << " i=" << i << " input_min=" << input_min
- << " input_max=" << input_max
- << " input_type=" << DataTypeString(DataTypeToEnum<InputType>::v())
- << " output_type="
- << DataTypeString(DataTypeToEnum<OutputType>::v());
- }
+ } else {
+ RequantizeManyInNewRange<InputType, OutputType>(
+ i_array.data(), i_array.size(), input_min, input_max, output_min,
+ output_max, output_values_ref.data());
}
+ const int64 end_time = Env::Default()->NowMicros();
+ total_duration += end_time - start_time;
}
+ const int64 one_run_duration = total_duration / iterations;
- template <typename T>
- void TestFloatToQuantizedInPlaceUsingEigen(
- Eigen::ThreadPoolDevice* eigen_device) {
- // These are the float values we're going to test the conversions on.
- typedef std::pair<float, float> FPair;
- for (FPair min_and_max : std::vector<FPair>{FPair(-255.0f, 255.0f), //
- FPair(-1.0f, 1.0f), //
- FPair(-1.0f, 255.0f), //
- FPair(0.0f, 1e6), //
- FPair(0.0f, 1.0f), //
- FPair(-31.0f, 13.0f)}) {
- const float f_min = min_and_max.first;
- const float f_max = min_and_max.second;
- const float f_range = f_max - f_min;
- const int values_count = 50000;
- Tensor input(DT_FLOAT, TensorShape{values_count});
- auto input_array = input.flat<float>();
- for (int i = 0; i < values_count; ++i) {
- input_array(i) = f_min + f_range * i / (values_count - 1);
- }
+ const int64 num_ops = num_elements;
- Tensor output(DataTypeToEnum<T>::v(), TensorShape{values_count});
- FloatTensorToQuantizedInPlaceUsingEigen<T>(*eigen_device, input, f_min,
- f_max, &output);
- auto output_array = output.flat<T>();
-
- const int tolerance = 1;
- for (int i = 0; i < values_count; ++i) {
- int32 expected = FloatToQuantized<T>(input_array(i), f_min, f_max);
- int32 actual = output_array(i);
-
- // The eigen computation uses float for constants and computation
- // instead
- // of doubles, so can be different by 1 or 2 in some cases (e.g., input
- // value 144.062744140625, min -1, max 255, type quint8).
- ASSERT_TRUE(std::abs(expected - actual) <= tolerance)
- << "expected=" << expected << " actual=" << actual
- << " tolerance=" << tolerance << " v=" << input_array(i)
- << " i=" << i << " f_min=" << f_min << " f_max=" << f_max
- << " type=" << DataTypeString(DataTypeToEnum<T>::v());
- }
+ const double million_ops_per_second =
+ (iterations * num_ops) / static_cast<double>(total_duration);
+
+ LOG(INFO) << "TimeRequantizeManyInNewRange: " << num_elements
+ << (use_eigen ? " eigen" : " ref") << ": iterations=" << iterations
+ << ", MOps/s=" << million_ops_per_second
+ << ", one_run_duration=" << one_run_duration
+ << ", total_duration=" << total_duration;
+}
+
+template <typename T>
+void TestFloatToQuantizedInPlaceUsingEigen(
+ Eigen::ThreadPoolDevice* eigen_device) {
+ // These are the float values we're going to test the conversions on.
+ typedef std::pair<float, float> FPair;
+ for (FPair min_and_max : std::vector<FPair>{FPair(-255.0f, 255.0f), //
+ FPair(-1.0f, 1.0f), //
+ FPair(-1.0f, 255.0f), //
+ FPair(0.0f, 1e6), //
+ FPair(0.0f, 1.0f), //
+ FPair(-31.0f, 13.0f)}) {
+ const float f_min = min_and_max.first;
+ const float f_max = min_and_max.second;
+ const float f_range = f_max - f_min;
+ const int values_count = 50000;
+ Tensor input(DT_FLOAT, TensorShape{values_count});
+ auto input_array = input.flat<float>();
+ for (int i = 0; i < values_count; ++i) {
+ input_array(i) = f_min + f_range * i / (values_count - 1);
+ }
+
+ Tensor output(DataTypeToEnum<T>::v(), TensorShape{values_count});
+ FloatTensorToQuantizedInPlaceUsingEigen<T>(*eigen_device, input, f_min,
+ f_max, &output);
+ auto output_array = output.flat<T>();
+
+ const int tolerance = 1;
+ for (int i = 0; i < values_count; ++i) {
+ int32 expected = FloatToQuantized<T>(input_array(i), f_min, f_max);
+ int32 actual = output_array(i);
+
+ // The eigen computation uses float for constants and computation
+ // instead
+ // of doubles, so can be different by 1 or 2 in some cases (e.g., input
+ // value 144.062744140625, min -1, max 255, type quint8).
+ ASSERT_TRUE(std::abs(expected - actual) <= tolerance)
+ << "expected=" << expected << " actual=" << actual
+ << " tolerance=" << tolerance << " v=" << input_array(i) << " i=" << i
+ << " f_min=" << f_min << " f_max=" << f_max
+ << " type=" << DataTypeString(DataTypeToEnum<T>::v());
}
}
+}
- template <typename T>
- void TestQuantizedToFloatInPlaceUsingEigen(
- Eigen::ThreadPoolDevice* eigen_device) {
- // These are the float values we're going to test the conversions on.
- typedef std::pair<float, float> FPair;
- for (FPair min_and_max : std::vector<FPair>{
- FPair(-255.0f, 255.0f), FPair(-1.0f, 1.0f), FPair(-1.0f, 255.0f),
- FPair(0.0f, 1e6), FPair(0.0f, 1.0f), FPair(-31.0f, 13.0f),
- FPair(-5.89505e+08, 5.89505e+08),
- }) {
- const float f_min = min_and_max.first;
- const float f_max = min_and_max.second;
- const int values_count = sizeof(T) == 1 ? 256 : 50000;
- Tensor input(DataTypeToEnum<T>::v(), TensorShape{values_count});
- auto input_array = input.flat<T>();
- const double q_range =
- static_cast<double>(Eigen::NumTraits<T>::highest()) -
- Eigen::NumTraits<T>::lowest();
- for (int i = 0; i < values_count; ++i) {
- if (sizeof(T) == 1) {
- input_array(i) = Eigen::NumTraits<T>::lowest() + i;
- } else {
- int64 offset = static_cast<int64>(q_range / values_count * i);
- input_array(i) = static_cast<int32>(
- std::min<int64>(Eigen::NumTraits<T>::lowest() + offset,
- Eigen::NumTraits<T>::highest()));
- }
+template <typename T>
+void TestQuantizedToFloatInPlaceUsingEigen(
+ Eigen::ThreadPoolDevice* eigen_device) {
+ // These are the float values we're going to test the conversions on.
+ typedef std::pair<float, float> FPair;
+ for (FPair min_and_max : std::vector<FPair>{
+ FPair(-255.0f, 255.0f), FPair(-1.0f, 1.0f), FPair(-1.0f, 255.0f),
+ FPair(0.0f, 1e6), FPair(0.0f, 1.0f), FPair(-31.0f, 13.0f),
+ FPair(-5.89505e+08, 5.89505e+08),
+ }) {
+ const float f_min = min_and_max.first;
+ const float f_max = min_and_max.second;
+ const int values_count = sizeof(T) == 1 ? 256 : 50000;
+ Tensor input(DataTypeToEnum<T>::v(), TensorShape{values_count});
+ auto input_array = input.flat<T>();
+ const double q_range = static_cast<double>(Eigen::NumTraits<T>::highest()) -
+ Eigen::NumTraits<T>::lowest();
+ for (int i = 0; i < values_count; ++i) {
+ if (sizeof(T) == 1) {
+ input_array(i) = Eigen::NumTraits<T>::lowest() + i;
+ } else {
+ int64 offset = static_cast<int64>(q_range / values_count * i);
+ input_array(i) = static_cast<int32>(
+ std::min<int64>(Eigen::NumTraits<T>::lowest() + offset,
+ Eigen::NumTraits<T>::highest()));
}
+ }
- Tensor output(DT_FLOAT, TensorShape{values_count});
- QuantizedTensorToFloatInPlaceUsingEigen<T>(*eigen_device, input, f_min,
- f_max, &output);
- auto output_array = output.flat<float>();
- const double range = static_cast<double>(f_max) - f_min;
- for (int i = 0; i < values_count; ++i) {
- float expected = QuantizedToFloat<T>(input_array(i), f_min, f_max);
- float actual = output_array(i);
- ASSERT_NEAR(expected, actual, range * 1.1e-7)
- << "expected=" << expected << " actual=" << actual
- << " v=" << input_array(i) << " i=" << i << " f_min=" << f_min
- << " f_max=" << f_max
- << " type=" << DataTypeString(DataTypeToEnum<T>::v());
- }
+ Tensor output(DT_FLOAT, TensorShape{values_count});
+ QuantizedTensorToFloatInPlaceUsingEigen<T>(*eigen_device, input, f_min,
+ f_max, &output);
+ auto output_array = output.flat<float>();
+ const double range = static_cast<double>(f_max) - f_min;
+ for (int i = 0; i < values_count; ++i) {
+ float expected = QuantizedToFloat<T>(input_array(i), f_min, f_max);
+ float actual = output_array(i);
+ ASSERT_NEAR(expected, actual, range * 1.1e-7)
+ << "expected=" << expected << " actual=" << actual
+ << " v=" << input_array(i) << " i=" << i << " f_min=" << f_min
+ << " f_max=" << f_max
+ << " type=" << DataTypeString(DataTypeToEnum<T>::v());
}
}
-};
+}
-TEST_F(QuantizationUtilsTest, FloatToQuantized) {
+} // namespace
+
+void TestFloatToQuantized() {
EXPECT_EQ(quint8(0), FloatToQuantized<quint8>(0.0f, 0.0f, 1.0f));
EXPECT_EQ(quint8(0), FloatToQuantized<quint8>(0.0f, 0.0f, 2.0f));
EXPECT_EQ(quint8(128), FloatToQuantized<quint8>(0.5f, 0.0f, 1.0f));
@@ -318,7 +448,7 @@ TEST_F(QuantizationUtilsTest, FloatToQuantized) {
FloatToQuantized<qint32>(128.0f, -128.0f, 128.0f));
}
-TEST_F(QuantizationUtilsTest, QuantizedToFloat) {
+void TestQuantizedToFloat() {
EXPECT_LT(fabsf(0.0f - QuantizedToFloat<quint8>(0, 0.0f, 1.0f)), 1 / 255.0f);
EXPECT_LT(fabsf(0.0f - QuantizedToFloat<quint8>(0, 0.0f, 2.0f)), 1 / 255.0f);
EXPECT_LT(fabsf(0.5f - QuantizedToFloat<quint8>(127, 0.0f, 1.0f)),
@@ -349,7 +479,7 @@ TEST_F(QuantizationUtilsTest, QuantizedToFloat) {
1.0);
}
-TEST_F(QuantizationUtilsTest, AvoidBias) {
+void TestAvoidBias() {
for (int i = 0; i < 256; ++i) {
const float as_float = QuantizedToFloat<quint8>(i, 0.0f, 2.0f);
const int back_to_int = FloatToQuantized<quint8>(as_float, 0.0f, 2.0f);
@@ -371,7 +501,7 @@ TEST_F(QuantizationUtilsTest, AvoidBias) {
}
}
-TEST_F(QuantizationUtilsTest, RequantizeInNewRange) {
+void TestRequantizeInNewRange() {
// These are the float values we're going to test the conversions on.
const size_t values_count = 6;
const float values[values_count] = {0.0f, 0.5f, 1.0f, -1.0f, 127.0f, 255.0f};
@@ -407,7 +537,7 @@ TEST_F(QuantizationUtilsTest, RequantizeInNewRange) {
}
}
-TEST_F(QuantizationUtilsTest, RequantizeInNewRangeRealData) {
+void TestRequantizeInNewRangeRealData() {
const float input_min = -0.739539f;
const float input_max = 0.641057f;
const float output_min = -2381.49f;
@@ -428,7 +558,7 @@ TEST_F(QuantizationUtilsTest, RequantizeInNewRangeRealData) {
EXPECT_LT(std::abs(value_as_qint32 - actual_output), 10);
}
-TEST_F(QuantizationUtilsTest, RequantizeInNewRange32To8Bit) {
+void TestRequantizeInNewRange32To8Bit() {
// These are the float values we're going to test the conversions on.
const size_t values_count = 6;
const float values[values_count] = {0.0f, 0.45f, 1.0f, -1.0f, 127.0f, 255.0f};
@@ -464,27 +594,26 @@ TEST_F(QuantizationUtilsTest, RequantizeInNewRange32To8Bit) {
}
}
-TEST_F(QuantizationUtilsTest, RequantizeManyInNewRange32To8Bit) {
+void TestRequantizeManyInNewRange32To8Bit() {
TestRequantizeManyInNewRange32To8Bit(nullptr /* eigen_device */);
}
-TEST_F(QuantizationUtilsTest, RequantizeManyInNewRange32To8BitUsingEigen) {
+void TestRequantizeManyInNewRange32To8BitUsingEigen() {
thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
TestRequantizeManyInNewRange32To8Bit(&eigen_device);
}
-TEST_F(QuantizationUtilsTest, RequantizeManyInNewRange32To8BitEigenVsNonEigen) {
+void TestRequantizeManyInNewRange32To8BitEigenVsNonEigen() {
TestRequantizeManyInNewRangeEigenVsNonEigen<qint32, quint8>();
}
-TEST_F(QuantizationUtilsTest,
- RequantizeManyInNewRange32To8BitSignedEigenVsNonEigen) {
+void TestRequantizeManyInNewRange32To8BitSignedEigenVsNonEigen() {
TestRequantizeManyInNewRangeEigenVsNonEigen<qint32, qint8>();
}
-TEST_F(QuantizationUtilsTest, FloatTensorToQuantized) {
+void TestFloatTensorToQuantized() {
const int input_width = 3;
const int input_height = 3;
const float input_min = 0.0f;
@@ -500,7 +629,7 @@ TEST_F(QuantizationUtilsTest, FloatTensorToQuantized) {
// Verify that FloatToQuantizedInPlaceUsingEigen is same result as
// FloatToQuantized.
-TEST_F(QuantizationUtilsTest, FloatToQuantizedInPlaceUsingEigen) {
+void TestFloatToQuantizedInPlaceUsingEigen() {
thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
@@ -511,7 +640,7 @@ TEST_F(QuantizationUtilsTest, FloatToQuantizedInPlaceUsingEigen) {
TestFloatToQuantizedInPlaceUsingEigen<qint16>(&eigen_device);
}
-TEST_F(QuantizationUtilsTest, OverflowWithEigen) {
+void TestOverflowWithEigen() {
thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
@@ -537,7 +666,7 @@ TEST_F(QuantizationUtilsTest, OverflowWithEigen) {
test::ExpectTensorEqual<qint32>(expected, output);
}
-TEST_F(QuantizationUtilsTest, QuantizedTensorToFloat) {
+void TestQuantizedTensorToFloat() {
const int input_width = 3;
const int input_height = 3;
const float input_min = -128.0f;
@@ -579,7 +708,7 @@ TEST_F(QuantizationUtilsTest, QuantizedTensorToFloat) {
// Verify that QuantizedToFloatInPlaceUsingEigen is same result as
// QuantizedToFloat.
-TEST_F(QuantizationUtilsTest, QuantizedToFloatInPlaceUsingEigen) {
+void TestQuantizedToFloatInPlaceUsingEigen() {
thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */);
EigenThreadPoolWrapper wrapper(&threadpool);
Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */);
@@ -591,4 +720,56 @@ TEST_F(QuantizationUtilsTest, QuantizedToFloatInPlaceUsingEigen) {
TestQuantizedToFloatInPlaceUsingEigen<qint32>(&eigen_device);
}
+void BenchmarkRequantizeManyInNewRange() {
+ TimeRequantizeManyInNewRange<qint32, quint8>(1000, 1000, false);
+ TimeRequantizeManyInNewRange<qint32, quint8>(1000, 1000, true);
+ TimeRequantizeManyInNewRange<qint32, quint8>(100000, 100, false);
+ TimeRequantizeManyInNewRange<qint32, quint8>(100000, 100, true);
+ TimeRequantizeManyInNewRange<qint32, quint8>(1000000, 10, false);
+ TimeRequantizeManyInNewRange<qint32, quint8>(1000000, 10, true);
+
+ TimeRequantizeManyInNewRange<quint8, qint32>(1000, 1000, false);
+ TimeRequantizeManyInNewRange<quint8, qint32>(1000, 1000, true);
+ TimeRequantizeManyInNewRange<quint8, qint32>(100000, 100, false);
+ TimeRequantizeManyInNewRange<quint8, qint32>(100000, 100, true);
+ TimeRequantizeManyInNewRange<quint8, qint32>(1000000, 10, false);
+ TimeRequantizeManyInNewRange<quint8, qint32>(1000000, 10, true);
+}
+
} // namespace tensorflow
+
+#if defined(__ANDROID__)
+int main(int argc, char** argv) {
+#define RUN_TEST(t) \
+ LOG(INFO) << "Test: " << #t; \
+ tensorflow::t();
+#else
+#define RUN_TEST(t) \
+ TEST(QuantizationUtilsTest, t) { tensorflow::t(); }
+#endif
+
+ RUN_TEST(TestFloatToQuantized);
+ RUN_TEST(TestQuantizedToFloat);
+ RUN_TEST(TestAvoidBias);
+ RUN_TEST(TestRequantizeInNewRange);
+ RUN_TEST(TestRequantizeInNewRangeRealData);
+ RUN_TEST(TestRequantizeInNewRange32To8Bit);
+ RUN_TEST(TestRequantizeManyInNewRange32To8Bit);
+ RUN_TEST(TestRequantizeManyInNewRange32To8BitUsingEigen);
+ RUN_TEST(TestRequantizeManyInNewRange32To8BitEigenVsNonEigen);
+ RUN_TEST(TestRequantizeManyInNewRange32To8BitSignedEigenVsNonEigen);
+ RUN_TEST(TestFloatTensorToQuantized);
+ RUN_TEST(TestRequantizeManyInNewRange8To32Bit);
+ RUN_TEST(TestFloatToQuantizedInPlaceUsingEigen);
+ RUN_TEST(TestOverflowWithEigen);
+ RUN_TEST(TestQuantizedTensorToFloat);
+ RUN_TEST(TestQuantizedToFloatInPlaceUsingEigen);
+
+#if defined(__ANDROID__)
+
+ tensorflow::BenchmarkRequantizeManyInNewRange();
+
+ LOG(INFO) << "All tests complete.";
+ return 0;
+}
+#endif
diff --git a/tensorflow/core/kernels/quantized_add_op.cc b/tensorflow/core/kernels/quantized_add_op.cc
new file mode 100644
index 0000000000..8be0c56798
--- /dev/null
+++ b/tensorflow/core/kernels/quantized_add_op.cc
@@ -0,0 +1,581 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Implements a quantized eight-bit version of the matmul operation.
+
+#define EIGEN_USE_THREADS
+
+#if defined(__ARM_NEON__) || defined(__ARM_NEON)
+#define USE_NEON
+#define QUANTIZED_ADD_USE_NEON
+#include <arm_neon.h>
+#endif
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/meta_support.h"
+#include "tensorflow/core/kernels/quantization_utils.h"
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/bcast.h"
+
+// There are implementations for three broadcast patterns for add:
+// - Scalar * Array
+// - Array * Array
+// - Array * Shorter Array (repeated to match first)
+//
+// These handle a lot of common broadcast patterns, and we have NEON SIMD
+// versions to accelerate performance on ARM platforms.
+
+namespace tensorflow {
+namespace {
+
+template <class T, class Toutput>
+void ScalarAddition(OpKernelContext* context, const T* full_input,
+ float full_input_min, float full_input_max,
+ int64 num_elements, T scalar_input, float scalar_input_min,
+ float scalar_input_max, float output_min, float output_max,
+ Toutput* output) {
+ const Toutput scalar_in_output_range = RequantizeInNewRange<T, Toutput>(
+ scalar_input, scalar_input_min, scalar_input_max, output_min, output_max);
+ for (int i = 0; i < num_elements; ++i) {
+ const Toutput full_input_in_output_range = RequantizeInNewRange<T, Toutput>(
+ full_input[i], full_input_min, full_input_max, output_min, output_max);
+ output[i] = full_input_in_output_range + scalar_in_output_range;
+ }
+}
+
+#ifdef QUANTIZED_ADD_USE_NEON
+
+template <>
+void ScalarAddition(OpKernelContext* context, const quint8* full_input,
+ float full_input_min, float full_input_max,
+ int64 num_elements, quint8 scalar_input,
+ float scalar_input_min, float scalar_input_max,
+ float output_min, float output_max, qint32* output) {
+ const int32 scalar_in_output_range = RequantizeInNewRange<quint8, qint32>(
+ scalar_input, scalar_input_min, scalar_input_max, output_min, output_max);
+
+ const float input_0_float =
+ QuantizedToFloat<quint8>(0, full_input_min, full_input_max);
+ const float input_1_float =
+ QuantizedToFloat<quint8>(1, full_input_min, full_input_max);
+ const int64 input_0_int64 =
+ FloatToQuantizedUnclamped<qint32>(input_0_float, output_min, output_max);
+ const int64 input_1_int64 =
+ FloatToQuantizedUnclamped<qint32>(input_1_float, output_min, output_max);
+ const int32 input_mult_int32 = input_1_int64 - input_0_int64;
+
+ const int64 lowest_quantized =
+ static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
+ const int64 highest_quantized =
+ static_cast<int64>(Eigen::NumTraits<qint32>::highest());
+
+ const int64x2_t input_0_64x2 = vmovq_n_s64(input_0_int64);
+ const int32x2_t input_mult_32x2 = vmov_n_s32(input_mult_int32);
+ const int32x4_t scalar_in_output_range_32x4 =
+ vmovq_n_s32(scalar_in_output_range);
+ int64 i = 0;
+ for (; i < (num_elements - 7); i += 8) {
+ const uint8* full_input_ptr = &(full_input->value) + i;
+ const std::array<int32x4_t, 2> output_value =
+ Requantize8x8To32Neon(full_input_ptr, input_0_64x2, input_mult_32x2);
+ const int32x4_t result_low_32x4 =
+ vaddq_s32(output_value[0], scalar_in_output_range_32x4);
+ const int32x4_t result_high_32x4 =
+ vaddq_s32(output_value[1], scalar_in_output_range_32x4);
+ int32* output_ptr = &(output->value) + i;
+ vst1q_s32(output_ptr + 0, result_low_32x4);
+ vst1q_s32(output_ptr + 4, result_high_32x4);
+ }
+ for (; i < num_elements; ++i) {
+ const int64 full_input_value = static_cast<int64>(full_input[i]);
+ int64 full_input_in_output_range_64 =
+ input_0_int64 + (full_input_value * input_mult_int32);
+ full_input_in_output_range_64 =
+ std::max(full_input_in_output_range_64, lowest_quantized);
+ full_input_in_output_range_64 =
+ std::min(full_input_in_output_range_64, highest_quantized);
+ const int32 full_input_in_output_range =
+ static_cast<int32>(full_input_in_output_range_64);
+ output[i] = full_input_in_output_range + scalar_in_output_range;
+ }
+}
+
+#else // QUANTIZED_ADD_USE_NEON
+
+template <>
+void ScalarAddition(OpKernelContext* context, const quint8* full_input,
+ float full_input_min, float full_input_max,
+ int64 num_elements, quint8 scalar_input,
+ float scalar_input_min, float scalar_input_max,
+ float output_min, float output_max, qint32* output) {
+ const int32 scalar_in_output_range = RequantizeInNewRange<quint8, qint32>(
+ scalar_input, scalar_input_min, scalar_input_max, output_min, output_max);
+
+ const float input_0_float =
+ QuantizedToFloat<quint8>(0, full_input_min, full_input_max);
+ const float input_1_float =
+ QuantizedToFloat<quint8>(1, full_input_min, full_input_max);
+ const int64 input_0_int64 =
+ FloatToQuantizedUnclamped<qint32>(input_0_float, output_min, output_max);
+ const int64 input_1_int64 =
+ FloatToQuantizedUnclamped<qint32>(input_1_float, output_min, output_max);
+ const int32 input_mult_int32 = input_1_int64 - input_0_int64;
+
+ const int64 lowest_quantized =
+ static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
+ const int64 highest_quantized =
+ static_cast<int64>(Eigen::NumTraits<qint32>::highest());
+
+ for (int i = 0; i < num_elements; ++i) {
+ const int64 full_input_value = static_cast<int64>(full_input[i]);
+ int64 full_input_in_output_range_64 =
+ input_0_int64 + (full_input_value * input_mult_int32);
+ full_input_in_output_range_64 =
+ std::max(full_input_in_output_range_64, lowest_quantized);
+ full_input_in_output_range_64 =
+ std::min(full_input_in_output_range_64, highest_quantized);
+ const int32 full_input_in_output_range =
+ static_cast<int32>(full_input_in_output_range_64);
+ output[i] = full_input_in_output_range + scalar_in_output_range;
+ }
+}
+
+#endif // QUANTIZED_ADD_USE_NEON
+
+template <class T, class Toutput>
+void VectorAddition(OpKernelContext* context, const T* x_data, float min_x,
+ float max_x, const T* y_data, float min_y, float max_y,
+ int64 num_elements, float output_min, float output_max,
+ Toutput* output) {
+ for (int i = 0; i < num_elements; ++i) {
+ const Toutput x_in_output_range = RequantizeInNewRange<T, Toutput>(
+ x_data[i], min_x, max_x, output_min, output_max);
+ const Toutput y_in_output_range = RequantizeInNewRange<T, Toutput>(
+ y_data[i], min_y, max_y, output_min, output_max);
+ output[i] = x_in_output_range + y_in_output_range;
+ }
+}
+
+#ifdef QUANTIZED_ADD_USE_NEON
+
+template <>
+void VectorAddition(OpKernelContext* context, const quint8* x_data, float min_x,
+ float max_x, const quint8* y_data, float min_y, float max_y,
+ int64 num_elements, float output_min, float output_max,
+ qint32* output) {
+ const float x_0_float = QuantizedToFloat<quint8>(0, min_x, max_x);
+ const float x_1_float = QuantizedToFloat<quint8>(1, min_x, max_x);
+ const int64 x_0_int64 =
+ FloatToQuantizedUnclamped<qint32>(x_0_float, output_min, output_max);
+ const int64 x_1_int64 =
+ FloatToQuantizedUnclamped<qint32>(x_1_float, output_min, output_max);
+ const int32 x_mult_int32 = x_1_int64 - x_0_int64;
+
+ const float y_0_float = QuantizedToFloat<quint8>(0, min_y, max_y);
+ const float y_1_float = QuantizedToFloat<quint8>(1, min_y, max_y);
+ const int64 y_0_int64 =
+ FloatToQuantizedUnclamped<qint32>(y_0_float, output_min, output_max);
+ const int64 y_1_int64 =
+ FloatToQuantizedUnclamped<qint32>(y_1_float, output_min, output_max);
+ const int32 y_mult_int32 = y_1_int64 - y_0_int64;
+
+ const int64 lowest_quantized =
+ static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
+ const int64 highest_quantized =
+ static_cast<int64>(Eigen::NumTraits<qint32>::highest());
+
+ const int64x2_t x_0_64x2 = vmovq_n_s64(x_0_int64);
+ const int32x2_t x_mult_32x2 = vmov_n_s32(x_mult_int32);
+
+ const int64x2_t y_0_64x2 = vmovq_n_s64(y_0_int64);
+ const int32x2_t y_mult_32x2 = vmov_n_s32(y_mult_int32);
+
+ int64 i = 0;
+ for (; i < (num_elements - 7); i += 8) {
+ const uint8* x_ptr = &(x_data->value) + i;
+ const std::array<int32x4_t, 2> x_output_value =
+ Requantize8x8To32Neon(x_ptr, x_0_64x2, x_mult_32x2);
+ const uint8* y_ptr = &(y_data->value) + i;
+ const std::array<int32x4_t, 2> y_output_value =
+ Requantize8x8To32Neon(y_ptr, y_0_64x2, y_mult_32x2);
+
+ const int32x4_t result_low_32x4 =
+ vaddq_s32(x_output_value[0], y_output_value[0]);
+ const int32x4_t result_high_32x4 =
+ vaddq_s32(x_output_value[1], y_output_value[1]);
+ int32* output_ptr = &(output->value) + i;
+ vst1q_s32(output_ptr + 0, result_low_32x4);
+ vst1q_s32(output_ptr + 4, result_high_32x4);
+ }
+
+ for (; i < num_elements; ++i) {
+ const int64 x_value = static_cast<int64>(x_data[i]);
+ int64 x_in_output_range_64 = x_0_int64 + (x_value * x_mult_int32);
+ x_in_output_range_64 = std::max(x_in_output_range_64, lowest_quantized);
+ x_in_output_range_64 = std::min(x_in_output_range_64, highest_quantized);
+ const int32 x_in_output_range = static_cast<int32>(x_in_output_range_64);
+
+ const int64 y_value = static_cast<int64>(y_data[i]);
+ int64 y_in_output_range_64 = y_0_int64 + (y_value * y_mult_int32);
+ y_in_output_range_64 = std::max(y_in_output_range_64, lowest_quantized);
+ y_in_output_range_64 = std::min(y_in_output_range_64, highest_quantized);
+ const int32 y_in_output_range = static_cast<int32>(y_in_output_range_64);
+
+ output[i] = x_in_output_range + y_in_output_range;
+ }
+}
+
+#else // QUANTIZED_ADD_USE_NEON
+
+template <>
+void VectorAddition(OpKernelContext* context, const quint8* x_data, float min_x,
+ float max_x, const quint8* y_data, float min_y, float max_y,
+ int64 num_elements, float output_min, float output_max,
+ qint32* output) {
+ const float x_0_float = QuantizedToFloat<quint8>(0, min_x, max_x);
+ const float x_1_float = QuantizedToFloat<quint8>(1, min_x, max_x);
+ const int64 x_0_int64 =
+ FloatToQuantizedUnclamped<qint32>(x_0_float, output_min, output_max);
+ const int64 x_1_int64 =
+ FloatToQuantizedUnclamped<qint32>(x_1_float, output_min, output_max);
+ const int32 x_mult_int32 = x_1_int64 - x_0_int64;
+
+ const float y_0_float = QuantizedToFloat<quint8>(0, min_y, max_y);
+ const float y_1_float = QuantizedToFloat<quint8>(1, min_y, max_y);
+ const int64 y_0_int64 =
+ FloatToQuantizedUnclamped<qint32>(y_0_float, output_min, output_max);
+ const int64 y_1_int64 =
+ FloatToQuantizedUnclamped<qint32>(y_1_float, output_min, output_max);
+ const int32 y_mult_int32 = y_1_int64 - y_0_int64;
+
+ const int64 lowest_quantized =
+ static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
+ const int64 highest_quantized =
+ static_cast<int64>(Eigen::NumTraits<qint32>::highest());
+
+ for (int i = 0; i < num_elements; ++i) {
+ const int64 x_value = static_cast<int64>(x_data[i]);
+ int64 x_in_output_range_64 = x_0_int64 + (x_value * x_mult_int32);
+ x_in_output_range_64 = std::max(x_in_output_range_64, lowest_quantized);
+ x_in_output_range_64 = std::min(x_in_output_range_64, highest_quantized);
+ const int32 x_in_output_range = static_cast<int32>(x_in_output_range_64);
+
+ const int64 y_value = static_cast<int64>(y_data[i]);
+ int64 y_in_output_range_64 = y_0_int64 + (y_value * y_mult_int32);
+ y_in_output_range_64 = std::max(y_in_output_range_64, lowest_quantized);
+ y_in_output_range_64 = std::min(y_in_output_range_64, highest_quantized);
+ const int32 y_in_output_range = static_cast<int32>(y_in_output_range_64);
+
+ output[i] = x_in_output_range + y_in_output_range;
+ }
+}
+
+#endif // QUANTIZED_ADD_USE_NEON
+
+template <class T, class Toutput>
+void VectorTensorAddition(const T* vector_data, float min_vector,
+ float max_vector, int64 vector_num_elements,
+ const T* tensor_data, float min_tensor,
+ float max_tensor, int64 tensor_num_elements,
+ float output_min, float output_max, Toutput* output) {
+ for (int i = 0; i < tensor_num_elements; ++i) {
+ const int64 vector_i = i % vector_num_elements;
+ const Toutput vector_in_output_range = RequantizeInNewRange<T, Toutput>(
+ vector_data[vector_i], min_vector, max_vector, output_min, output_max);
+ const Toutput tensor_in_output_range = RequantizeInNewRange<T, Toutput>(
+ tensor_data[i], min_tensor, max_tensor, output_min, output_max);
+ output[i] = vector_in_output_range + tensor_in_output_range;
+ }
+}
+
+#ifdef QUANTIZED_ADD_USE_NEON
+
+template <>
+void VectorTensorAddition(const quint8* vector_data, float min_vector,
+ float max_vector, int64 vector_num_elements,
+ const quint8* tensor_data, float min_tensor,
+ float max_tensor, int64 tensor_num_elements,
+ float output_min, float output_max, qint32* output) {
+ const float vector_0_float =
+ QuantizedToFloat<quint8>(0, min_vector, max_vector);
+ const float vector_1_float =
+ QuantizedToFloat<quint8>(1, min_vector, max_vector);
+ const int64 vector_0_int64 =
+ FloatToQuantizedUnclamped<qint32>(vector_0_float, output_min, output_max);
+ const int64 vector_1_int64 =
+ FloatToQuantizedUnclamped<qint32>(vector_1_float, output_min, output_max);
+ const int32 vector_mult_int32 = vector_1_int64 - vector_0_int64;
+
+ const float tensor_0_float =
+ QuantizedToFloat<quint8>(0, min_tensor, max_tensor);
+ const float tensor_1_float =
+ QuantizedToFloat<quint8>(1, min_tensor, max_tensor);
+ const int64 tensor_0_int64 =
+ FloatToQuantizedUnclamped<qint32>(tensor_0_float, output_min, output_max);
+ const int64 tensor_1_int64 =
+ FloatToQuantizedUnclamped<qint32>(tensor_1_float, output_min, output_max);
+ const int32 tensor_mult_int32 = tensor_1_int64 - tensor_0_int64;
+
+ const int64 lowest_quantized =
+ static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
+ const int64 highest_quantized =
+ static_cast<int64>(Eigen::NumTraits<qint32>::highest());
+
+ const int64x2_t vector_0_64x2 = vmovq_n_s64(vector_0_int64);
+ const int32x2_t vector_mult_32x2 = vmov_n_s32(vector_mult_int32);
+
+ const int64x2_t tensor_0_64x2 = vmovq_n_s64(tensor_0_int64);
+ const int32x2_t tensor_mult_32x2 = vmov_n_s32(tensor_mult_int32);
+
+ for (int64 base_i = 0; base_i < tensor_num_elements;
+ base_i += vector_num_elements) {
+ int64 i = base_i;
+ int64 vector_i = 0;
+ for (; vector_i < (vector_num_elements - 7); vector_i += 8, i += 8) {
+ const uint8* vector_ptr = &(vector_data->value) + vector_i;
+ const std::array<int32x4_t, 2> vector_output_value =
+ Requantize8x8To32Neon(vector_ptr, vector_0_64x2, vector_mult_32x2);
+ const uint8* tensor_ptr = &(tensor_data->value) + i;
+ const std::array<int32x4_t, 2> tensor_output_value =
+ Requantize8x8To32Neon(tensor_ptr, tensor_0_64x2, tensor_mult_32x2);
+
+ const int32x4_t result_low_32x4 =
+ vaddq_s32(vector_output_value[0], tensor_output_value[0]);
+ const int32x4_t result_high_32x4 =
+ vaddq_s32(vector_output_value[1], tensor_output_value[1]);
+ int32* output_ptr = &(output->value) + i;
+ vst1q_s32(output_ptr + 0, result_low_32x4);
+ vst1q_s32(output_ptr + 4, result_high_32x4);
+ }
+ for (; vector_i < vector_num_elements; ++vector_i, ++i) {
+ const int64 vector_value = static_cast<int64>(vector_data[vector_i]);
+ int64 vector_in_output_range_64 =
+ vector_0_int64 + (vector_value * vector_mult_int32);
+ vector_in_output_range_64 =
+ std::max(vector_in_output_range_64, lowest_quantized);
+ vector_in_output_range_64 =
+ std::min(vector_in_output_range_64, highest_quantized);
+ const int32 vector_in_output_range =
+ static_cast<int32>(vector_in_output_range_64);
+
+ const int64 tensor_value = static_cast<int64>(tensor_data[i]);
+ int64 tensor_in_output_range_64 =
+ tensor_0_int64 + (tensor_value * tensor_mult_int32);
+ tensor_in_output_range_64 =
+ std::max(tensor_in_output_range_64, lowest_quantized);
+ tensor_in_output_range_64 =
+ std::min(tensor_in_output_range_64, highest_quantized);
+ const int32 tensor_in_output_range =
+ static_cast<int32>(tensor_in_output_range_64);
+
+ output[i] = vector_in_output_range + tensor_in_output_range;
+ }
+ }
+}
+
+#else // QUANTIZED_ADD_USE_NEON
+
+template <>
+void VectorTensorAddition(const quint8* vector_data, float min_vector,
+ float max_vector, int64 vector_num_elements,
+ const quint8* tensor_data, float min_tensor,
+ float max_tensor, int64 tensor_num_elements,
+ float output_min, float output_max, qint32* output) {
+ const float vector_0_float =
+ QuantizedToFloat<quint8>(0, min_vector, max_vector);
+ const float vector_1_float =
+ QuantizedToFloat<quint8>(1, min_vector, max_vector);
+ const int64 vector_0_int64 =
+ FloatToQuantizedUnclamped<qint32>(vector_0_float, output_min, output_max);
+ const int64 vector_1_int64 =
+ FloatToQuantizedUnclamped<qint32>(vector_1_float, output_min, output_max);
+ const int32 vector_mult_int32 = vector_1_int64 - vector_0_int64;
+
+ const float tensor_0_float =
+ QuantizedToFloat<quint8>(0, min_tensor, max_tensor);
+ const float tensor_1_float =
+ QuantizedToFloat<quint8>(1, min_tensor, max_tensor);
+ const int64 tensor_0_int64 =
+ FloatToQuantizedUnclamped<qint32>(tensor_0_float, output_min, output_max);
+ const int64 tensor_1_int64 =
+ FloatToQuantizedUnclamped<qint32>(tensor_1_float, output_min, output_max);
+ const int32 tensor_mult_int32 = tensor_1_int64 - tensor_0_int64;
+
+ const int64 lowest_quantized =
+ static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
+ const int64 highest_quantized =
+ static_cast<int64>(Eigen::NumTraits<qint32>::highest());
+
+ for (int i = 0; i < tensor_num_elements; ++i) {
+ const int64 vector_i = i % vector_num_elements;
+ const int64 vector_value = static_cast<int64>(vector_data[vector_i]);
+ int64 vector_in_output_range_64 =
+ vector_0_int64 + (vector_value * vector_mult_int32);
+ vector_in_output_range_64 =
+ std::max(vector_in_output_range_64, lowest_quantized);
+ vector_in_output_range_64 =
+ std::min(vector_in_output_range_64, highest_quantized);
+ const int32 vector_in_output_range =
+ static_cast<int32>(vector_in_output_range_64);
+
+ const int64 tensor_value = static_cast<int64>(tensor_data[i]);
+ int64 tensor_in_output_range_64 =
+ tensor_0_int64 + (tensor_value * tensor_mult_int32);
+ tensor_in_output_range_64 =
+ std::max(tensor_in_output_range_64, lowest_quantized);
+ tensor_in_output_range_64 =
+ std::min(tensor_in_output_range_64, highest_quantized);
+ const int32 tensor_in_output_range =
+ static_cast<int32>(tensor_in_output_range_64);
+
+ output[i] = vector_in_output_range + tensor_in_output_range;
+ }
+}
+
+#endif // QUANTIZED_ADD_USE_NEON
+
+} // namespace
+
+template <class T, class Toutput>
+class QuantizedAddOp : public OpKernel {
+ public:
+ explicit QuantizedAddOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& x = context->input(0);
+ const Tensor& y = context->input(1);
+ const float min_x = context->input(2).flat<float>()(0);
+ const float max_x = context->input(3).flat<float>()(0);
+ const float min_y = context->input(4).flat<float>()(0);
+ const float max_y = context->input(5).flat<float>()(0);
+
+ BCast bcast(BCast::FromShape(x.shape()), BCast::FromShape(y.shape()));
+ if (!bcast.IsValid()) {
+ context->SetStatus(errors::InvalidArgument(
+ "Incompatible shapes: ", x.shape().DebugString(), " vs. ",
+ y.shape().DebugString()));
+ return;
+ }
+ Tensor* z;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0, BCast::ToShape(bcast.output_shape()), &z));
+
+ // Make sure that we have valid quantization ranges for the input buffers.
+ // If the difference between the min and max is negative or zero, it makes
+ // it hard to do meaningful intermediate operations on the values.
+ OP_REQUIRES(context, (max_x > min_x),
+ errors::InvalidArgument("max_x must be larger than min_x."));
+ OP_REQUIRES(context, (max_y > min_y),
+ errors::InvalidArgument("max_y must be larger than min_y."));
+ const T* x_data = x.flat<T>().data();
+ const T* y_data = y.flat<T>().data();
+ Toutput* z_data = z->flat<Toutput>().data();
+
+ // We want the range of the output to be symmetrical around zero so that
+ // adding zero leaves the result unchanged, and to contain the largest of
+ // the two input values with some room to spare.
+ const float smallest_min = std::min(min_x, min_y);
+ const float largest_max = std::min(max_x, max_y);
+ const float biggest_range =
+ std::max(std::abs(smallest_min), std::abs(largest_max));
+ const float output_range = (biggest_range * (1 << 14));
+ const float min_z_value = -output_range;
+ const float max_z_value = output_range;
+
+ const int ndims = bcast.x_reshape().size();
+ if (ndims <= 1) {
+ if (x.NumElements() == 1) {
+ ScalarAddition<T, Toutput>(context, y_data, min_y, max_y,
+ y.NumElements(), x_data[0], min_x, max_x,
+ min_z_value, max_z_value, z_data);
+ } else if (y.NumElements() == 1) {
+ ScalarAddition<T, Toutput>(context, x_data, min_x, max_x,
+ x.NumElements(), y_data[0], min_y, max_y,
+ min_z_value, max_z_value, z_data);
+ } else {
+ VectorAddition<T, Toutput>(context, x_data, min_x, max_x, y_data, min_y,
+ max_y, x.NumElements(), min_z_value,
+ max_z_value, z_data);
+ }
+ } else if (ndims == 2) {
+ const T* vector_data;
+ int64 vector_num_elements;
+ float vector_min;
+ float vector_max;
+ const T* tensor_data;
+ int64 tensor_num_elements;
+ float tensor_min;
+ float tensor_max;
+ if (x.NumElements() < y.NumElements()) {
+ vector_data = x_data;
+ vector_num_elements = x.NumElements();
+ vector_min = min_x;
+ vector_max = max_x;
+ tensor_data = y_data;
+ tensor_num_elements = y.NumElements();
+ tensor_min = min_y;
+ tensor_max = max_y;
+ } else {
+ vector_data = y_data;
+ vector_num_elements = y.NumElements();
+ vector_min = min_y;
+ vector_max = max_y;
+ tensor_data = x_data;
+ tensor_num_elements = x.NumElements();
+ tensor_min = min_x;
+ tensor_max = max_x;
+ }
+ VectorTensorAddition<T, Toutput>(
+ vector_data, vector_min, vector_max, vector_num_elements, tensor_data,
+ tensor_min, tensor_max, tensor_num_elements, min_z_value, max_z_value,
+ z_data);
+ } else {
+ LOG(INFO) << "ndims=" << ndims;
+ LOG(INFO) << "bcast.x_reshape()="
+ << TensorShape(bcast.x_reshape()).DebugString();
+ LOG(INFO) << "bcast.y_reshape()="
+ << TensorShape(bcast.y_reshape()).DebugString();
+ LOG(INFO) << "bcast.x_bcast()="
+ << TensorShape(bcast.x_bcast()).DebugString();
+ LOG(INFO) << "bcast.y_bcast()="
+ << TensorShape(bcast.y_bcast()).DebugString();
+
+ context->SetStatus(errors::Unimplemented(
+ "Broadcast between ", context->input(0).shape().DebugString(),
+ " and ", context->input(1).shape().DebugString(),
+ " is not supported yet."));
+ return;
+ }
+
+ Tensor* z_min = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(1, {}, &z_min));
+ z_min->flat<float>()(0) = min_z_value;
+
+ Tensor* z_max = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(2, {}, &z_max));
+ z_max->flat<float>()(0) = max_z_value;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("QuantizedAdd")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<quint8>("T1")
+ .TypeConstraint<quint8>("T2")
+ .TypeConstraint<qint32>("Toutput"),
+ QuantizedAddOp<quint8, qint32>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/quantized_add_op_test.cc b/tensorflow/core/kernels/quantized_add_op_test.cc
new file mode 100644
index 0000000000..74d16b282d
--- /dev/null
+++ b/tensorflow/core/kernels/quantized_add_op_test.cc
@@ -0,0 +1,311 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/math_ops.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/quantization_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+using namespace ops; // NOLINT(build/namespaces)
+
+namespace {
+
+void TestAdd(const std::vector<int64>& x_shape,
+ const std::vector<float>& x_values, float x_min_value,
+ float x_max_value, const std::vector<int64>& y_shape,
+ const std::vector<float>& y_values, float y_min_value,
+ float y_max_value, const std::vector<int64>& expected_shape,
+ const std::vector<float>& expected_values, double tolerance) {
+ Scope root = Scope::NewRootScope();
+
+ Tensor x_float_tensor(DT_FLOAT, TensorShape(x_shape));
+ test::FillValues<float>(&x_float_tensor, x_values);
+ Tensor x_quantized_tensor(DT_QUINT8, x_float_tensor.shape());
+ FloatTensorToQuantizedInPlace<quint8>(x_float_tensor, x_min_value,
+ x_max_value, &x_quantized_tensor);
+ Output x =
+ Const(root.WithOpName("x"), Input::Initializer(x_quantized_tensor));
+ Output x_min = Const(root.WithOpName("x_min"), x_min_value);
+ Output x_max = Const(root.WithOpName("x_max"), x_max_value);
+
+ Tensor y_float_tensor(DT_FLOAT, TensorShape(y_shape));
+ test::FillValues<float>(&y_float_tensor, y_values);
+ Tensor y_quantized_tensor(DT_QUINT8, y_float_tensor.shape());
+ FloatTensorToQuantizedInPlace<quint8>(y_float_tensor, y_min_value,
+ y_max_value, &y_quantized_tensor);
+ Output y =
+ Const(root.WithOpName("y"), Input::Initializer(y_quantized_tensor));
+ Output y_min = Const(root.WithOpName("y_min"), y_min_value);
+ Output y_max = Const(root.WithOpName("y_max"), y_max_value);
+
+ ops::QuantizedAdd add = ops::QuantizedAdd(root.WithOpName("add"), x, y, x_min,
+ x_max, y_min, y_max);
+
+ TF_EXPECT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+
+ TF_EXPECT_OK(session.Run(ClientSession::FeedType(),
+ {add.z, add.min_z, add.max_z}, &outputs));
+
+ const Tensor& z_quantized = outputs[0];
+ const float z_min = outputs[1].flat<float>()(0);
+ const float z_max = outputs[2].flat<float>()(0);
+
+ Tensor z_float = QuantizedTensorToFloat<qint32>(z_quantized, z_min, z_max);
+ Tensor expected_z_float(DT_FLOAT, TensorShape(expected_shape));
+ test::FillValues<float>(&expected_z_float, expected_values);
+ test::ExpectTensorNear<float>(expected_z_float, z_float, tolerance);
+}
+
+void TestAddShape(const std::vector<int64>& x_shape,
+ const std::vector<int64>& y_shape) {
+ const size_t x_num_elements = TensorShape(x_shape).num_elements();
+ std::vector<float> x_values(x_num_elements);
+ for (int i = 0; i < x_num_elements; ++i) {
+ x_values[i] = i % 256;
+ }
+ const float x_min_value = 0.0f;
+ const float x_max_value = 256.0f;
+
+ const size_t y_num_elements = TensorShape(y_shape).num_elements();
+ std::vector<float> y_values(y_num_elements);
+ for (int i = 0; i < y_num_elements; ++i) {
+ y_values[i] = ((i + 23) % 123) - 50;
+ }
+ const float y_min_value = -150.0f;
+ const float y_max_value = 150.0f;
+
+ Scope root = Scope::NewRootScope();
+
+ Tensor x_float_tensor(DT_FLOAT, TensorShape(x_shape));
+ test::FillValues<float>(&x_float_tensor, x_values);
+ Output x = Const(root.WithOpName("x"), Input::Initializer(x_float_tensor));
+
+ Tensor y_float_tensor(DT_FLOAT, TensorShape(y_shape));
+ test::FillValues<float>(&y_float_tensor, y_values);
+ Output y = Const(root.WithOpName("y"), Input::Initializer(y_float_tensor));
+
+ Add add = Add(root.WithOpName("add"), x, y);
+
+ TF_EXPECT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {add.z}, &outputs));
+
+ const Tensor& expected_values_tensor = outputs[0];
+ const float* expected_values_data =
+ expected_values_tensor.flat<float>().data();
+ std::vector<float> expected_values(
+ expected_values_data,
+ expected_values_data + expected_values_tensor.NumElements());
+ std::vector<int64> expected_shape;
+ for (const int64 dim : expected_values_tensor.shape().dim_sizes()) {
+ expected_shape.push_back(dim);
+ }
+ TestAdd(x_shape, x_values, x_min_value, x_max_value, y_shape, y_values,
+ y_min_value, y_max_value, expected_shape, expected_values, 256.0);
+}
+
+void TimeAdd(const std::vector<int64>& x_shape,
+ const std::vector<int64>& y_shape, int64 iterations) {
+ TestAddShape(x_shape, y_shape);
+
+ Scope root = Scope::NewRootScope();
+
+ Tensor x_quantized_tensor(DT_QUINT8, TensorShape(x_shape));
+ Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_QUINT8);
+ Output x_min = Const(root.WithOpName("x_min"), 0.0f);
+ Output x_max = Const(root.WithOpName("x_max"), 1.0f);
+
+ Tensor y_quantized_tensor(DT_QUINT8, TensorShape(y_shape));
+ Output y =
+ Const(root.WithOpName("y"), Input::Initializer(y_quantized_tensor));
+ Output y_min = Const(root.WithOpName("y_min"), 0.0f);
+ Output y_max = Const(root.WithOpName("y_max"), 1.0f);
+
+ ops::QuantizedAdd add = ops::QuantizedAdd(root.WithOpName("add"), placeholder,
+ y, x_min, x_max, y_min, y_max);
+
+ TF_EXPECT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+
+ int64 total_duration = 0;
+ for (int i = 0; i < iterations; ++i) {
+ const int64 start_time = Env::Default()->NowMicros();
+ TF_EXPECT_OK(session.Run({{placeholder, x_quantized_tensor}},
+ {add.z, add.min_z, add.max_z}, &outputs));
+ const int64 end_time = Env::Default()->NowMicros();
+ total_duration += end_time - start_time;
+ }
+ const int64 one_run_duration = total_duration / iterations;
+
+ const int64 num_ops = outputs[0].NumElements();
+
+ const double million_ops_per_second =
+ (iterations * num_ops) / static_cast<double>(total_duration);
+
+ LOG(INFO) << "TimeAdd: " << TensorShape(x_shape).DebugString() << " * "
+ << TensorShape(y_shape).DebugString()
+ << ": iterations=" << iterations
+ << ", MOps/s=" << million_ops_per_second
+ << ", one_run_duration=" << one_run_duration
+ << ", total_duration=" << total_duration;
+}
+
+} // namespace
+
+void TestManualScalar() {
+ TestAdd(
+ {10}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}, 0.0f,
+ 10.0f, {1}, {10.0f}, -100.0f, 100.0f, {10},
+ {11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f},
+ 1.0f);
+ TestAdd(
+ {1}, {10.0f}, -100.0f, 100.0f, {10},
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}, 0.0f,
+ 10.0f, {10},
+ {11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f},
+ 1.0f);
+}
+
+void TestScalar() {
+ TestAddShape({100}, {1});
+ TestAddShape({1}, {100});
+}
+
+void TestManualVector() {
+ TestAdd({10}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f},
+ 0.0f, 10.0f, {10},
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}, 0.0f,
+ 10.0f, {10},
+ {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f, 16.0f, 18.0f, 20.0f},
+ 1.0f);
+}
+
+void TestVector() { TestAddShape({100}, {100}); }
+
+void TestManualVectorPlusTensor() {
+ TestAdd(
+ {10}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}, 0.0f,
+ 10.0f, {2, 10},
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
+ 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f},
+ 0.0f, 20.0f, {2, 10},
+ {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f, 16.0f, 18.0f, 20.0f,
+ 12.0f, 14.0f, 16.0f, 18.0f, 20.0f, 22.0f, 24.0f, 26.0f, 28.0f, 30.0f},
+ 1.0f);
+ TestAdd({2, 10}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f,
+ 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f,
+ 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f},
+ 0.0f, 20.0f, {10},
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}, 0.0f,
+ 10.0f, {2, 10}, {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f,
+ 16.0f, 18.0f, 20.0f, 12.0f, 14.0f, 16.0f, 18.0f,
+ 20.0f, 22.0f, 24.0f, 26.0f, 28.0f, 30.0f},
+ 1.0f);
+ TestAdd(
+ {5, 2}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f},
+ 0.0f, 10.0f, {2, 5, 2},
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
+ 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f},
+ 0.0f, 20.0f, {2, 5, 2},
+ {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f, 16.0f, 18.0f, 20.0f,
+ 12.0f, 14.0f, 16.0f, 18.0f, 20.0f, 22.0f, 24.0f, 26.0f, 28.0f, 30.0f},
+ 1.0f);
+}
+
+void TestVectorPlusTensor() {
+ TestAddShape({100}, {2, 100});
+ TestAddShape({2, 100}, {100});
+ TestAddShape({5, 2}, {2, 5, 2});
+}
+
+void BenchmarkTensorScalar() {
+ TimeAdd({200}, {1}, 1000);
+ TimeAdd({10000}, {1}, 100);
+ TimeAdd({1000000}, {1}, 10);
+ TimeAdd({10000000}, {1}, 1);
+}
+
+void BenchmarkVector() {
+ TimeAdd({200}, {200}, 1000);
+ TimeAdd({10000}, {10000}, 100);
+ TimeAdd({1000000}, {1000000}, 10);
+ TimeAdd({10000000}, {10000000}, 1);
+}
+
+void BenchmarkVectorPlusTensor() {
+ TimeAdd({10, 20}, {20}, 100);
+ TimeAdd({10, 1000}, {1000}, 10);
+ TimeAdd({1000, 1000}, {1000}, 1);
+ TimeAdd({10000, 1000}, {1000}, 1);
+ TimeAdd({100, 100}, {100}, 10);
+ TimeAdd({10000, 100}, {100}, 1);
+ TimeAdd({100000, 100}, {100}, 1);
+}
+
+#if !defined(__ANDROID__)
+
+#define RUN_TEST(t) \
+ TEST(QuantizedAddOpTest, t) { t(); }
+
+RUN_TEST(TestManualScalar);
+RUN_TEST(TestManualVector);
+RUN_TEST(TestManualVectorPlusTensor);
+RUN_TEST(TestScalar);
+RUN_TEST(TestVector);
+RUN_TEST(TestVectorPlusTensor);
+
+#undef RUN_TEST
+
+#endif // __ANDROID__
+
+} // end namespace tensorflow
+
+#if defined(__ANDROID__)
+int main(int argc, char** argv) {
+ LOG(INFO) << "TestManualScalar:";
+ tensorflow::TestManualScalar();
+ LOG(INFO) << "TestManualVector:";
+ tensorflow::TestManualVector();
+ LOG(INFO) << "TestManualVectorPlusTensor:";
+ tensorflow::TestManualVectorPlusTensor();
+ tensorflow::BenchmarkTensorScalar();
+ tensorflow::BenchmarkVector();
+ tensorflow::BenchmarkVectorPlusTensor();
+ LOG(INFO) << "All tests complete";
+ return 0;
+}
+#endif // __ANDROID__
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 28c4ec643e..3cf8a98193 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -2252,6 +2252,35 @@ max_z: The float value that the highest quantized output value represents.
broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
)doc");
+REGISTER_OP("QuantizedAdd")
+ .Input("x: T1")
+ .Input("y: T2")
+ .Input("min_x: float")
+ .Input("max_x: float")
+ .Input("min_y: float")
+ .Input("max_y: float")
+ .Output("z: Toutput")
+ .Output("min_z: float")
+ .Output("max_z: float")
+ .Attr("T1: quantizedtype")
+ .Attr("T2: quantizedtype")
+ .Attr("Toutput: quantizedtype = DT_QINT32")
+ .SetIsCommutative()
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
+ .Doc(R"doc(
+Returns x + y element-wise, working on quantized buffers.
+
+min_x: The float value that the lowest quantized `x` value represents.
+max_x: The float value that the highest quantized `x` value represents.
+min_y: The float value that the lowest quantized `y` value represents.
+max_y: The float value that the highest quantized `y` value represents.
+min_z: The float value that the lowest quantized output value represents.
+max_z: The float value that the highest quantized output value represents.
+
+*NOTE*: `QuantizedAdd` supports limited forms of broadcasting. More about
+broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+)doc");
+
REGISTER_OP("QuantizeDownAndShrinkRange")
.Input("input: Tinput")
.Input("input_min: float")
diff --git a/tensorflow/core/platform/test.cc b/tensorflow/core/platform/test.cc
index 920f983982..d71f72e322 100644
--- a/tensorflow/core/platform/test.cc
+++ b/tensorflow/core/platform/test.cc
@@ -15,15 +15,17 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
-#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_POSIX_ANDROID) || \
- defined(PLATFORM_GOOGLE_ANDROID)
+#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
#include "tensorflow/core/platform/google/build_config/googletest.h"
#endif
+#include <cstdlib>
+#include <iostream>
+
namespace tensorflow {
namespace testing {
-#if defined(PLATFORM_GOOGLE) || defined(__ANDROID__)
+#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
string TmpDir() { return FLAGS_test_tmpdir; }
int RandomSeed() { return FLAGS_test_random_seed; }
#else
diff --git a/tensorflow/tools/graph_transforms/quantize_nodes.cc b/tensorflow/tools/graph_transforms/quantize_nodes.cc
index 78078ab6ab..5497ad008b 100644
--- a/tensorflow/tools/graph_transforms/quantize_nodes.cc
+++ b/tensorflow/tools/graph_transforms/quantize_nodes.cc
@@ -56,6 +56,13 @@ struct QuantizedOpInfo {
// conversion process can transform them.
const std::vector<QuantizedOpInfo>& GetQuantizedOpList() {
static const std::vector<QuantizedOpInfo> op_list = {
+ {"Add",
+ {},
+ {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
+ DT_QUINT8,
+ DT_QINT32,
+ {},
+ QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
{"AvgPool",
{"ksize", "strides", "padding"},
{{"T", DT_QUINT8}},
diff --git a/tensorflow/tools/graph_transforms/quantize_nodes_test.cc b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc
index f8fe13ca13..d02655f3f9 100644
--- a/tensorflow/tools/graph_transforms/quantize_nodes_test.cc
+++ b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc
@@ -105,9 +105,9 @@ class QuantizeNodesTest : public ::testing::Test {
&quantized_graph_def);
// Reshape is not included here because it can be added as part of the
// quantization process.
- const std::set<string> quantizable_ops = {"BiasAdd", "Concat", "Conv2D",
- "MatMul", "Relu", "Relu6",
- "AvgPool", "MaxPool", "Mul"};
+ const std::set<string> quantizable_ops = {
+ "Add", "BiasAdd", "Concat", "Conv2D", "MatMul",
+ "Relu", "Relu6", "AvgPool", "MaxPool", "Mul"};
for (const NodeDef& node : quantized_graph_def.node()) {
EXPECT_EQ(0, quantizable_ops.count(node.op()))
<< "Found quantizable node " << node.op() << " for node named "
@@ -277,6 +277,41 @@ class QuantizeNodesTest : public ::testing::Test {
TestQuantizedVersusFloatGraph(float_graph_def, {}, {"mul"});
}
+ void TestQuantizeAdd() {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ std::vector<int64> x_shape({10, 100});
+ const size_t x_num_elements = TensorShape(x_shape).num_elements();
+ std::vector<float> x_values(x_num_elements);
+ for (int i = 0; i < x_num_elements; ++i) {
+ x_values[i] = (i % 256) / 256.0f;
+ }
+
+ std::vector<int64> y_shape({100});
+ const size_t y_num_elements = TensorShape(y_shape).num_elements();
+ std::vector<float> y_values(y_num_elements);
+ for (int i = 0; i < y_num_elements; ++i) {
+ y_values[i] = ((i + 23) % 123) - 50;
+ }
+
+ Scope root = Scope::NewRootScope();
+
+ Tensor x_float_tensor(DT_FLOAT, TensorShape(x_shape));
+ test::FillValues<float>(&x_float_tensor, x_values);
+ Output x = Const(root.WithOpName("x"), Input::Initializer(x_float_tensor));
+
+ Tensor y_float_tensor(DT_FLOAT, TensorShape(y_shape));
+ test::FillValues<float>(&y_float_tensor, y_values);
+ Output y = Const(root.WithOpName("y"), Input::Initializer(y_float_tensor));
+
+ Add add = Add(root.WithOpName("add"), x, y);
+
+ GraphDef float_graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&float_graph_def));
+
+ TestQuantizedVersusFloatGraph(float_graph_def, {}, {"add"});
+ }
+
void TestQuantizeConv2D(int depth, int input_width, int input_height,
int input_batch_count, int filter_size,
int filter_count, int stride, const string& padding,
@@ -1382,6 +1417,8 @@ TEST_F(QuantizeNodesTest, TestQuantizeMatMulSmall) {
TEST_F(QuantizeNodesTest, TestQuantizeMul) { TestQuantizeMul(); }
+TEST_F(QuantizeNodesTest, TestQuantizeAdd) { TestQuantizeAdd(); }
+
TEST_F(QuantizeNodesTest, TestOddPaddingProblem) {
// Tests one error case we ran into in a real graph.
TestQuantizeConv2D(1, 4, 4, 1, 3, 1, 2, "SAME",