diff options
author | 2017-05-24 14:28:35 -0700 | |
---|---|---|
committer | 2017-05-24 14:32:48 -0700 | |
commit | 625ce1ac462292fb4bc76a06343f170871cb428c (patch) | |
tree | bed59540605273477e07e79f6088db8224399743 | |
parent | 63aa126486b7edd48a3c9b52e8d2047c17004c9f (diff) |
Implement quantized addition op, with NEON-acceleration for ARM devices
PiperOrigin-RevId: 157037658
-rw-r--r-- | tensorflow/contrib/makefile/tf_op_files.txt | 1 | ||||
-rw-r--r-- | tensorflow/core/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/BUILD | 116 | ||||
-rw-r--r-- | tensorflow/core/kernels/quantization_utils.h | 282 | ||||
-rw-r--r-- | tensorflow/core/kernels/quantization_utils_test.cc | 605 | ||||
-rw-r--r-- | tensorflow/core/kernels/quantized_add_op.cc | 581 | ||||
-rw-r--r-- | tensorflow/core/kernels/quantized_add_op_test.cc | 311 | ||||
-rw-r--r-- | tensorflow/core/ops/math_ops.cc | 29 | ||||
-rw-r--r-- | tensorflow/core/platform/test.cc | 8 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/quantize_nodes.cc | 7 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/quantize_nodes_test.cc | 43 |
11 files changed, 1754 insertions, 230 deletions
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index b9cd91e519..a2c9b4c5bd 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -189,6 +189,7 @@ tensorflow/core/kernels/quantization_utils.cc tensorflow/core/kernels/quantize_down_and_shrink_range.cc tensorflow/core/kernels/quantize_op.cc tensorflow/core/kernels/quantized_activation_ops.cc +tensorflow/core/kernels/quantized_add_op.cc tensorflow/core/kernels/quantized_batch_norm_op.cc tensorflow/core/kernels/quantized_bias_add_op.cc tensorflow/core/kernels/quantized_concat_op.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index ed3eeb4b21..e5ab8570d8 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1072,6 +1072,7 @@ filegroup( ":framework/shape_inference_testutil.h", ":framework/tensor_testutil.cc", ":framework/tensor_testutil.h", + ":platform/test.cc", ":platform/test.h", ":util/reporter.cc", ":util/reporter.h", diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index aaa2543183..fe7e1f46f4 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4259,6 +4259,7 @@ filegroup( "quantize_down_and_shrink_range.cc", "quantize_op.cc", "quantized_activation_ops.cc", + "quantized_add_op.cc", "quantized_batch_norm_op.cc", "quantized_bias_add_op.cc", "quantized_concat_op.cc", @@ -4365,6 +4366,7 @@ tf_kernel_library( "quantize_down_and_shrink_range.cc", "quantize_op.cc", "quantized_activation_ops.cc", + "quantized_add_op.cc", "quantized_batch_norm_op.cc", "quantized_bias_add_op.cc", "quantized_concat_op.cc", @@ -4468,6 +4470,48 @@ tf_cc_test( ], ) +# Android-only test for quantization utilities. +cc_binary( + name = "quantization_utils_test_android_only", + testonly = 1, + srcs = ["quantization_utils_test.cc"], + copts = tf_copts(), + linkopts = select({ + "//tensorflow:android": [ + "-lm", + "-llog", + "-pie", + "-std=c++11", + ], + "//conditions:default": [], + }), + linkstatic = 1, + tags = [ + "manual", + "notap", + ], + deps = [ + ] + select({ + "//tensorflow:android": [ + ":android_tensorflow_kernels", + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:android_tensorflow_test_lib", + ], + "//conditions:default": [ + ":quantized_ops", + "//third_party/eigen3", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:client_session", + "//tensorflow/core:framework", + "//tensorflow/core:tensor_testutil", + "//tensorflow/core:test_main", + ], + }), +) + tf_cc_test( name = "quantized_activation_ops_test", srcs = ["quantized_activation_ops_test.cc"], @@ -4486,6 +4530,72 @@ tf_cc_test( ], ) +# Android-only test for quantized addition. +cc_binary( + name = "quantized_add_op_test_android_only", + testonly = 1, + srcs = ["quantized_add_op_test.cc"], + copts = tf_copts(), + linkopts = select({ + "//tensorflow:android": [ + "-lm", + "-llog", + "-pie", + "-std=c++11", + ], + "//conditions:default": [], + }), + linkstatic = 1, + tags = [ + "manual", + "notap", + ], + deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:client_session", + ] + select({ + "//tensorflow:android": [ + ":android_tensorflow_kernels", + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:android_tensorflow_test_lib", + ], + "//conditions:default": [ + ":ops_util", + ":quantized_ops", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensor_testutil", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], + }), +) + +tf_cc_test( + name = "quantized_add_op_test", + size = "small", + srcs = ["quantized_add_op_test.cc"], + deps = [ + ":math", + ":ops_testutil", + ":ops_util", + ":quantized_ops", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:client_session", + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:core_cpu", + "//tensorflow/core:direct_session", + "//tensorflow/core:framework", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test( name = "quantized_bias_add_op_test", size = "small", @@ -4564,7 +4674,7 @@ tf_cc_test( ], ) -# Android-only test for quantized instance norm. +# Android-only test for quantized multiply. cc_binary( name = "quantized_mul_op_test_android_only", testonly = 1, @@ -4590,9 +4700,13 @@ cc_binary( "//tensorflow/core:android_tensorflow_test_lib", ], "//conditions:default": [ + ":ops_util", + ":quantized_ops", "//tensorflow/core:framework", "//tensorflow/core:tensor_testutil", "//tensorflow/core:test_main", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", ], }), ) diff --git a/tensorflow/core/kernels/quantization_utils.h b/tensorflow/core/kernels/quantization_utils.h index be67dfd112..e258efd545 100644 --- a/tensorflow/core/kernels/quantization_utils.h +++ b/tensorflow/core/kernels/quantization_utils.h @@ -24,6 +24,13 @@ limitations under the License. // optimized. They should be implementable using fixed point representations // to avoid a dependency on floating-point hardware. +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define QUANTIZATION_UTILS_USE_NEON +#include <arm_neon.h> +#endif + +#include <array> + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK #include "public/gemmlowp.h" @@ -203,7 +210,7 @@ inline T2 RequantizeInNewRange(T1 input, float min_input, float max_input, } template <class T1, class T2> -inline void RequantizeManyInNewRange(const T1* input, size_t count, +inline void RequantizeManyInNewRange(const T1* input, int64 count, float min_input, float max_input, float min_output, float max_output, T2* output) { @@ -217,10 +224,11 @@ inline void RequantizeManyInNewRange(const T1* input, size_t count, // Because converting 32-bit accumulated results down to eight bit is a common // case, we have a specialized code path to handle it as efficiently as // possible using only fixed-point math for the inner loop. -template <> -inline void RequantizeManyInNewRange<qint32, quint8>( - const qint32* input, size_t count, float min_input, float max_input, - float min_output, float max_output, quint8* output) { +inline void RequantizeManyInNewRangeReference(const qint32* input, int64 count, + float min_input, float max_input, + float min_output, + float max_output, + quint8* output) { // Initially we calculate all the constants we need once, before we go into // the inner loop. If this is updated, also update the Eigen version. const int fp_shift = 16; @@ -236,9 +244,10 @@ inline void RequantizeManyInNewRange<qint32, quint8>( const int64 input_offset_fp = static_cast<int64>(input_rezero * recip_output_range * (1 << fp_shift)); const int64 output_offset_fp = - output_range == 0.0 ? 0 : static_cast<int64>((1 << fp_shift) * - (min_output * 255.0) / - output_range); + output_range == 0.0 + ? 0 + : static_cast<int64>((1 << fp_shift) * (min_output * 255.0) / + output_range); const int64 rounding_delta = 1 << (fp_shift - 1); // Inside this loop we just do minimal adds, multiplies, and shifts, in a way @@ -258,6 +267,256 @@ inline void RequantizeManyInNewRange<qint32, quint8>( } } +// Another common case is converting eight bit inputs up to thirty two bits, so +// we have specialized fixed-point code to accelerate that. There is also a NEON +// version for ARM devices below. +inline void RequantizeManyInNewRange8To32BitReference( + const quint8* input, int64 count, float min_input, float max_input, + float min_output, float max_output, qint32* output) { + const float code_0_float = QuantizedToFloat<quint8>(0, min_input, max_input); + const float code_1_float = QuantizedToFloat<quint8>(1, min_input, max_input); + const int64 code_0_int64 = + FloatToQuantizedUnclamped<qint32>(code_0_float, min_output, max_output); + const int64 code_1_int64 = + FloatToQuantizedUnclamped<qint32>(code_1_float, min_output, max_output); + const int32 mult_int32 = code_1_int64 - code_0_int64; + const int64 lowest_quantized = + static_cast<int64>(Eigen::NumTraits<qint32>::lowest()); + const int64 highest_quantized = + static_cast<int64>(Eigen::NumTraits<qint32>::highest()); + for (int64 i = 0; i < count; ++i) { + const int64 input_value = static_cast<int64>(input[i]); + int64 output_value = code_0_int64 + (input_value * mult_int32); + output_value = std::max(output_value, lowest_quantized); + output_value = std::min(output_value, highest_quantized); + output[i] = static_cast<int32>(output_value); + } +} + +#ifdef QUANTIZATION_UTILS_USE_NEON +// Speeds up the 32->8bit conversion using fixed-point arithmetic and NEON SIMD +// intrinsics for ARM platforms. +inline void RequantizeManyInNewRangeNeon(const qint32* input, int64 count, + float min_input, float max_input, + float min_output, float max_output, + quint8* output) { + // Initially we calculate all the constants we need once, before we go into + // the inner loop. If this is updated, also update the Eigen version. + const int fp_shift = 16; + + // Calculate range variables in advance. + // Input range. + const float input_range = max_input - min_input; + // Output range. + const float output_range = max_output - min_output; + // Ratio of output range. + const float recip_output_range = + output_range == 0.0 ? 0.0 : (255.0 / output_range); + // Average of input range as zero position of input. + const float input_rezero = (min_input + max_input) / 2.0; + // In-out range scale. + const int32 range_scale_fp = + output_range == 0.0 ? 0.0 + : static_cast<int32>(255.0 * (1 << (fp_shift - 16)) * + input_range / output_range); + // Input zero position offset to output. + const int32 input_offset_fp = + static_cast<int32>(input_rezero * recip_output_range * (1 << fp_shift)); + // Output min offset. + const int32 output_offset_fp = + output_range == 0.0 + ? 0 + : static_cast<int32>((1 << fp_shift) * (min_output * 255.0) / + output_range); + const int32 rounding_delta = 1 << (fp_shift - 1); + + // broadcast range to each lane + const int32x4_t range_scale_fp_32x4 = vmovq_n_s32(range_scale_fp); + const int32x4_t input_offset_fp_32x4 = vmovq_n_s32(input_offset_fp); + const int32x4_t output_offset_fp_32x4 = vmovq_n_s32(output_offset_fp); + const int32x4_t rounding_delta_32x4 = vmovq_n_s32(rounding_delta); + + int64 index = 0; + // Use SIMD to requantize. + for (; index < (count - 7); index += 8) { + const int32* input_ptr = &(input->value) + index; + const int32x4_t input_value_low_32x4 = vld1q_s32(input_ptr); + const int32x4_t input_value_high_32x4 = vld1q_s32(input_ptr + 4); + const int32x4_t fp_value_low_32x4 = vaddq_s32( + input_offset_fp_32x4, + vmulq_s32(vshrq_n_s32(input_value_low_32x4, 16), range_scale_fp_32x4)); + const int32x4_t fp_value_high_32x4 = vaddq_s32( + input_offset_fp_32x4, + vmulq_s32(vshrq_n_s32(input_value_high_32x4, 16), range_scale_fp_32x4)); + const int32x4_t offset_intermediate_low_32x4 = + vsubq_s32(fp_value_low_32x4, output_offset_fp_32x4); + const int32x4_t offset_intermediate_high_32x4 = + vsubq_s32(fp_value_high_32x4, output_offset_fp_32x4); + const int32x4_t round_intermediate_low_32x4 = + vaddq_s32(offset_intermediate_low_32x4, rounding_delta_32x4); + const int32x4_t round_intermediate_high_32x4 = + vaddq_s32(offset_intermediate_high_32x4, rounding_delta_32x4); + const int16x4_t quantized_low_16x4 = + vqmovn_s32(vshrq_n_s32(round_intermediate_low_32x4, fp_shift)); + const int16x4_t quantized_high_16x4 = + vqmovn_s32(vshrq_n_s32(round_intermediate_high_32x4, fp_shift)); + const uint8x8_t quantized_8x8 = + vqmovun_s16(vcombine_s16(quantized_low_16x4, quantized_high_16x4)); + uint8* output_ptr = &(output->value) + index; + vst1_u8(output_ptr, quantized_8x8); + } + + // Requantize remaining elements in array without SIMD. + for (; index < count; ++index) { + const int32 input_value = static_cast<int32>(input[index]); + const int32 fp_value = + static_cast<int32>( + (static_cast<int32>(input_value >> 16) * (range_scale_fp))) + + input_offset_fp; + const int32 offset_intermediate = fp_value - output_offset_fp; + const int32 round_intermediate = offset_intermediate + rounding_delta; + int32 quantized_int32 = round_intermediate >> fp_shift; + quantized_int32 = std::max(quantized_int32, 0); + quantized_int32 = std::min(quantized_int32, 255); + output[index] = static_cast<quint8>(static_cast<int32>(quantized_int32)); + } +} + +template <> +inline void RequantizeManyInNewRange<qint32, quint8>( + const qint32* input, int64 count, float min_input, float max_input, + float min_output, float max_output, quint8* output) { + const float input_range = max_input - min_input; + const float output_range = max_output - min_output; + if ((input_range / output_range) > 16384.0f) { + // Our NEON implementation uses 32-bit math and can't handle very + // large ranges, so fall back to the reference implementation. We don't + // expect these to be common in models, so this shouldn't be a performance + // problem in practice. + RequantizeManyInNewRangeReference(input, count, min_input, max_input, + min_output, max_output, output); + } else { + RequantizeManyInNewRangeNeon(input, count, min_input, max_input, min_output, + max_output, output); + } +} + +// Requantize 8 x 8 quints to 8 x 32 qints in parallel by neon +// Return std::array instead of pointer to leverage return value optimization +inline std::array<int32x4_t, 2> Requantize8x8To32Neon( + const uint8* input_ptr, const int64x2_t input_0_64x2, + const int32x2_t input_mult_32x2) { + const uint8x8_t input_value_8x8 = vld1_u8(input_ptr); + const int16x8_t input_value_16x8 = + vreinterpretq_s16_u16(vmovl_u8(input_value_8x8)); + const int16x4_t input_value_low_16x4 = vget_low_s16(input_value_16x8); + const int16x4_t input_value_high_16x4 = vget_high_s16(input_value_16x8); + const int32x4_t input_value_low_32x4 = vmovl_s16(input_value_low_16x4); + const int32x4_t input_value_high_32x4 = vmovl_s16(input_value_high_16x4); + const int32x2_t input_value_low_low_32x2 = vget_low_s32(input_value_low_32x4); + const int32x2_t input_value_low_high_32x2 = + vget_high_s32(input_value_low_32x4); + const int32x2_t input_value_high_low_32x2 = + vget_low_s32(input_value_high_32x4); + const int32x2_t input_value_high_high_32x2 = + vget_high_s32(input_value_high_32x4); + const int64x2_t mult_result_low_low_64x2 = + vmlal_s32(input_0_64x2, input_value_low_low_32x2, input_mult_32x2); + const int64x2_t mult_result_low_high_64x2 = + vmlal_s32(input_0_64x2, input_value_low_high_32x2, input_mult_32x2); + const int64x2_t mult_result_high_low_64x2 = + vmlal_s32(input_0_64x2, input_value_high_low_32x2, input_mult_32x2); + const int64x2_t mult_result_high_high_64x2 = + vmlal_s32(input_0_64x2, input_value_high_high_32x2, input_mult_32x2); + const int32x2_t output_value_low_low_32x2 = + vqmovn_s64(mult_result_low_low_64x2); + const int32x2_t output_value_low_high_32x2 = + vqmovn_s64(mult_result_low_high_64x2); + const int32x2_t output_value_high_low_32x2 = + vqmovn_s64(mult_result_high_low_64x2); + const int32x2_t output_value_high_high_32x2 = + vqmovn_s64(mult_result_high_high_64x2); + const int32x4_t output_value_low_32x4 = + vcombine_s32(output_value_low_low_32x2, output_value_low_high_32x2); + const int32x4_t output_value_high_32x4 = + vcombine_s32(output_value_high_low_32x2, output_value_high_high_32x2); + return std::array<int32x4_t, 2>{ + {output_value_low_32x4, output_value_high_32x4}}; +} + +// Speeds up the 8->32bit conversion using fixed-point arithmetic and NEON SIMD +// intrinsics for ARM platforms. +template <> +inline void RequantizeManyInNewRange<quint8, qint32>( + const quint8* input, int64 count, float min_input, float max_input, + float min_output, float max_output, qint32* output) { + // Pre-calculate zero position and multiplier. + // Calculate 0 and 1 value in float. + const float code_0_float = QuantizedToFloat<quint8>(0, min_input, max_input); + const float code_1_float = QuantizedToFloat<quint8>(1, min_input, max_input); + + // Cast 0 and 1 value in int64. + const int64 code_0_int64 = + FloatToQuantizedUnclamped<qint32>(code_0_float, min_output, max_output); + const int64 code_1_int64 = + FloatToQuantizedUnclamped<qint32>(code_1_float, min_output, max_output); + + // Calculate multiplier. + const int32 mult_int32 = static_cast<int32>(code_1_int64 - code_0_int64); + + // Broadcast 0 position and multiplier to lanes + const int64x2_t code_0_64x2 = vmovq_n_s64(code_0_int64); + const int32x2_t mult_32x2 = vmov_n_s32(mult_int32); + + int64 i = 0; + + // Use SIMD to requantize array. + for (; i < (count - 7); i += 8) { + const uint8* input_ptr = &(input->value) + i; + int32* output_ptr = &(output->value) + i; + const std::array<int32x4_t, 2> output_value = + Requantize8x8To32Neon(input_ptr, code_0_64x2, mult_32x2); + vst1q_s32(output_ptr + 0, output_value[0]); + vst1q_s32(output_ptr + 4, output_value[1]); + } + + // Requantize remaining elements in array without SIMD. + const int64 lowest_quantized = + static_cast<int64>(Eigen::NumTraits<qint32>::lowest()); + const int64 highest_quantized = + static_cast<int64>(Eigen::NumTraits<qint32>::highest()); + + for (; i < count; ++i) { + const int64 input_value = static_cast<int64>(input[i]); + int64 output_value = code_0_int64 + (input_value * mult_int32); + output_value = std::max(output_value, lowest_quantized); + output_value = std::min(output_value, highest_quantized); + output[i] = static_cast<int32>(output_value); + } +} + +#else + +// If SIMD implementations aren't available, then use these default reference +// versions. +template <> +inline void RequantizeManyInNewRange<qint32, quint8>( + const qint32* input, int64 count, float min_input, float max_input, + float min_output, float max_output, quint8* output) { + RequantizeManyInNewRangeReference(input, count, min_input, max_input, + min_output, max_output, output); +} + +template <> +inline void RequantizeManyInNewRange<quint8, qint32>( + const quint8* input, int64 count, float min_input, float max_input, + float min_output, float max_output, qint32* output) { + RequantizeManyInNewRange8To32BitReference(input, count, min_input, max_input, + min_output, max_output, output); +} + +#endif + template <int shift> struct int64_right_shift_op { EIGEN_EMPTY_STRUCT_CTOR(int64_right_shift_op) @@ -305,9 +564,10 @@ inline void RequantizeManyInNewRangeUsingEigen<qint32, quint8>( const int64 input_offset_fp = static_cast<int64>(input_rezero * recip_output_range * (1 << fp_shift)); const int64 output_offset_fp = - output_range == 0.0 ? 0 : static_cast<int64>((1 << fp_shift) * - (min_output * 255.0) / - output_range); + output_range == 0.0 + ? 0 + : static_cast<int64>((1 << fp_shift) * (min_output * 255.0) / + output_range); const int64 rounding_delta = 1 << (fp_shift - 1); // Inside this eigen expression we just do minimal adds, multiplies, and diff --git a/tensorflow/core/kernels/quantization_utils_test.cc b/tensorflow/core/kernels/quantization_utils_test.cc index 0c23c0586c..c547b166ee 100644 --- a/tensorflow/core/kernels/quantization_utils_test.cc +++ b/tensorflow/core/kernels/quantization_utils_test.cc @@ -18,66 +18,99 @@ limitations under the License. #include <limits> #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/kernels/quantization_utils.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/quantization_utils.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { +namespace { + +void TestRequantizeMany(Eigen::ThreadPoolDevice* eigen_device, float input_min, + float input_max, float output_min, float output_max, + const std::vector<qint32>& values_quantized, + int tolerance = 1) { + const int values_count = values_quantized.size(); + std::vector<quint8> expected_values; + for (int value_index = 0; value_index < values_count; ++value_index) { + expected_values.push_back(FloatToQuantized<quint8>( + QuantizedToFloat(values_quantized[value_index], input_min, input_max), + output_min, output_max)); + } -class QuantizationUtilsTest : public ::testing::Test { - protected: - void TestRequantizeMany(Eigen::ThreadPoolDevice* eigen_device, - float input_min, float input_max, float output_min, - float output_max, - const std::vector<qint32>& values_quantized, - int tolerance = 1) { - const int values_count = values_quantized.size(); - std::vector<quint8> expected_values; - for (int value_index = 0; value_index < values_count; ++value_index) { - expected_values.push_back(FloatToQuantized<quint8>( - QuantizedToFloat(values_quantized[value_index], input_min, input_max), - output_min, output_max)); - } + Tensor i_tensor = + tensorflow::test::AsTensor(gtl::ArraySlice<qint32>(values_quantized)); + Tensor o_tensor(DT_QUINT8, TensorShape{values_count}); + auto output_values = o_tensor.flat<quint8>(); + + if (eigen_device == nullptr) { + auto input_array = i_tensor.flat<qint32>(); + RequantizeManyInNewRange(input_array.data(), input_array.size(), input_min, + input_max, output_min, output_max, + output_values.data()); + } else { + RequantizeManyInNewRangeUsingEigen<qint32, quint8>( + *eigen_device, i_tensor, input_min, input_max, output_min, output_max, + &o_tensor); + } - Tensor i_tensor = - tensorflow::test::AsTensor(gtl::ArraySlice<qint32>(values_quantized)); - Tensor o_tensor(DT_QUINT8, TensorShape{values_count}); - auto output_values = o_tensor.flat<quint8>(); + const string tolerance_str = strings::StrCat("+-", tolerance); + for (size_t value_index = 0; value_index < values_count; ++value_index) { + int e = expected_values[value_index]; + int v = output_values(value_index); + ASSERT_TRUE(std::abs(e - v) <= tolerance) + << "actual=" << v << ", expected=" << e << tolerance_str + << ", values_quantized[" << value_index + << "]=" << values_quantized[value_index] << ", input_min=" << input_min + << ", input_max=" << input_max << ", output_min=" << output_min + << ", output_max=" << output_max << ", value_index=" << value_index; + } +} - if (eigen_device == nullptr) { - auto input_array = i_tensor.flat<qint32>(); - RequantizeManyInNewRange(input_array.data(), input_array.size(), - input_min, input_max, output_min, output_max, - output_values.data()); - } else { - RequantizeManyInNewRangeUsingEigen<qint32, quint8>( - *eigen_device, i_tensor, input_min, input_max, output_min, output_max, - &o_tensor); - } +void TestRequantizeMany8To32Bit(float input_min, float input_max, + float output_min, float output_max, + const std::vector<quint8>& values_quantized, + int tolerance = 256) { + const int values_count = values_quantized.size(); + std::vector<qint32> expected_values; + for (int value_index = 0; value_index < values_count; ++value_index) { + expected_values.push_back(FloatToQuantized<qint32>( + QuantizedToFloat(values_quantized[value_index], input_min, input_max), + output_min, output_max)); + } - const string tolerance_str = strings::StrCat("+-", tolerance); - for (size_t value_index = 0; value_index < values_count; ++value_index) { - int e = expected_values[value_index]; - int v = output_values(value_index); - ASSERT_TRUE(std::abs(e - v) <= tolerance) - << "actual=" << v << ", expected=" << e << tolerance_str - << ", values_quantized[" << value_index - << "]=" << values_quantized[value_index] - << ", input_min=" << input_min << ", input_max=" << input_max - << ", output_min=" << output_min << ", output_max=" << output_max - << ", value_index=" << value_index; - } + Tensor i_tensor = + tensorflow::test::AsTensor(gtl::ArraySlice<quint8>(values_quantized)); + Tensor o_tensor(DT_QINT32, TensorShape{values_count}); + auto output_values = o_tensor.flat<qint32>(); + + auto input_array = i_tensor.flat<quint8>(); + RequantizeManyInNewRange(input_array.data(), input_array.size(), input_min, + input_max, output_min, output_max, + output_values.data()); + + const string tolerance_str = strings::StrCat("+-", tolerance); + for (size_t value_index = 0; value_index < values_count; ++value_index) { + int e = expected_values[value_index]; + int v = output_values(value_index); + ASSERT_TRUE(std::abs(e - v) <= tolerance) + << "actual=" << v << ", expected=" << e << tolerance_str + << ", values_quantized[" << value_index + << "]=" << values_quantized[value_index] << ", input_min=" << input_min + << ", input_max=" << input_max << ", output_min=" << output_min + << ", output_max=" << output_max << ", value_index=" << value_index; } +} - // If eigen_device is NULL, then the reference implementation is tested. - void TestRequantizeManyInNewRange32To8Bit( - Eigen::ThreadPoolDevice* eigen_device) { +// If eigen_device is NULL, then the reference implementation is tested. +void TestRequantizeManyInNewRange32To8Bit( + Eigen::ThreadPoolDevice* eigen_device) { + if (true) { // These are the float values we're going to test the conversions on. const size_t values_count = 6; const float values[values_count] = {0.0f, 0.45f, 1.0f, @@ -108,7 +141,7 @@ class QuantizationUtilsTest : public ::testing::Test { qint32 high = Eigen::NumTraits<qint32>::highest(); std::vector<qint32> vals{low, high}; int num_steps = 14419; - qint32 step = static_cast<int32>((1L << 32) / num_steps); + qint32 step = static_cast<int32>((1LL << 32) / num_steps); qint32 v = low + static_cast<qint32>(1); for (int i = 0; i < num_steps; ++i) { vals.push_back(v); @@ -120,181 +153,278 @@ class QuantizationUtilsTest : public ::testing::Test { vals); TestRequantizeMany(eigen_device, -1.0f, 12345678.0f, -12345678.0f, 12345678.0f, vals); + } + // Test when the input range is large and output range is small. + // Use all quantized values where the float is in the output range. + const float out_min = -29.1234; + const float out_max = 23.1234; + const float in_min = -1e6; + const float in_max = 1e6; + + qint32 low = FloatToQuantized<qint32>(out_min, in_min, in_max); + qint32 high = FloatToQuantized<qint32>(out_max, in_min, in_max); + std::vector<qint32> vals; + vals.clear(); + for (int32 i = low; i <= high; ++i) vals.push_back(i); + TestRequantizeMany(eigen_device, in_min, in_max, out_min, out_max, vals); +} - // Test when the input range is large and output range is small. - // Use all quantized values where the float is in the output range. - const float out_min = -29.1234; - const float out_max = 23.1234; - const float in_min = -1e6; - const float in_max = 1e6; - - low = FloatToQuantized<qint32>(out_min, in_min, in_max); - high = FloatToQuantized<qint32>(out_max, in_min, in_max); - vals.clear(); - for (int32 i = low; i <= high; ++i) vals.push_back(i); - TestRequantizeMany(eigen_device, in_min, in_max, out_min, out_max, vals); +void TestRequantizeManyInNewRange8To32Bit() { + // These are the float values we're going to test the conversions on. + const size_t values_count = 6; + const float values[values_count] = {0.0f, 0.45f, 1.0f, -1.0f, 127.0f, 255.0f}; + // These are the input and output ranges we'll test. + const size_t ranges_count = 6; + const float ranges[ranges_count][4] = { + {0.0f, 255.0f, 0.0f, 255.0f}, // + {0.0f, 1.0f, 0.0f, 1.0f}, // + {-1.0f, 1.0f, -1.0f, 1.0f}, // + {-1.0f, 1.0f, -255.0f, 255.0f}, // + {3.0f, 3.0f, 0.0f, 255.0f}, // input min == max + {0.0f, 255.0f, 5.0f, 5.0f}, // output min == max + }; + for (int i = 0; i < ranges_count; ++i) { + const auto& r = ranges[i]; + std::vector<quint8> values_quantized; + for (int value_index = 0; value_index < values_count; ++value_index) { + const float v = values[value_index]; + values_quantized.push_back(FloatToQuantized<quint8>(v, r[0], r[1])); + } + TestRequantizeMany8To32Bit(r[0], r[1], r[2], r[3], values_quantized); } - template <typename InputType, typename OutputType> - void TestRequantizeManyInNewRangeEigenVsNonEigen() { - thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */); - EigenThreadPoolWrapper wrapper(&threadpool); - Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */); + // Test with many different values in the input quantized range. + int low = Eigen::NumTraits<quint8>::lowest(); + int high = Eigen::NumTraits<quint8>::highest(); + std::vector<quint8> vals; + for (int val = low; val <= high; ++val) { + vals.push_back(val); + } + TestRequantizeMany8To32Bit(-1.0f, 1.0f, -1.0f, 1.0f, vals); + TestRequantizeMany8To32Bit(-255.0f, 255.0f, -255.0f, 255.0f, vals); + TestRequantizeMany8To32Bit(-1.0f, 1.0f, -12345678.0f, 12345678.0f, vals); + TestRequantizeMany8To32Bit(-1.0f, 12345678.0f, -12345678.0f, 12345678.0f, + vals); +} - const size_t ranges_count = 6; - const float ranges[ranges_count][4] = { - {0.0f, 255.0f, 0.0f, 255.0f}, // - {0.0f, 1.0f, 0.0f, 1.0f}, // - {-1.0f, 1.0f, -1.0f, 1.0f}, // - {-1.0f, 1.0f, -255.0f, 255.0f}, // - {3.0f, 3.0f, 0.0f, 255.0f}, // input min == max - {0.0f, 255.0f, 5.0f, 5.0f}, // output min == max - }; +template <typename InputType, typename OutputType> +void TestRequantizeManyInNewRangeEigenVsNonEigen() { + thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */); + EigenThreadPoolWrapper wrapper(&threadpool); + Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */); - // Random values. - for (size_t range_index = 0; range_index < ranges_count; ++range_index) { - const float input_min = ranges[range_index][0]; - const float input_max = ranges[range_index][1]; - const float output_min = ranges[range_index][2]; - const float output_max = ranges[range_index][3]; - const int values_count = 10000; - random::PhiloxRandom philox(testing::RandomSeed(), 17); - random::SimplePhilox rnd(&philox); - std::vector<InputType> values_quantized; - for (int i = 0; i < values_count; ++i) { - float v = (rnd.RandFloat() * (input_max - input_min)) + input_min; - values_quantized.push_back( - FloatToQuantized<InputType>(v, input_min, input_max)); - } + const size_t ranges_count = 6; + const float ranges[ranges_count][4] = { + {0.0f, 255.0f, 0.0f, 255.0f}, // + {0.0f, 1.0f, 0.0f, 1.0f}, // + {-1.0f, 1.0f, -1.0f, 1.0f}, // + {-1.0f, 1.0f, -255.0f, 255.0f}, // + {3.0f, 3.0f, 0.0f, 255.0f}, // input min == max + {0.0f, 255.0f, 5.0f, 5.0f}, // output min == max + }; + + // Random values. + for (size_t range_index = 0; range_index < ranges_count; ++range_index) { + const float input_min = ranges[range_index][0]; + const float input_max = ranges[range_index][1]; + const float output_min = ranges[range_index][2]; + const float output_max = ranges[range_index][3]; + const int values_count = 10000; + random::PhiloxRandom philox(testing::RandomSeed(), 17); + random::SimplePhilox rnd(&philox); + std::vector<InputType> values_quantized; + for (int i = 0; i < values_count; ++i) { + float v = (rnd.RandFloat() * (input_max - input_min)) + input_min; + values_quantized.push_back( + FloatToQuantized<InputType>(v, input_min, input_max)); + } - Tensor i_tensor = tensorflow::test::AsTensor( - gtl::ArraySlice<InputType>(values_quantized)); - const auto i_array = i_tensor.flat<InputType>(); - Tensor o_tensor_eigen(DataTypeToEnum<OutputType>::v(), - TensorShape{values_count}); - auto output_values_eigen = o_tensor_eigen.flat<OutputType>(); - Tensor o_tensor_ref(DataTypeToEnum<OutputType>::v(), + Tensor i_tensor = tensorflow::test::AsTensor( + gtl::ArraySlice<InputType>(values_quantized)); + const auto i_array = i_tensor.flat<InputType>(); + Tensor o_tensor_eigen(DataTypeToEnum<OutputType>::v(), TensorShape{values_count}); - auto output_values_ref = o_tensor_ref.flat<OutputType>(); + auto output_values_eigen = o_tensor_eigen.flat<OutputType>(); + Tensor o_tensor_ref(DataTypeToEnum<OutputType>::v(), + TensorShape{values_count}); + auto output_values_ref = o_tensor_ref.flat<OutputType>(); + + RequantizeManyInNewRange(i_array.data(), i_array.size(), input_min, + input_max, output_min, output_max, + output_values_ref.data()); + RequantizeManyInNewRangeUsingEigen<InputType, OutputType>( + eigen_device, i_tensor, input_min, input_max, output_min, output_max, + &o_tensor_eigen); + + const int tolerance = 1; + for (int i = 0; i < values_quantized.size(); ++i) { + auto expected = output_values_ref(i); + auto actual = output_values_eigen(i); + // The eigen computation uses float for constants and computation + // instead of doubles, so can be different by 1 or 2 in some cases + // (e.g., input value 144.062744140625, min -1, max 255, type quint8). + ASSERT_TRUE(std::abs(expected - actual) <= tolerance) + << "expected=" << expected << " actual=" << actual + << " tolerance=" << tolerance << " v=" << values_quantized[i] + << " i=" << i << " input_min=" << input_min + << " input_max=" << input_max + << " input_type=" << DataTypeString(DataTypeToEnum<InputType>::v()) + << " output_type=" << DataTypeString(DataTypeToEnum<OutputType>::v()); + } + } +} + +template <typename InputType, typename OutputType> +void TimeRequantizeManyInNewRange(int64 num_elements, int64 iterations, + bool use_eigen) { + const float input_min = -100.0f; + const float input_max = 100.0f; + const float output_min = -1000000.0f; + const float output_max = 1000000.0f; + + random::PhiloxRandom philox(testing::RandomSeed(), 17); + random::SimplePhilox rnd(&philox); + std::vector<InputType> values_quantized; + for (int i = 0; i < num_elements; ++i) { + float v = (rnd.RandFloat() * (input_max - input_min)) + input_min; + values_quantized.push_back( + FloatToQuantized<InputType>(v, input_min, input_max)); + } - RequantizeManyInNewRange(i_array.data(), i_array.size(), input_min, - input_max, output_min, output_max, - output_values_ref.data()); + thread::ThreadPool threadpool(Env::Default(), "test", 4 /* num_threads */); + EigenThreadPoolWrapper wrapper(&threadpool); + Eigen::ThreadPoolDevice eigen_device(&wrapper, 4 /* num_threads */); + + Tensor i_tensor = + tensorflow::test::AsTensor(gtl::ArraySlice<InputType>(values_quantized)); + const auto i_array = i_tensor.flat<InputType>(); + Tensor o_tensor_eigen(DataTypeToEnum<OutputType>::v(), + TensorShape{num_elements}); + Tensor o_tensor_ref(DataTypeToEnum<OutputType>::v(), + TensorShape{num_elements}); + auto output_values_ref = o_tensor_ref.flat<OutputType>(); + + int64 total_duration = 0; + for (int i = 0; i < iterations; ++i) { + const int64 start_time = Env::Default()->NowMicros(); + if (use_eigen) { RequantizeManyInNewRangeUsingEigen<InputType, OutputType>( eigen_device, i_tensor, input_min, input_max, output_min, output_max, &o_tensor_eigen); - - const int tolerance = 1; - for (int i = 0; i < values_quantized.size(); ++i) { - auto expected = output_values_ref(i); - auto actual = output_values_eigen(i); - // The eigen computation uses float for constants and computation - // instead of doubles, so can be different by 1 or 2 in some cases - // (e.g., input value 144.062744140625, min -1, max 255, type quint8). - ASSERT_TRUE(std::abs(expected - actual) <= tolerance) - << "expected=" << expected << " actual=" << actual - << " tolerance=" << tolerance << " v=" << values_quantized[i] - << " i=" << i << " input_min=" << input_min - << " input_max=" << input_max - << " input_type=" << DataTypeString(DataTypeToEnum<InputType>::v()) - << " output_type=" - << DataTypeString(DataTypeToEnum<OutputType>::v()); - } + } else { + RequantizeManyInNewRange<InputType, OutputType>( + i_array.data(), i_array.size(), input_min, input_max, output_min, + output_max, output_values_ref.data()); } + const int64 end_time = Env::Default()->NowMicros(); + total_duration += end_time - start_time; } + const int64 one_run_duration = total_duration / iterations; - template <typename T> - void TestFloatToQuantizedInPlaceUsingEigen( - Eigen::ThreadPoolDevice* eigen_device) { - // These are the float values we're going to test the conversions on. - typedef std::pair<float, float> FPair; - for (FPair min_and_max : std::vector<FPair>{FPair(-255.0f, 255.0f), // - FPair(-1.0f, 1.0f), // - FPair(-1.0f, 255.0f), // - FPair(0.0f, 1e6), // - FPair(0.0f, 1.0f), // - FPair(-31.0f, 13.0f)}) { - const float f_min = min_and_max.first; - const float f_max = min_and_max.second; - const float f_range = f_max - f_min; - const int values_count = 50000; - Tensor input(DT_FLOAT, TensorShape{values_count}); - auto input_array = input.flat<float>(); - for (int i = 0; i < values_count; ++i) { - input_array(i) = f_min + f_range * i / (values_count - 1); - } + const int64 num_ops = num_elements; - Tensor output(DataTypeToEnum<T>::v(), TensorShape{values_count}); - FloatTensorToQuantizedInPlaceUsingEigen<T>(*eigen_device, input, f_min, - f_max, &output); - auto output_array = output.flat<T>(); - - const int tolerance = 1; - for (int i = 0; i < values_count; ++i) { - int32 expected = FloatToQuantized<T>(input_array(i), f_min, f_max); - int32 actual = output_array(i); - - // The eigen computation uses float for constants and computation - // instead - // of doubles, so can be different by 1 or 2 in some cases (e.g., input - // value 144.062744140625, min -1, max 255, type quint8). - ASSERT_TRUE(std::abs(expected - actual) <= tolerance) - << "expected=" << expected << " actual=" << actual - << " tolerance=" << tolerance << " v=" << input_array(i) - << " i=" << i << " f_min=" << f_min << " f_max=" << f_max - << " type=" << DataTypeString(DataTypeToEnum<T>::v()); - } + const double million_ops_per_second = + (iterations * num_ops) / static_cast<double>(total_duration); + + LOG(INFO) << "TimeRequantizeManyInNewRange: " << num_elements + << (use_eigen ? " eigen" : " ref") << ": iterations=" << iterations + << ", MOps/s=" << million_ops_per_second + << ", one_run_duration=" << one_run_duration + << ", total_duration=" << total_duration; +} + +template <typename T> +void TestFloatToQuantizedInPlaceUsingEigen( + Eigen::ThreadPoolDevice* eigen_device) { + // These are the float values we're going to test the conversions on. + typedef std::pair<float, float> FPair; + for (FPair min_and_max : std::vector<FPair>{FPair(-255.0f, 255.0f), // + FPair(-1.0f, 1.0f), // + FPair(-1.0f, 255.0f), // + FPair(0.0f, 1e6), // + FPair(0.0f, 1.0f), // + FPair(-31.0f, 13.0f)}) { + const float f_min = min_and_max.first; + const float f_max = min_and_max.second; + const float f_range = f_max - f_min; + const int values_count = 50000; + Tensor input(DT_FLOAT, TensorShape{values_count}); + auto input_array = input.flat<float>(); + for (int i = 0; i < values_count; ++i) { + input_array(i) = f_min + f_range * i / (values_count - 1); + } + + Tensor output(DataTypeToEnum<T>::v(), TensorShape{values_count}); + FloatTensorToQuantizedInPlaceUsingEigen<T>(*eigen_device, input, f_min, + f_max, &output); + auto output_array = output.flat<T>(); + + const int tolerance = 1; + for (int i = 0; i < values_count; ++i) { + int32 expected = FloatToQuantized<T>(input_array(i), f_min, f_max); + int32 actual = output_array(i); + + // The eigen computation uses float for constants and computation + // instead + // of doubles, so can be different by 1 or 2 in some cases (e.g., input + // value 144.062744140625, min -1, max 255, type quint8). + ASSERT_TRUE(std::abs(expected - actual) <= tolerance) + << "expected=" << expected << " actual=" << actual + << " tolerance=" << tolerance << " v=" << input_array(i) << " i=" << i + << " f_min=" << f_min << " f_max=" << f_max + << " type=" << DataTypeString(DataTypeToEnum<T>::v()); } } +} - template <typename T> - void TestQuantizedToFloatInPlaceUsingEigen( - Eigen::ThreadPoolDevice* eigen_device) { - // These are the float values we're going to test the conversions on. - typedef std::pair<float, float> FPair; - for (FPair min_and_max : std::vector<FPair>{ - FPair(-255.0f, 255.0f), FPair(-1.0f, 1.0f), FPair(-1.0f, 255.0f), - FPair(0.0f, 1e6), FPair(0.0f, 1.0f), FPair(-31.0f, 13.0f), - FPair(-5.89505e+08, 5.89505e+08), - }) { - const float f_min = min_and_max.first; - const float f_max = min_and_max.second; - const int values_count = sizeof(T) == 1 ? 256 : 50000; - Tensor input(DataTypeToEnum<T>::v(), TensorShape{values_count}); - auto input_array = input.flat<T>(); - const double q_range = - static_cast<double>(Eigen::NumTraits<T>::highest()) - - Eigen::NumTraits<T>::lowest(); - for (int i = 0; i < values_count; ++i) { - if (sizeof(T) == 1) { - input_array(i) = Eigen::NumTraits<T>::lowest() + i; - } else { - int64 offset = static_cast<int64>(q_range / values_count * i); - input_array(i) = static_cast<int32>( - std::min<int64>(Eigen::NumTraits<T>::lowest() + offset, - Eigen::NumTraits<T>::highest())); - } +template <typename T> +void TestQuantizedToFloatInPlaceUsingEigen( + Eigen::ThreadPoolDevice* eigen_device) { + // These are the float values we're going to test the conversions on. + typedef std::pair<float, float> FPair; + for (FPair min_and_max : std::vector<FPair>{ + FPair(-255.0f, 255.0f), FPair(-1.0f, 1.0f), FPair(-1.0f, 255.0f), + FPair(0.0f, 1e6), FPair(0.0f, 1.0f), FPair(-31.0f, 13.0f), + FPair(-5.89505e+08, 5.89505e+08), + }) { + const float f_min = min_and_max.first; + const float f_max = min_and_max.second; + const int values_count = sizeof(T) == 1 ? 256 : 50000; + Tensor input(DataTypeToEnum<T>::v(), TensorShape{values_count}); + auto input_array = input.flat<T>(); + const double q_range = static_cast<double>(Eigen::NumTraits<T>::highest()) - + Eigen::NumTraits<T>::lowest(); + for (int i = 0; i < values_count; ++i) { + if (sizeof(T) == 1) { + input_array(i) = Eigen::NumTraits<T>::lowest() + i; + } else { + int64 offset = static_cast<int64>(q_range / values_count * i); + input_array(i) = static_cast<int32>( + std::min<int64>(Eigen::NumTraits<T>::lowest() + offset, + Eigen::NumTraits<T>::highest())); } + } - Tensor output(DT_FLOAT, TensorShape{values_count}); - QuantizedTensorToFloatInPlaceUsingEigen<T>(*eigen_device, input, f_min, - f_max, &output); - auto output_array = output.flat<float>(); - const double range = static_cast<double>(f_max) - f_min; - for (int i = 0; i < values_count; ++i) { - float expected = QuantizedToFloat<T>(input_array(i), f_min, f_max); - float actual = output_array(i); - ASSERT_NEAR(expected, actual, range * 1.1e-7) - << "expected=" << expected << " actual=" << actual - << " v=" << input_array(i) << " i=" << i << " f_min=" << f_min - << " f_max=" << f_max - << " type=" << DataTypeString(DataTypeToEnum<T>::v()); - } + Tensor output(DT_FLOAT, TensorShape{values_count}); + QuantizedTensorToFloatInPlaceUsingEigen<T>(*eigen_device, input, f_min, + f_max, &output); + auto output_array = output.flat<float>(); + const double range = static_cast<double>(f_max) - f_min; + for (int i = 0; i < values_count; ++i) { + float expected = QuantizedToFloat<T>(input_array(i), f_min, f_max); + float actual = output_array(i); + ASSERT_NEAR(expected, actual, range * 1.1e-7) + << "expected=" << expected << " actual=" << actual + << " v=" << input_array(i) << " i=" << i << " f_min=" << f_min + << " f_max=" << f_max + << " type=" << DataTypeString(DataTypeToEnum<T>::v()); } } -}; +} -TEST_F(QuantizationUtilsTest, FloatToQuantized) { +} // namespace + +void TestFloatToQuantized() { EXPECT_EQ(quint8(0), FloatToQuantized<quint8>(0.0f, 0.0f, 1.0f)); EXPECT_EQ(quint8(0), FloatToQuantized<quint8>(0.0f, 0.0f, 2.0f)); EXPECT_EQ(quint8(128), FloatToQuantized<quint8>(0.5f, 0.0f, 1.0f)); @@ -318,7 +448,7 @@ TEST_F(QuantizationUtilsTest, FloatToQuantized) { FloatToQuantized<qint32>(128.0f, -128.0f, 128.0f)); } -TEST_F(QuantizationUtilsTest, QuantizedToFloat) { +void TestQuantizedToFloat() { EXPECT_LT(fabsf(0.0f - QuantizedToFloat<quint8>(0, 0.0f, 1.0f)), 1 / 255.0f); EXPECT_LT(fabsf(0.0f - QuantizedToFloat<quint8>(0, 0.0f, 2.0f)), 1 / 255.0f); EXPECT_LT(fabsf(0.5f - QuantizedToFloat<quint8>(127, 0.0f, 1.0f)), @@ -349,7 +479,7 @@ TEST_F(QuantizationUtilsTest, QuantizedToFloat) { 1.0); } -TEST_F(QuantizationUtilsTest, AvoidBias) { +void TestAvoidBias() { for (int i = 0; i < 256; ++i) { const float as_float = QuantizedToFloat<quint8>(i, 0.0f, 2.0f); const int back_to_int = FloatToQuantized<quint8>(as_float, 0.0f, 2.0f); @@ -371,7 +501,7 @@ TEST_F(QuantizationUtilsTest, AvoidBias) { } } -TEST_F(QuantizationUtilsTest, RequantizeInNewRange) { +void TestRequantizeInNewRange() { // These are the float values we're going to test the conversions on. const size_t values_count = 6; const float values[values_count] = {0.0f, 0.5f, 1.0f, -1.0f, 127.0f, 255.0f}; @@ -407,7 +537,7 @@ TEST_F(QuantizationUtilsTest, RequantizeInNewRange) { } } -TEST_F(QuantizationUtilsTest, RequantizeInNewRangeRealData) { +void TestRequantizeInNewRangeRealData() { const float input_min = -0.739539f; const float input_max = 0.641057f; const float output_min = -2381.49f; @@ -428,7 +558,7 @@ TEST_F(QuantizationUtilsTest, RequantizeInNewRangeRealData) { EXPECT_LT(std::abs(value_as_qint32 - actual_output), 10); } -TEST_F(QuantizationUtilsTest, RequantizeInNewRange32To8Bit) { +void TestRequantizeInNewRange32To8Bit() { // These are the float values we're going to test the conversions on. const size_t values_count = 6; const float values[values_count] = {0.0f, 0.45f, 1.0f, -1.0f, 127.0f, 255.0f}; @@ -464,27 +594,26 @@ TEST_F(QuantizationUtilsTest, RequantizeInNewRange32To8Bit) { } } -TEST_F(QuantizationUtilsTest, RequantizeManyInNewRange32To8Bit) { +void TestRequantizeManyInNewRange32To8Bit() { TestRequantizeManyInNewRange32To8Bit(nullptr /* eigen_device */); } -TEST_F(QuantizationUtilsTest, RequantizeManyInNewRange32To8BitUsingEigen) { +void TestRequantizeManyInNewRange32To8BitUsingEigen() { thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */); EigenThreadPoolWrapper wrapper(&threadpool); Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */); TestRequantizeManyInNewRange32To8Bit(&eigen_device); } -TEST_F(QuantizationUtilsTest, RequantizeManyInNewRange32To8BitEigenVsNonEigen) { +void TestRequantizeManyInNewRange32To8BitEigenVsNonEigen() { TestRequantizeManyInNewRangeEigenVsNonEigen<qint32, quint8>(); } -TEST_F(QuantizationUtilsTest, - RequantizeManyInNewRange32To8BitSignedEigenVsNonEigen) { +void TestRequantizeManyInNewRange32To8BitSignedEigenVsNonEigen() { TestRequantizeManyInNewRangeEigenVsNonEigen<qint32, qint8>(); } -TEST_F(QuantizationUtilsTest, FloatTensorToQuantized) { +void TestFloatTensorToQuantized() { const int input_width = 3; const int input_height = 3; const float input_min = 0.0f; @@ -500,7 +629,7 @@ TEST_F(QuantizationUtilsTest, FloatTensorToQuantized) { // Verify that FloatToQuantizedInPlaceUsingEigen is same result as // FloatToQuantized. -TEST_F(QuantizationUtilsTest, FloatToQuantizedInPlaceUsingEigen) { +void TestFloatToQuantizedInPlaceUsingEigen() { thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */); EigenThreadPoolWrapper wrapper(&threadpool); Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */); @@ -511,7 +640,7 @@ TEST_F(QuantizationUtilsTest, FloatToQuantizedInPlaceUsingEigen) { TestFloatToQuantizedInPlaceUsingEigen<qint16>(&eigen_device); } -TEST_F(QuantizationUtilsTest, OverflowWithEigen) { +void TestOverflowWithEigen() { thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */); EigenThreadPoolWrapper wrapper(&threadpool); Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */); @@ -537,7 +666,7 @@ TEST_F(QuantizationUtilsTest, OverflowWithEigen) { test::ExpectTensorEqual<qint32>(expected, output); } -TEST_F(QuantizationUtilsTest, QuantizedTensorToFloat) { +void TestQuantizedTensorToFloat() { const int input_width = 3; const int input_height = 3; const float input_min = -128.0f; @@ -579,7 +708,7 @@ TEST_F(QuantizationUtilsTest, QuantizedTensorToFloat) { // Verify that QuantizedToFloatInPlaceUsingEigen is same result as // QuantizedToFloat. -TEST_F(QuantizationUtilsTest, QuantizedToFloatInPlaceUsingEigen) { +void TestQuantizedToFloatInPlaceUsingEigen() { thread::ThreadPool threadpool(Env::Default(), "test", 2 /* num_threads */); EigenThreadPoolWrapper wrapper(&threadpool); Eigen::ThreadPoolDevice eigen_device(&wrapper, 2 /* num_threads */); @@ -591,4 +720,56 @@ TEST_F(QuantizationUtilsTest, QuantizedToFloatInPlaceUsingEigen) { TestQuantizedToFloatInPlaceUsingEigen<qint32>(&eigen_device); } +void BenchmarkRequantizeManyInNewRange() { + TimeRequantizeManyInNewRange<qint32, quint8>(1000, 1000, false); + TimeRequantizeManyInNewRange<qint32, quint8>(1000, 1000, true); + TimeRequantizeManyInNewRange<qint32, quint8>(100000, 100, false); + TimeRequantizeManyInNewRange<qint32, quint8>(100000, 100, true); + TimeRequantizeManyInNewRange<qint32, quint8>(1000000, 10, false); + TimeRequantizeManyInNewRange<qint32, quint8>(1000000, 10, true); + + TimeRequantizeManyInNewRange<quint8, qint32>(1000, 1000, false); + TimeRequantizeManyInNewRange<quint8, qint32>(1000, 1000, true); + TimeRequantizeManyInNewRange<quint8, qint32>(100000, 100, false); + TimeRequantizeManyInNewRange<quint8, qint32>(100000, 100, true); + TimeRequantizeManyInNewRange<quint8, qint32>(1000000, 10, false); + TimeRequantizeManyInNewRange<quint8, qint32>(1000000, 10, true); +} + } // namespace tensorflow + +#if defined(__ANDROID__) +int main(int argc, char** argv) { +#define RUN_TEST(t) \ + LOG(INFO) << "Test: " << #t; \ + tensorflow::t(); +#else +#define RUN_TEST(t) \ + TEST(QuantizationUtilsTest, t) { tensorflow::t(); } +#endif + + RUN_TEST(TestFloatToQuantized); + RUN_TEST(TestQuantizedToFloat); + RUN_TEST(TestAvoidBias); + RUN_TEST(TestRequantizeInNewRange); + RUN_TEST(TestRequantizeInNewRangeRealData); + RUN_TEST(TestRequantizeInNewRange32To8Bit); + RUN_TEST(TestRequantizeManyInNewRange32To8Bit); + RUN_TEST(TestRequantizeManyInNewRange32To8BitUsingEigen); + RUN_TEST(TestRequantizeManyInNewRange32To8BitEigenVsNonEigen); + RUN_TEST(TestRequantizeManyInNewRange32To8BitSignedEigenVsNonEigen); + RUN_TEST(TestFloatTensorToQuantized); + RUN_TEST(TestRequantizeManyInNewRange8To32Bit); + RUN_TEST(TestFloatToQuantizedInPlaceUsingEigen); + RUN_TEST(TestOverflowWithEigen); + RUN_TEST(TestQuantizedTensorToFloat); + RUN_TEST(TestQuantizedToFloatInPlaceUsingEigen); + +#if defined(__ANDROID__) + + tensorflow::BenchmarkRequantizeManyInNewRange(); + + LOG(INFO) << "All tests complete."; + return 0; +} +#endif diff --git a/tensorflow/core/kernels/quantized_add_op.cc b/tensorflow/core/kernels/quantized_add_op.cc new file mode 100644 index 0000000000..8be0c56798 --- /dev/null +++ b/tensorflow/core/kernels/quantized_add_op.cc @@ -0,0 +1,581 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implements a quantized eight-bit version of the matmul operation. + +#define EIGEN_USE_THREADS + +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define USE_NEON +#define QUANTIZED_ADD_USE_NEON +#include <arm_neon.h> +#endif + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/meta_support.h" +#include "tensorflow/core/kernels/quantization_utils.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/bcast.h" + +// There are implementations for three broadcast patterns for add: +// - Scalar * Array +// - Array * Array +// - Array * Shorter Array (repeated to match first) +// +// These handle a lot of common broadcast patterns, and we have NEON SIMD +// versions to accelerate performance on ARM platforms. + +namespace tensorflow { +namespace { + +template <class T, class Toutput> +void ScalarAddition(OpKernelContext* context, const T* full_input, + float full_input_min, float full_input_max, + int64 num_elements, T scalar_input, float scalar_input_min, + float scalar_input_max, float output_min, float output_max, + Toutput* output) { + const Toutput scalar_in_output_range = RequantizeInNewRange<T, Toutput>( + scalar_input, scalar_input_min, scalar_input_max, output_min, output_max); + for (int i = 0; i < num_elements; ++i) { + const Toutput full_input_in_output_range = RequantizeInNewRange<T, Toutput>( + full_input[i], full_input_min, full_input_max, output_min, output_max); + output[i] = full_input_in_output_range + scalar_in_output_range; + } +} + +#ifdef QUANTIZED_ADD_USE_NEON + +template <> +void ScalarAddition(OpKernelContext* context, const quint8* full_input, + float full_input_min, float full_input_max, + int64 num_elements, quint8 scalar_input, + float scalar_input_min, float scalar_input_max, + float output_min, float output_max, qint32* output) { + const int32 scalar_in_output_range = RequantizeInNewRange<quint8, qint32>( + scalar_input, scalar_input_min, scalar_input_max, output_min, output_max); + + const float input_0_float = + QuantizedToFloat<quint8>(0, full_input_min, full_input_max); + const float input_1_float = + QuantizedToFloat<quint8>(1, full_input_min, full_input_max); + const int64 input_0_int64 = + FloatToQuantizedUnclamped<qint32>(input_0_float, output_min, output_max); + const int64 input_1_int64 = + FloatToQuantizedUnclamped<qint32>(input_1_float, output_min, output_max); + const int32 input_mult_int32 = input_1_int64 - input_0_int64; + + const int64 lowest_quantized = + static_cast<int64>(Eigen::NumTraits<qint32>::lowest()); + const int64 highest_quantized = + static_cast<int64>(Eigen::NumTraits<qint32>::highest()); + + const int64x2_t input_0_64x2 = vmovq_n_s64(input_0_int64); + const int32x2_t input_mult_32x2 = vmov_n_s32(input_mult_int32); + const int32x4_t scalar_in_output_range_32x4 = + vmovq_n_s32(scalar_in_output_range); + int64 i = 0; + for (; i < (num_elements - 7); i += 8) { + const uint8* full_input_ptr = &(full_input->value) + i; + const std::array<int32x4_t, 2> output_value = + Requantize8x8To32Neon(full_input_ptr, input_0_64x2, input_mult_32x2); + const int32x4_t result_low_32x4 = + vaddq_s32(output_value[0], scalar_in_output_range_32x4); + const int32x4_t result_high_32x4 = + vaddq_s32(output_value[1], scalar_in_output_range_32x4); + int32* output_ptr = &(output->value) + i; + vst1q_s32(output_ptr + 0, result_low_32x4); + vst1q_s32(output_ptr + 4, result_high_32x4); + } + for (; i < num_elements; ++i) { + const int64 full_input_value = static_cast<int64>(full_input[i]); + int64 full_input_in_output_range_64 = + input_0_int64 + (full_input_value * input_mult_int32); + full_input_in_output_range_64 = + std::max(full_input_in_output_range_64, lowest_quantized); + full_input_in_output_range_64 = + std::min(full_input_in_output_range_64, highest_quantized); + const int32 full_input_in_output_range = + static_cast<int32>(full_input_in_output_range_64); + output[i] = full_input_in_output_range + scalar_in_output_range; + } +} + +#else // QUANTIZED_ADD_USE_NEON + +template <> +void ScalarAddition(OpKernelContext* context, const quint8* full_input, + float full_input_min, float full_input_max, + int64 num_elements, quint8 scalar_input, + float scalar_input_min, float scalar_input_max, + float output_min, float output_max, qint32* output) { + const int32 scalar_in_output_range = RequantizeInNewRange<quint8, qint32>( + scalar_input, scalar_input_min, scalar_input_max, output_min, output_max); + + const float input_0_float = + QuantizedToFloat<quint8>(0, full_input_min, full_input_max); + const float input_1_float = + QuantizedToFloat<quint8>(1, full_input_min, full_input_max); + const int64 input_0_int64 = + FloatToQuantizedUnclamped<qint32>(input_0_float, output_min, output_max); + const int64 input_1_int64 = + FloatToQuantizedUnclamped<qint32>(input_1_float, output_min, output_max); + const int32 input_mult_int32 = input_1_int64 - input_0_int64; + + const int64 lowest_quantized = + static_cast<int64>(Eigen::NumTraits<qint32>::lowest()); + const int64 highest_quantized = + static_cast<int64>(Eigen::NumTraits<qint32>::highest()); + + for (int i = 0; i < num_elements; ++i) { + const int64 full_input_value = static_cast<int64>(full_input[i]); + int64 full_input_in_output_range_64 = + input_0_int64 + (full_input_value * input_mult_int32); + full_input_in_output_range_64 = + std::max(full_input_in_output_range_64, lowest_quantized); + full_input_in_output_range_64 = + std::min(full_input_in_output_range_64, highest_quantized); + const int32 full_input_in_output_range = + static_cast<int32>(full_input_in_output_range_64); + output[i] = full_input_in_output_range + scalar_in_output_range; + } +} + +#endif // QUANTIZED_ADD_USE_NEON + +template <class T, class Toutput> +void VectorAddition(OpKernelContext* context, const T* x_data, float min_x, + float max_x, const T* y_data, float min_y, float max_y, + int64 num_elements, float output_min, float output_max, + Toutput* output) { + for (int i = 0; i < num_elements; ++i) { + const Toutput x_in_output_range = RequantizeInNewRange<T, Toutput>( + x_data[i], min_x, max_x, output_min, output_max); + const Toutput y_in_output_range = RequantizeInNewRange<T, Toutput>( + y_data[i], min_y, max_y, output_min, output_max); + output[i] = x_in_output_range + y_in_output_range; + } +} + +#ifdef QUANTIZED_ADD_USE_NEON + +template <> +void VectorAddition(OpKernelContext* context, const quint8* x_data, float min_x, + float max_x, const quint8* y_data, float min_y, float max_y, + int64 num_elements, float output_min, float output_max, + qint32* output) { + const float x_0_float = QuantizedToFloat<quint8>(0, min_x, max_x); + const float x_1_float = QuantizedToFloat<quint8>(1, min_x, max_x); + const int64 x_0_int64 = + FloatToQuantizedUnclamped<qint32>(x_0_float, output_min, output_max); + const int64 x_1_int64 = + FloatToQuantizedUnclamped<qint32>(x_1_float, output_min, output_max); + const int32 x_mult_int32 = x_1_int64 - x_0_int64; + + const float y_0_float = QuantizedToFloat<quint8>(0, min_y, max_y); + const float y_1_float = QuantizedToFloat<quint8>(1, min_y, max_y); + const int64 y_0_int64 = + FloatToQuantizedUnclamped<qint32>(y_0_float, output_min, output_max); + const int64 y_1_int64 = + FloatToQuantizedUnclamped<qint32>(y_1_float, output_min, output_max); + const int32 y_mult_int32 = y_1_int64 - y_0_int64; + + const int64 lowest_quantized = + static_cast<int64>(Eigen::NumTraits<qint32>::lowest()); + const int64 highest_quantized = + static_cast<int64>(Eigen::NumTraits<qint32>::highest()); + + const int64x2_t x_0_64x2 = vmovq_n_s64(x_0_int64); + const int32x2_t x_mult_32x2 = vmov_n_s32(x_mult_int32); + + const int64x2_t y_0_64x2 = vmovq_n_s64(y_0_int64); + const int32x2_t y_mult_32x2 = vmov_n_s32(y_mult_int32); + + int64 i = 0; + for (; i < (num_elements - 7); i += 8) { + const uint8* x_ptr = &(x_data->value) + i; + const std::array<int32x4_t, 2> x_output_value = + Requantize8x8To32Neon(x_ptr, x_0_64x2, x_mult_32x2); + const uint8* y_ptr = &(y_data->value) + i; + const std::array<int32x4_t, 2> y_output_value = + Requantize8x8To32Neon(y_ptr, y_0_64x2, y_mult_32x2); + + const int32x4_t result_low_32x4 = + vaddq_s32(x_output_value[0], y_output_value[0]); + const int32x4_t result_high_32x4 = + vaddq_s32(x_output_value[1], y_output_value[1]); + int32* output_ptr = &(output->value) + i; + vst1q_s32(output_ptr + 0, result_low_32x4); + vst1q_s32(output_ptr + 4, result_high_32x4); + } + + for (; i < num_elements; ++i) { + const int64 x_value = static_cast<int64>(x_data[i]); + int64 x_in_output_range_64 = x_0_int64 + (x_value * x_mult_int32); + x_in_output_range_64 = std::max(x_in_output_range_64, lowest_quantized); + x_in_output_range_64 = std::min(x_in_output_range_64, highest_quantized); + const int32 x_in_output_range = static_cast<int32>(x_in_output_range_64); + + const int64 y_value = static_cast<int64>(y_data[i]); + int64 y_in_output_range_64 = y_0_int64 + (y_value * y_mult_int32); + y_in_output_range_64 = std::max(y_in_output_range_64, lowest_quantized); + y_in_output_range_64 = std::min(y_in_output_range_64, highest_quantized); + const int32 y_in_output_range = static_cast<int32>(y_in_output_range_64); + + output[i] = x_in_output_range + y_in_output_range; + } +} + +#else // QUANTIZED_ADD_USE_NEON + +template <> +void VectorAddition(OpKernelContext* context, const quint8* x_data, float min_x, + float max_x, const quint8* y_data, float min_y, float max_y, + int64 num_elements, float output_min, float output_max, + qint32* output) { + const float x_0_float = QuantizedToFloat<quint8>(0, min_x, max_x); + const float x_1_float = QuantizedToFloat<quint8>(1, min_x, max_x); + const int64 x_0_int64 = + FloatToQuantizedUnclamped<qint32>(x_0_float, output_min, output_max); + const int64 x_1_int64 = + FloatToQuantizedUnclamped<qint32>(x_1_float, output_min, output_max); + const int32 x_mult_int32 = x_1_int64 - x_0_int64; + + const float y_0_float = QuantizedToFloat<quint8>(0, min_y, max_y); + const float y_1_float = QuantizedToFloat<quint8>(1, min_y, max_y); + const int64 y_0_int64 = + FloatToQuantizedUnclamped<qint32>(y_0_float, output_min, output_max); + const int64 y_1_int64 = + FloatToQuantizedUnclamped<qint32>(y_1_float, output_min, output_max); + const int32 y_mult_int32 = y_1_int64 - y_0_int64; + + const int64 lowest_quantized = + static_cast<int64>(Eigen::NumTraits<qint32>::lowest()); + const int64 highest_quantized = + static_cast<int64>(Eigen::NumTraits<qint32>::highest()); + + for (int i = 0; i < num_elements; ++i) { + const int64 x_value = static_cast<int64>(x_data[i]); + int64 x_in_output_range_64 = x_0_int64 + (x_value * x_mult_int32); + x_in_output_range_64 = std::max(x_in_output_range_64, lowest_quantized); + x_in_output_range_64 = std::min(x_in_output_range_64, highest_quantized); + const int32 x_in_output_range = static_cast<int32>(x_in_output_range_64); + + const int64 y_value = static_cast<int64>(y_data[i]); + int64 y_in_output_range_64 = y_0_int64 + (y_value * y_mult_int32); + y_in_output_range_64 = std::max(y_in_output_range_64, lowest_quantized); + y_in_output_range_64 = std::min(y_in_output_range_64, highest_quantized); + const int32 y_in_output_range = static_cast<int32>(y_in_output_range_64); + + output[i] = x_in_output_range + y_in_output_range; + } +} + +#endif // QUANTIZED_ADD_USE_NEON + +template <class T, class Toutput> +void VectorTensorAddition(const T* vector_data, float min_vector, + float max_vector, int64 vector_num_elements, + const T* tensor_data, float min_tensor, + float max_tensor, int64 tensor_num_elements, + float output_min, float output_max, Toutput* output) { + for (int i = 0; i < tensor_num_elements; ++i) { + const int64 vector_i = i % vector_num_elements; + const Toutput vector_in_output_range = RequantizeInNewRange<T, Toutput>( + vector_data[vector_i], min_vector, max_vector, output_min, output_max); + const Toutput tensor_in_output_range = RequantizeInNewRange<T, Toutput>( + tensor_data[i], min_tensor, max_tensor, output_min, output_max); + output[i] = vector_in_output_range + tensor_in_output_range; + } +} + +#ifdef QUANTIZED_ADD_USE_NEON + +template <> +void VectorTensorAddition(const quint8* vector_data, float min_vector, + float max_vector, int64 vector_num_elements, + const quint8* tensor_data, float min_tensor, + float max_tensor, int64 tensor_num_elements, + float output_min, float output_max, qint32* output) { + const float vector_0_float = + QuantizedToFloat<quint8>(0, min_vector, max_vector); + const float vector_1_float = + QuantizedToFloat<quint8>(1, min_vector, max_vector); + const int64 vector_0_int64 = + FloatToQuantizedUnclamped<qint32>(vector_0_float, output_min, output_max); + const int64 vector_1_int64 = + FloatToQuantizedUnclamped<qint32>(vector_1_float, output_min, output_max); + const int32 vector_mult_int32 = vector_1_int64 - vector_0_int64; + + const float tensor_0_float = + QuantizedToFloat<quint8>(0, min_tensor, max_tensor); + const float tensor_1_float = + QuantizedToFloat<quint8>(1, min_tensor, max_tensor); + const int64 tensor_0_int64 = + FloatToQuantizedUnclamped<qint32>(tensor_0_float, output_min, output_max); + const int64 tensor_1_int64 = + FloatToQuantizedUnclamped<qint32>(tensor_1_float, output_min, output_max); + const int32 tensor_mult_int32 = tensor_1_int64 - tensor_0_int64; + + const int64 lowest_quantized = + static_cast<int64>(Eigen::NumTraits<qint32>::lowest()); + const int64 highest_quantized = + static_cast<int64>(Eigen::NumTraits<qint32>::highest()); + + const int64x2_t vector_0_64x2 = vmovq_n_s64(vector_0_int64); + const int32x2_t vector_mult_32x2 = vmov_n_s32(vector_mult_int32); + + const int64x2_t tensor_0_64x2 = vmovq_n_s64(tensor_0_int64); + const int32x2_t tensor_mult_32x2 = vmov_n_s32(tensor_mult_int32); + + for (int64 base_i = 0; base_i < tensor_num_elements; + base_i += vector_num_elements) { + int64 i = base_i; + int64 vector_i = 0; + for (; vector_i < (vector_num_elements - 7); vector_i += 8, i += 8) { + const uint8* vector_ptr = &(vector_data->value) + vector_i; + const std::array<int32x4_t, 2> vector_output_value = + Requantize8x8To32Neon(vector_ptr, vector_0_64x2, vector_mult_32x2); + const uint8* tensor_ptr = &(tensor_data->value) + i; + const std::array<int32x4_t, 2> tensor_output_value = + Requantize8x8To32Neon(tensor_ptr, tensor_0_64x2, tensor_mult_32x2); + + const int32x4_t result_low_32x4 = + vaddq_s32(vector_output_value[0], tensor_output_value[0]); + const int32x4_t result_high_32x4 = + vaddq_s32(vector_output_value[1], tensor_output_value[1]); + int32* output_ptr = &(output->value) + i; + vst1q_s32(output_ptr + 0, result_low_32x4); + vst1q_s32(output_ptr + 4, result_high_32x4); + } + for (; vector_i < vector_num_elements; ++vector_i, ++i) { + const int64 vector_value = static_cast<int64>(vector_data[vector_i]); + int64 vector_in_output_range_64 = + vector_0_int64 + (vector_value * vector_mult_int32); + vector_in_output_range_64 = + std::max(vector_in_output_range_64, lowest_quantized); + vector_in_output_range_64 = + std::min(vector_in_output_range_64, highest_quantized); + const int32 vector_in_output_range = + static_cast<int32>(vector_in_output_range_64); + + const int64 tensor_value = static_cast<int64>(tensor_data[i]); + int64 tensor_in_output_range_64 = + tensor_0_int64 + (tensor_value * tensor_mult_int32); + tensor_in_output_range_64 = + std::max(tensor_in_output_range_64, lowest_quantized); + tensor_in_output_range_64 = + std::min(tensor_in_output_range_64, highest_quantized); + const int32 tensor_in_output_range = + static_cast<int32>(tensor_in_output_range_64); + + output[i] = vector_in_output_range + tensor_in_output_range; + } + } +} + +#else // QUANTIZED_ADD_USE_NEON + +template <> +void VectorTensorAddition(const quint8* vector_data, float min_vector, + float max_vector, int64 vector_num_elements, + const quint8* tensor_data, float min_tensor, + float max_tensor, int64 tensor_num_elements, + float output_min, float output_max, qint32* output) { + const float vector_0_float = + QuantizedToFloat<quint8>(0, min_vector, max_vector); + const float vector_1_float = + QuantizedToFloat<quint8>(1, min_vector, max_vector); + const int64 vector_0_int64 = + FloatToQuantizedUnclamped<qint32>(vector_0_float, output_min, output_max); + const int64 vector_1_int64 = + FloatToQuantizedUnclamped<qint32>(vector_1_float, output_min, output_max); + const int32 vector_mult_int32 = vector_1_int64 - vector_0_int64; + + const float tensor_0_float = + QuantizedToFloat<quint8>(0, min_tensor, max_tensor); + const float tensor_1_float = + QuantizedToFloat<quint8>(1, min_tensor, max_tensor); + const int64 tensor_0_int64 = + FloatToQuantizedUnclamped<qint32>(tensor_0_float, output_min, output_max); + const int64 tensor_1_int64 = + FloatToQuantizedUnclamped<qint32>(tensor_1_float, output_min, output_max); + const int32 tensor_mult_int32 = tensor_1_int64 - tensor_0_int64; + + const int64 lowest_quantized = + static_cast<int64>(Eigen::NumTraits<qint32>::lowest()); + const int64 highest_quantized = + static_cast<int64>(Eigen::NumTraits<qint32>::highest()); + + for (int i = 0; i < tensor_num_elements; ++i) { + const int64 vector_i = i % vector_num_elements; + const int64 vector_value = static_cast<int64>(vector_data[vector_i]); + int64 vector_in_output_range_64 = + vector_0_int64 + (vector_value * vector_mult_int32); + vector_in_output_range_64 = + std::max(vector_in_output_range_64, lowest_quantized); + vector_in_output_range_64 = + std::min(vector_in_output_range_64, highest_quantized); + const int32 vector_in_output_range = + static_cast<int32>(vector_in_output_range_64); + + const int64 tensor_value = static_cast<int64>(tensor_data[i]); + int64 tensor_in_output_range_64 = + tensor_0_int64 + (tensor_value * tensor_mult_int32); + tensor_in_output_range_64 = + std::max(tensor_in_output_range_64, lowest_quantized); + tensor_in_output_range_64 = + std::min(tensor_in_output_range_64, highest_quantized); + const int32 tensor_in_output_range = + static_cast<int32>(tensor_in_output_range_64); + + output[i] = vector_in_output_range + tensor_in_output_range; + } +} + +#endif // QUANTIZED_ADD_USE_NEON + +} // namespace + +template <class T, class Toutput> +class QuantizedAddOp : public OpKernel { + public: + explicit QuantizedAddOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& x = context->input(0); + const Tensor& y = context->input(1); + const float min_x = context->input(2).flat<float>()(0); + const float max_x = context->input(3).flat<float>()(0); + const float min_y = context->input(4).flat<float>()(0); + const float max_y = context->input(5).flat<float>()(0); + + BCast bcast(BCast::FromShape(x.shape()), BCast::FromShape(y.shape())); + if (!bcast.IsValid()) { + context->SetStatus(errors::InvalidArgument( + "Incompatible shapes: ", x.shape().DebugString(), " vs. ", + y.shape().DebugString())); + return; + } + Tensor* z; + OP_REQUIRES_OK(context, context->allocate_output( + 0, BCast::ToShape(bcast.output_shape()), &z)); + + // Make sure that we have valid quantization ranges for the input buffers. + // If the difference between the min and max is negative or zero, it makes + // it hard to do meaningful intermediate operations on the values. + OP_REQUIRES(context, (max_x > min_x), + errors::InvalidArgument("max_x must be larger than min_x.")); + OP_REQUIRES(context, (max_y > min_y), + errors::InvalidArgument("max_y must be larger than min_y.")); + const T* x_data = x.flat<T>().data(); + const T* y_data = y.flat<T>().data(); + Toutput* z_data = z->flat<Toutput>().data(); + + // We want the range of the output to be symmetrical around zero so that + // adding zero leaves the result unchanged, and to contain the largest of + // the two input values with some room to spare. + const float smallest_min = std::min(min_x, min_y); + const float largest_max = std::min(max_x, max_y); + const float biggest_range = + std::max(std::abs(smallest_min), std::abs(largest_max)); + const float output_range = (biggest_range * (1 << 14)); + const float min_z_value = -output_range; + const float max_z_value = output_range; + + const int ndims = bcast.x_reshape().size(); + if (ndims <= 1) { + if (x.NumElements() == 1) { + ScalarAddition<T, Toutput>(context, y_data, min_y, max_y, + y.NumElements(), x_data[0], min_x, max_x, + min_z_value, max_z_value, z_data); + } else if (y.NumElements() == 1) { + ScalarAddition<T, Toutput>(context, x_data, min_x, max_x, + x.NumElements(), y_data[0], min_y, max_y, + min_z_value, max_z_value, z_data); + } else { + VectorAddition<T, Toutput>(context, x_data, min_x, max_x, y_data, min_y, + max_y, x.NumElements(), min_z_value, + max_z_value, z_data); + } + } else if (ndims == 2) { + const T* vector_data; + int64 vector_num_elements; + float vector_min; + float vector_max; + const T* tensor_data; + int64 tensor_num_elements; + float tensor_min; + float tensor_max; + if (x.NumElements() < y.NumElements()) { + vector_data = x_data; + vector_num_elements = x.NumElements(); + vector_min = min_x; + vector_max = max_x; + tensor_data = y_data; + tensor_num_elements = y.NumElements(); + tensor_min = min_y; + tensor_max = max_y; + } else { + vector_data = y_data; + vector_num_elements = y.NumElements(); + vector_min = min_y; + vector_max = max_y; + tensor_data = x_data; + tensor_num_elements = x.NumElements(); + tensor_min = min_x; + tensor_max = max_x; + } + VectorTensorAddition<T, Toutput>( + vector_data, vector_min, vector_max, vector_num_elements, tensor_data, + tensor_min, tensor_max, tensor_num_elements, min_z_value, max_z_value, + z_data); + } else { + LOG(INFO) << "ndims=" << ndims; + LOG(INFO) << "bcast.x_reshape()=" + << TensorShape(bcast.x_reshape()).DebugString(); + LOG(INFO) << "bcast.y_reshape()=" + << TensorShape(bcast.y_reshape()).DebugString(); + LOG(INFO) << "bcast.x_bcast()=" + << TensorShape(bcast.x_bcast()).DebugString(); + LOG(INFO) << "bcast.y_bcast()=" + << TensorShape(bcast.y_bcast()).DebugString(); + + context->SetStatus(errors::Unimplemented( + "Broadcast between ", context->input(0).shape().DebugString(), + " and ", context->input(1).shape().DebugString(), + " is not supported yet.")); + return; + } + + Tensor* z_min = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, {}, &z_min)); + z_min->flat<float>()(0) = min_z_value; + + Tensor* z_max = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(2, {}, &z_max)); + z_max->flat<float>()(0) = max_z_value; + } +}; + +REGISTER_KERNEL_BUILDER(Name("QuantizedAdd") + .Device(DEVICE_CPU) + .TypeConstraint<quint8>("T1") + .TypeConstraint<quint8>("T2") + .TypeConstraint<qint32>("Toutput"), + QuantizedAddOp<quint8, qint32>); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/quantized_add_op_test.cc b/tensorflow/core/kernels/quantized_add_op_test.cc new file mode 100644 index 0000000000..74d16b282d --- /dev/null +++ b/tensorflow/core/kernels/quantized_add_op_test.cc @@ -0,0 +1,311 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include <functional> +#include <memory> +#include <vector> + +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/math_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/quantization_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +using namespace ops; // NOLINT(build/namespaces) + +namespace { + +void TestAdd(const std::vector<int64>& x_shape, + const std::vector<float>& x_values, float x_min_value, + float x_max_value, const std::vector<int64>& y_shape, + const std::vector<float>& y_values, float y_min_value, + float y_max_value, const std::vector<int64>& expected_shape, + const std::vector<float>& expected_values, double tolerance) { + Scope root = Scope::NewRootScope(); + + Tensor x_float_tensor(DT_FLOAT, TensorShape(x_shape)); + test::FillValues<float>(&x_float_tensor, x_values); + Tensor x_quantized_tensor(DT_QUINT8, x_float_tensor.shape()); + FloatTensorToQuantizedInPlace<quint8>(x_float_tensor, x_min_value, + x_max_value, &x_quantized_tensor); + Output x = + Const(root.WithOpName("x"), Input::Initializer(x_quantized_tensor)); + Output x_min = Const(root.WithOpName("x_min"), x_min_value); + Output x_max = Const(root.WithOpName("x_max"), x_max_value); + + Tensor y_float_tensor(DT_FLOAT, TensorShape(y_shape)); + test::FillValues<float>(&y_float_tensor, y_values); + Tensor y_quantized_tensor(DT_QUINT8, y_float_tensor.shape()); + FloatTensorToQuantizedInPlace<quint8>(y_float_tensor, y_min_value, + y_max_value, &y_quantized_tensor); + Output y = + Const(root.WithOpName("y"), Input::Initializer(y_quantized_tensor)); + Output y_min = Const(root.WithOpName("y_min"), y_min_value); + Output y_max = Const(root.WithOpName("y_max"), y_max_value); + + ops::QuantizedAdd add = ops::QuantizedAdd(root.WithOpName("add"), x, y, x_min, + x_max, y_min, y_max); + + TF_EXPECT_OK(root.status()); + + ClientSession session(root); + std::vector<Tensor> outputs; + + TF_EXPECT_OK(session.Run(ClientSession::FeedType(), + {add.z, add.min_z, add.max_z}, &outputs)); + + const Tensor& z_quantized = outputs[0]; + const float z_min = outputs[1].flat<float>()(0); + const float z_max = outputs[2].flat<float>()(0); + + Tensor z_float = QuantizedTensorToFloat<qint32>(z_quantized, z_min, z_max); + Tensor expected_z_float(DT_FLOAT, TensorShape(expected_shape)); + test::FillValues<float>(&expected_z_float, expected_values); + test::ExpectTensorNear<float>(expected_z_float, z_float, tolerance); +} + +void TestAddShape(const std::vector<int64>& x_shape, + const std::vector<int64>& y_shape) { + const size_t x_num_elements = TensorShape(x_shape).num_elements(); + std::vector<float> x_values(x_num_elements); + for (int i = 0; i < x_num_elements; ++i) { + x_values[i] = i % 256; + } + const float x_min_value = 0.0f; + const float x_max_value = 256.0f; + + const size_t y_num_elements = TensorShape(y_shape).num_elements(); + std::vector<float> y_values(y_num_elements); + for (int i = 0; i < y_num_elements; ++i) { + y_values[i] = ((i + 23) % 123) - 50; + } + const float y_min_value = -150.0f; + const float y_max_value = 150.0f; + + Scope root = Scope::NewRootScope(); + + Tensor x_float_tensor(DT_FLOAT, TensorShape(x_shape)); + test::FillValues<float>(&x_float_tensor, x_values); + Output x = Const(root.WithOpName("x"), Input::Initializer(x_float_tensor)); + + Tensor y_float_tensor(DT_FLOAT, TensorShape(y_shape)); + test::FillValues<float>(&y_float_tensor, y_values); + Output y = Const(root.WithOpName("y"), Input::Initializer(y_float_tensor)); + + Add add = Add(root.WithOpName("add"), x, y); + + TF_EXPECT_OK(root.status()); + + ClientSession session(root); + std::vector<Tensor> outputs; + TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {add.z}, &outputs)); + + const Tensor& expected_values_tensor = outputs[0]; + const float* expected_values_data = + expected_values_tensor.flat<float>().data(); + std::vector<float> expected_values( + expected_values_data, + expected_values_data + expected_values_tensor.NumElements()); + std::vector<int64> expected_shape; + for (const int64 dim : expected_values_tensor.shape().dim_sizes()) { + expected_shape.push_back(dim); + } + TestAdd(x_shape, x_values, x_min_value, x_max_value, y_shape, y_values, + y_min_value, y_max_value, expected_shape, expected_values, 256.0); +} + +void TimeAdd(const std::vector<int64>& x_shape, + const std::vector<int64>& y_shape, int64 iterations) { + TestAddShape(x_shape, y_shape); + + Scope root = Scope::NewRootScope(); + + Tensor x_quantized_tensor(DT_QUINT8, TensorShape(x_shape)); + Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_QUINT8); + Output x_min = Const(root.WithOpName("x_min"), 0.0f); + Output x_max = Const(root.WithOpName("x_max"), 1.0f); + + Tensor y_quantized_tensor(DT_QUINT8, TensorShape(y_shape)); + Output y = + Const(root.WithOpName("y"), Input::Initializer(y_quantized_tensor)); + Output y_min = Const(root.WithOpName("y_min"), 0.0f); + Output y_max = Const(root.WithOpName("y_max"), 1.0f); + + ops::QuantizedAdd add = ops::QuantizedAdd(root.WithOpName("add"), placeholder, + y, x_min, x_max, y_min, y_max); + + TF_EXPECT_OK(root.status()); + + ClientSession session(root); + std::vector<Tensor> outputs; + + int64 total_duration = 0; + for (int i = 0; i < iterations; ++i) { + const int64 start_time = Env::Default()->NowMicros(); + TF_EXPECT_OK(session.Run({{placeholder, x_quantized_tensor}}, + {add.z, add.min_z, add.max_z}, &outputs)); + const int64 end_time = Env::Default()->NowMicros(); + total_duration += end_time - start_time; + } + const int64 one_run_duration = total_duration / iterations; + + const int64 num_ops = outputs[0].NumElements(); + + const double million_ops_per_second = + (iterations * num_ops) / static_cast<double>(total_duration); + + LOG(INFO) << "TimeAdd: " << TensorShape(x_shape).DebugString() << " * " + << TensorShape(y_shape).DebugString() + << ": iterations=" << iterations + << ", MOps/s=" << million_ops_per_second + << ", one_run_duration=" << one_run_duration + << ", total_duration=" << total_duration; +} + +} // namespace + +void TestManualScalar() { + TestAdd( + {10}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}, 0.0f, + 10.0f, {1}, {10.0f}, -100.0f, 100.0f, {10}, + {11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f}, + 1.0f); + TestAdd( + {1}, {10.0f}, -100.0f, 100.0f, {10}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}, 0.0f, + 10.0f, {10}, + {11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f}, + 1.0f); +} + +void TestScalar() { + TestAddShape({100}, {1}); + TestAddShape({1}, {100}); +} + +void TestManualVector() { + TestAdd({10}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}, + 0.0f, 10.0f, {10}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}, 0.0f, + 10.0f, {10}, + {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f, 16.0f, 18.0f, 20.0f}, + 1.0f); +} + +void TestVector() { TestAddShape({100}, {100}); } + +void TestManualVectorPlusTensor() { + TestAdd( + {10}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}, 0.0f, + 10.0f, {2, 10}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f}, + 0.0f, 20.0f, {2, 10}, + {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f, 16.0f, 18.0f, 20.0f, + 12.0f, 14.0f, 16.0f, 18.0f, 20.0f, 22.0f, 24.0f, 26.0f, 28.0f, 30.0f}, + 1.0f); + TestAdd({2, 10}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, + 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f}, + 0.0f, 20.0f, {10}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}, 0.0f, + 10.0f, {2, 10}, {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f, + 16.0f, 18.0f, 20.0f, 12.0f, 14.0f, 16.0f, 18.0f, + 20.0f, 22.0f, 24.0f, 26.0f, 28.0f, 30.0f}, + 1.0f); + TestAdd( + {5, 2}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}, + 0.0f, 10.0f, {2, 5, 2}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f}, + 0.0f, 20.0f, {2, 5, 2}, + {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f, 16.0f, 18.0f, 20.0f, + 12.0f, 14.0f, 16.0f, 18.0f, 20.0f, 22.0f, 24.0f, 26.0f, 28.0f, 30.0f}, + 1.0f); +} + +void TestVectorPlusTensor() { + TestAddShape({100}, {2, 100}); + TestAddShape({2, 100}, {100}); + TestAddShape({5, 2}, {2, 5, 2}); +} + +void BenchmarkTensorScalar() { + TimeAdd({200}, {1}, 1000); + TimeAdd({10000}, {1}, 100); + TimeAdd({1000000}, {1}, 10); + TimeAdd({10000000}, {1}, 1); +} + +void BenchmarkVector() { + TimeAdd({200}, {200}, 1000); + TimeAdd({10000}, {10000}, 100); + TimeAdd({1000000}, {1000000}, 10); + TimeAdd({10000000}, {10000000}, 1); +} + +void BenchmarkVectorPlusTensor() { + TimeAdd({10, 20}, {20}, 100); + TimeAdd({10, 1000}, {1000}, 10); + TimeAdd({1000, 1000}, {1000}, 1); + TimeAdd({10000, 1000}, {1000}, 1); + TimeAdd({100, 100}, {100}, 10); + TimeAdd({10000, 100}, {100}, 1); + TimeAdd({100000, 100}, {100}, 1); +} + +#if !defined(__ANDROID__) + +#define RUN_TEST(t) \ + TEST(QuantizedAddOpTest, t) { t(); } + +RUN_TEST(TestManualScalar); +RUN_TEST(TestManualVector); +RUN_TEST(TestManualVectorPlusTensor); +RUN_TEST(TestScalar); +RUN_TEST(TestVector); +RUN_TEST(TestVectorPlusTensor); + +#undef RUN_TEST + +#endif // __ANDROID__ + +} // end namespace tensorflow + +#if defined(__ANDROID__) +int main(int argc, char** argv) { + LOG(INFO) << "TestManualScalar:"; + tensorflow::TestManualScalar(); + LOG(INFO) << "TestManualVector:"; + tensorflow::TestManualVector(); + LOG(INFO) << "TestManualVectorPlusTensor:"; + tensorflow::TestManualVectorPlusTensor(); + tensorflow::BenchmarkTensorScalar(); + tensorflow::BenchmarkVector(); + tensorflow::BenchmarkVectorPlusTensor(); + LOG(INFO) << "All tests complete"; + return 0; +} +#endif // __ANDROID__ diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 28c4ec643e..3cf8a98193 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -2252,6 +2252,35 @@ max_z: The float value that the highest quantized output value represents. broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) )doc"); +REGISTER_OP("QuantizedAdd") + .Input("x: T1") + .Input("y: T2") + .Input("min_x: float") + .Input("max_x: float") + .Input("min_y: float") + .Input("max_y: float") + .Output("z: Toutput") + .Output("min_z: float") + .Output("max_z: float") + .Attr("T1: quantizedtype") + .Attr("T2: quantizedtype") + .Attr("Toutput: quantizedtype = DT_QINT32") + .SetIsCommutative() + .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) + .Doc(R"doc( +Returns x + y element-wise, working on quantized buffers. + +min_x: The float value that the lowest quantized `x` value represents. +max_x: The float value that the highest quantized `x` value represents. +min_y: The float value that the lowest quantized `y` value represents. +max_y: The float value that the highest quantized `y` value represents. +min_z: The float value that the lowest quantized output value represents. +max_z: The float value that the highest quantized output value represents. + +*NOTE*: `QuantizedAdd` supports limited forms of broadcasting. More about +broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +)doc"); + REGISTER_OP("QuantizeDownAndShrinkRange") .Input("input: Tinput") .Input("input_min: float") diff --git a/tensorflow/core/platform/test.cc b/tensorflow/core/platform/test.cc index 920f983982..d71f72e322 100644 --- a/tensorflow/core/platform/test.cc +++ b/tensorflow/core/platform/test.cc @@ -15,15 +15,17 @@ limitations under the License. #include "tensorflow/core/platform/types.h" -#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_POSIX_ANDROID) || \ - defined(PLATFORM_GOOGLE_ANDROID) +#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) #include "tensorflow/core/platform/google/build_config/googletest.h" #endif +#include <cstdlib> +#include <iostream> + namespace tensorflow { namespace testing { -#if defined(PLATFORM_GOOGLE) || defined(__ANDROID__) +#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) string TmpDir() { return FLAGS_test_tmpdir; } int RandomSeed() { return FLAGS_test_random_seed; } #else diff --git a/tensorflow/tools/graph_transforms/quantize_nodes.cc b/tensorflow/tools/graph_transforms/quantize_nodes.cc index 78078ab6ab..5497ad008b 100644 --- a/tensorflow/tools/graph_transforms/quantize_nodes.cc +++ b/tensorflow/tools/graph_transforms/quantize_nodes.cc @@ -56,6 +56,13 @@ struct QuantizedOpInfo { // conversion process can transform them. const std::vector<QuantizedOpInfo>& GetQuantizedOpList() { static const std::vector<QuantizedOpInfo> op_list = { + {"Add", + {}, + {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}}, + DT_QUINT8, + DT_QINT32, + {}, + QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, {"AvgPool", {"ksize", "strides", "padding"}, {{"T", DT_QUINT8}}, diff --git a/tensorflow/tools/graph_transforms/quantize_nodes_test.cc b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc index f8fe13ca13..d02655f3f9 100644 --- a/tensorflow/tools/graph_transforms/quantize_nodes_test.cc +++ b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc @@ -105,9 +105,9 @@ class QuantizeNodesTest : public ::testing::Test { &quantized_graph_def); // Reshape is not included here because it can be added as part of the // quantization process. - const std::set<string> quantizable_ops = {"BiasAdd", "Concat", "Conv2D", - "MatMul", "Relu", "Relu6", - "AvgPool", "MaxPool", "Mul"}; + const std::set<string> quantizable_ops = { + "Add", "BiasAdd", "Concat", "Conv2D", "MatMul", + "Relu", "Relu6", "AvgPool", "MaxPool", "Mul"}; for (const NodeDef& node : quantized_graph_def.node()) { EXPECT_EQ(0, quantizable_ops.count(node.op())) << "Found quantizable node " << node.op() << " for node named " @@ -277,6 +277,41 @@ class QuantizeNodesTest : public ::testing::Test { TestQuantizedVersusFloatGraph(float_graph_def, {}, {"mul"}); } + void TestQuantizeAdd() { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + std::vector<int64> x_shape({10, 100}); + const size_t x_num_elements = TensorShape(x_shape).num_elements(); + std::vector<float> x_values(x_num_elements); + for (int i = 0; i < x_num_elements; ++i) { + x_values[i] = (i % 256) / 256.0f; + } + + std::vector<int64> y_shape({100}); + const size_t y_num_elements = TensorShape(y_shape).num_elements(); + std::vector<float> y_values(y_num_elements); + for (int i = 0; i < y_num_elements; ++i) { + y_values[i] = ((i + 23) % 123) - 50; + } + + Scope root = Scope::NewRootScope(); + + Tensor x_float_tensor(DT_FLOAT, TensorShape(x_shape)); + test::FillValues<float>(&x_float_tensor, x_values); + Output x = Const(root.WithOpName("x"), Input::Initializer(x_float_tensor)); + + Tensor y_float_tensor(DT_FLOAT, TensorShape(y_shape)); + test::FillValues<float>(&y_float_tensor, y_values); + Output y = Const(root.WithOpName("y"), Input::Initializer(y_float_tensor)); + + Add add = Add(root.WithOpName("add"), x, y); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + TestQuantizedVersusFloatGraph(float_graph_def, {}, {"add"}); + } + void TestQuantizeConv2D(int depth, int input_width, int input_height, int input_batch_count, int filter_size, int filter_count, int stride, const string& padding, @@ -1382,6 +1417,8 @@ TEST_F(QuantizeNodesTest, TestQuantizeMatMulSmall) { TEST_F(QuantizeNodesTest, TestQuantizeMul) { TestQuantizeMul(); } +TEST_F(QuantizeNodesTest, TestQuantizeAdd) { TestQuantizeAdd(); } + TEST_F(QuantizeNodesTest, TestOddPaddingProblem) { // Tests one error case we ran into in a real graph. TestQuantizeConv2D(1, 4, 4, 1, 3, 1, 2, "SAME", |