aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/quantized_conv_ops.cc
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2016-12-22 16:27:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-22 16:46:05 -0800
commite27bc777939ba93458f14bc67721c593cbe40ac9 (patch)
tree78c2009b245d6bdfb553760c3421998b1c801dce /tensorflow/core/kernels/quantized_conv_ops.cc
parent4e937d33b390bd5fed17cc38d672be48d3cdd570 (diff)
Fix for quantized convolution optimization
Change: 142809066
Diffstat (limited to 'tensorflow/core/kernels/quantized_conv_ops.cc')
-rw-r--r--tensorflow/core/kernels/quantized_conv_ops.cc21
1 files changed, 12 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/quantized_conv_ops.cc b/tensorflow/core/kernels/quantized_conv_ops.cc
index eff404e11f..f08ebacb03 100644
--- a/tensorflow/core/kernels/quantized_conv_ops.cc
+++ b/tensorflow/core/kernels/quantized_conv_ops.cc
@@ -292,6 +292,7 @@ class Im2ColConvFunctor {
const int64 patch_count = (input_batches * output_height * output_width);
const int64 chunk_count =
(patch_count + (patches_per_chunk - 1)) / patches_per_chunk;
+
for (int64 chunk_index = 0; chunk_index < chunk_count; ++chunk_index) {
const int64 patch_index_start = chunk_index * patches_per_chunk;
const int64 patch_index_end =
@@ -381,13 +382,15 @@ class Im2ColConvFunctor {
const int lda = filter_value_count;
const int ldb = filter_count;
const int ldc = filter_count;
+ T3* chunk_output_data = output_data + (patch_index_start * filter_count);
+
if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() &&
std::is_same<T2, quint8>() && std::is_same<T3, qint32>() &&
(output_offset == 0) && (output_mult == 1) && (output_shift == 0) &&
(transpose_c == false)) {
meta::QuantizedGemm(context, transpose_a, transpose_b, im2col_buffer,
- filter_data, output_data, m, n, k, -input_offset,
- -filter_offset, lda, ldb, ldc);
+ filter_data, chunk_output_data, m, n, k,
+ -input_offset, -filter_offset, lda, ldb, ldc);
} else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() &&
std::is_same<T3, qint32>() && (output_offset == 0) &&
(output_mult == 1) && (output_shift == 0)) {
@@ -396,7 +399,7 @@ class Im2ColConvFunctor {
// slower reference implementation if not.
const uint8* im2col_data_as_uint8 = &(im2col_buffer->value);
const uint8* filter_data_as_uint8 = &(filter_data->value);
- int32* output_data_as_int32 = &(output_data->value);
+ int32* output_data_as_int32 = &(chunk_output_data->value);
// All of the transpose_* variables are currently compile-time consts,
// so we could just hard-code these values too, but that would break if
// anybody changed those values in the future (e.g. to match the ability
@@ -431,8 +434,8 @@ class Im2ColConvFunctor {
} else {
ReferenceGemm<T1, T2, T3>(
transpose_a, transpose_b, transpose_c, m, n, k, im2col_buffer,
- input_offset, lda, filter_data, filter_offset, ldb, output_data,
- output_shift, output_offset, output_mult, ldc);
+ input_offset, lda, filter_data, filter_offset, ldb,
+ chunk_output_data, output_shift, output_offset, output_mult, ldc);
}
}
}
@@ -491,10 +494,10 @@ class QuantizedConv2DOp : public OpKernel {
// The last dimension for input is in_depth. It must be the same as the
// filter's in_depth.
const int64 in_depth = input.dim_size(3);
- OP_REQUIRES(
- context, in_depth == filter.dim_size(2),
- errors::InvalidArgument("input and filter must have the same depth: ",
- in_depth, " vs ", filter.dim_size(2)));
+ OP_REQUIRES(context, in_depth == filter.dim_size(2),
+ errors::InvalidArgument(
+ "input and filter must have the same depth: ", in_depth,
+ " vs ", filter.dim_size(2)));
// The last dimension for filter is out_depth.
const int64 out_depth = filter.dim_size(3);