diff options
author | Pete Warden <petewarden@google.com> | 2016-12-22 16:27:51 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-22 16:46:05 -0800 |
commit | e27bc777939ba93458f14bc67721c593cbe40ac9 (patch) | |
tree | 78c2009b245d6bdfb553760c3421998b1c801dce /tensorflow/core/kernels/quantized_conv_ops.cc | |
parent | 4e937d33b390bd5fed17cc38d672be48d3cdd570 (diff) |
Fix for quantized convolution optimization
Change: 142809066
Diffstat (limited to 'tensorflow/core/kernels/quantized_conv_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/quantized_conv_ops.cc | 21 |
1 files changed, 12 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/quantized_conv_ops.cc b/tensorflow/core/kernels/quantized_conv_ops.cc index eff404e11f..f08ebacb03 100644 --- a/tensorflow/core/kernels/quantized_conv_ops.cc +++ b/tensorflow/core/kernels/quantized_conv_ops.cc @@ -292,6 +292,7 @@ class Im2ColConvFunctor { const int64 patch_count = (input_batches * output_height * output_width); const int64 chunk_count = (patch_count + (patches_per_chunk - 1)) / patches_per_chunk; + for (int64 chunk_index = 0; chunk_index < chunk_count; ++chunk_index) { const int64 patch_index_start = chunk_index * patches_per_chunk; const int64 patch_index_end = @@ -381,13 +382,15 @@ class Im2ColConvFunctor { const int lda = filter_value_count; const int ldb = filter_count; const int ldc = filter_count; + T3* chunk_output_data = output_data + (patch_index_start * filter_count); + if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() && std::is_same<T2, quint8>() && std::is_same<T3, qint32>() && (output_offset == 0) && (output_mult == 1) && (output_shift == 0) && (transpose_c == false)) { meta::QuantizedGemm(context, transpose_a, transpose_b, im2col_buffer, - filter_data, output_data, m, n, k, -input_offset, - -filter_offset, lda, ldb, ldc); + filter_data, chunk_output_data, m, n, k, + -input_offset, -filter_offset, lda, ldb, ldc); } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() && std::is_same<T3, qint32>() && (output_offset == 0) && (output_mult == 1) && (output_shift == 0)) { @@ -396,7 +399,7 @@ class Im2ColConvFunctor { // slower reference implementation if not. const uint8* im2col_data_as_uint8 = &(im2col_buffer->value); const uint8* filter_data_as_uint8 = &(filter_data->value); - int32* output_data_as_int32 = &(output_data->value); + int32* output_data_as_int32 = &(chunk_output_data->value); // All of the transpose_* variables are currently compile-time consts, // so we could just hard-code these values too, but that would break if // anybody changed those values in the future (e.g. to match the ability @@ -431,8 +434,8 @@ class Im2ColConvFunctor { } else { ReferenceGemm<T1, T2, T3>( transpose_a, transpose_b, transpose_c, m, n, k, im2col_buffer, - input_offset, lda, filter_data, filter_offset, ldb, output_data, - output_shift, output_offset, output_mult, ldc); + input_offset, lda, filter_data, filter_offset, ldb, + chunk_output_data, output_shift, output_offset, output_mult, ldc); } } } @@ -491,10 +494,10 @@ class QuantizedConv2DOp : public OpKernel { // The last dimension for input is in_depth. It must be the same as the // filter's in_depth. const int64 in_depth = input.dim_size(3); - OP_REQUIRES( - context, in_depth == filter.dim_size(2), - errors::InvalidArgument("input and filter must have the same depth: ", - in_depth, " vs ", filter.dim_size(2))); + OP_REQUIRES(context, in_depth == filter.dim_size(2), + errors::InvalidArgument( + "input and filter must have the same depth: ", in_depth, + " vs ", filter.dim_size(2))); // The last dimension for filter is out_depth. const int64 out_depth = filter.dim_size(3); |