aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantization
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-16 06:51:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-16 08:03:19 -0700
commit9f2fa2ec4a68bb9e88ee20146927f84e4f9fe199 (patch)
tree8126294c685e8735e3fb145c3131c658730da19a /tensorflow/contrib/quantization
parent2c4cba87fa0c8b3f003fb84544d2c68140583a0e (diff)
In contrib/quantization, use eigen threadpool when calling gemmlowp to
avoid creating a new one each time. Change: 127624630
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r--tensorflow/contrib/quantization/kernels/quantization_utils.h56
-rw-r--r--tensorflow/contrib/quantization/kernels/quantized_conv_ops.cc54
-rw-r--r--tensorflow/contrib/quantization/kernels/quantized_matmul_op.cc35
3 files changed, 104 insertions, 41 deletions
diff --git a/tensorflow/contrib/quantization/kernels/quantization_utils.h b/tensorflow/contrib/quantization/kernels/quantization_utils.h
index 5c3716d24e..c9a3c77797 100644
--- a/tensorflow/contrib/quantization/kernels/quantization_utils.h
+++ b/tensorflow/contrib/quantization/kernels/quantization_utils.h
@@ -25,7 +25,9 @@ limitations under the License.
// to avoid a dependency on floating-point hardware.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "external/gemmlowp/public/gemmlowp.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/threadpool.h"
namespace tensorflow {
@@ -487,6 +489,60 @@ void QuantizedAdd(const Eigen::ThreadPoolDevice& device, const Tensor& input,
}
}
+// See gemmlowp/internal/multi_thread_gemm.h for definitions of
+// Prepare, Wait, StartWorker, and CreateWorkers.
+class TensorflowGemmlowpWorkersPool {
+ public:
+ TensorflowGemmlowpWorkersPool(thread::ThreadPool* workers)
+ : workers_(workers) {}
+
+ void Prepare(int workers_count) {
+ counter_to_decrement_when_ready_.Reset(workers_count);
+ }
+
+ void Wait() { counter_to_decrement_when_ready_.Wait(); }
+
+ void StartWorker(int index, gemmlowp::Task* task) {
+ CHECK(workers_ != nullptr);
+ // <index> is ignored - the tensorflow threadpool does not support assigning
+ // to a specific thread.
+ workers_->Schedule([this, task]() {
+ // TODO(cwhipkey): get a local_allocator from a thread local.
+ gemmlowp::Allocator local_allocator;
+ CHECK(task != nullptr);
+ task->local_allocator = &local_allocator;
+ task->Run();
+ delete task;
+ counter_to_decrement_when_ready_.DecrementCount();
+ });
+ }
+
+ void CreateWorkers(std::size_t workers_count) {}
+
+ private:
+ thread::ThreadPool* const workers_;
+
+ // The BlockingCounter used to wait for the workers.
+ gemmlowp::BlockingCounter counter_to_decrement_when_ready_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorflowGemmlowpWorkersPool);
+};
+
+class TensorflowGemmContext : public gemmlowp::MultiThreadGemmContextBase {
+ public:
+ TensorflowGemmContext(int num_threads, thread::ThreadPool* workers)
+ : workers_pool_(workers) {
+ set_max_num_threads(num_threads);
+ }
+
+ TensorflowGemmlowpWorkersPool* workers_pool() { return &workers_pool_; }
+
+ private:
+ TensorflowGemmlowpWorkersPool workers_pool_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorflowGemmContext);
+};
+
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_QUANTIZATION_UTILS_H_
diff --git a/tensorflow/contrib/quantization/kernels/quantized_conv_ops.cc b/tensorflow/contrib/quantization/kernels/quantized_conv_ops.cc
index 46ec0a337a..647e68ea12 100644
--- a/tensorflow/contrib/quantization/kernels/quantized_conv_ops.cc
+++ b/tensorflow/contrib/quantization/kernels/quantized_conv_ops.cc
@@ -46,13 +46,13 @@ namespace tensorflow {
template <class T1, class T2, class T3>
class ReferenceConvFunctor {
public:
- void operator()(const T1* input_data, int input_batches, int input_height,
- int input_width, int input_depth, int input_offset,
- const T2* filter_data, int filter_height, int filter_width,
- int filter_count, int filter_offset, int stride,
- Padding padding, T3* output_data, int output_height,
- int output_width, int output_shift, int output_offset,
- int output_mult) {
+ void operator()(OpKernelContext* op_context, const T1* input_data,
+ int input_batches, int input_height, int input_width,
+ int input_depth, int input_offset, const T2* filter_data,
+ int filter_height, int filter_width, int filter_count,
+ int filter_offset, int stride, Padding padding,
+ T3* output_data, int output_height, int output_width,
+ int output_shift, int output_offset, int output_mult) {
// Set up some constants we need for the output down-shifting and
// saturation.
const int32 highest = static_cast<int32>(Eigen::NumTraits<T3>::highest());
@@ -186,13 +186,13 @@ class ReferenceConvFunctor {
template <class T1, class T2, class T3>
class Im2ColConvFunctor {
public:
- void operator()(const T1* input_data, int input_batches, int input_height,
- int input_width, int input_depth, int input_offset,
- const T2* filter_data, int filter_height, int filter_width,
- int filter_count, int filter_offset, int stride,
- Padding padding, T3* output_data, int output_height,
- int output_width, int output_shift, int output_offset,
- int output_mult) {
+ void operator()(OpKernelContext* op_context, const T1* input_data,
+ int input_batches, int input_height, int input_width,
+ int input_depth, int input_offset, const T2* filter_data,
+ int filter_height, int filter_width, int filter_count,
+ int filter_offset, int stride, Padding padding,
+ T3* output_data, int output_height, int output_width,
+ int output_shift, int output_offset, int output_mult) {
if (input_offset < 0) {
// Only log the first few occurrences of this warning.
static int warning_count = 0;
@@ -206,11 +206,11 @@ class Im2ColConvFunctor {
<< " avoid this situation.";
}
ReferenceConvFunctor<T1, T2, T3> conv_functor;
- conv_functor(input_data, input_batches, input_height, input_width,
- input_depth, input_offset, filter_data, filter_height,
- filter_width, filter_count, filter_offset, stride, padding,
- output_data, output_height, output_width, output_shift,
- output_offset, output_mult);
+ conv_functor(op_context, input_data, input_batches, input_height,
+ input_width, input_depth, input_offset, filter_data,
+ filter_height, filter_width, filter_count, filter_offset,
+ stride, padding, output_data, output_height, output_width,
+ output_shift, output_offset, output_mult);
return;
}
@@ -369,7 +369,11 @@ class Im2ColConvFunctor {
gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(
output_data_as_int32, m, n, ldc);
const std::tuple<> empty_pipeline = {};
- gemmlowp::GemmContext context;
+
+ auto& worker_threads =
+ *(op_context->device()->tensorflow_cpu_worker_threads());
+ TensorflowGemmContext context(worker_threads.num_threads,
+ worker_threads.workers);
gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t,
gemmlowp::DefaultL8R8BitDepthParams>(
&context, lhs, rhs, &result, -input_offset, -filter_offset,
@@ -483,11 +487,11 @@ class QuantizedConv2DOp : public OpKernel {
// This will call different implementations (e.g. reference or optimized)
// depending on the template parameter.
ConvFunctor<T1, T2, T3> conv_functor;
- conv_functor(input.flat<T1>().data(), batch, input_rows, input_cols,
- in_depth, offset_input, filter.flat<T2>().data(), filter_rows,
- filter_cols, out_depth, offset_filter, stride, padding_,
- output->flat<T3>().data(), out_rows, out_cols, shift_output,
- offset_output, mult_output);
+ conv_functor(context, input.flat<T1>().data(), batch, input_rows,
+ input_cols, in_depth, offset_input, filter.flat<T2>().data(),
+ filter_rows, filter_cols, out_depth, offset_filter, stride,
+ padding_, output->flat<T3>().data(), out_rows, out_cols,
+ shift_output, offset_output, mult_output);
float min_output_value;
float max_output_value;
diff --git a/tensorflow/contrib/quantization/kernels/quantized_matmul_op.cc b/tensorflow/contrib/quantization/kernels/quantized_matmul_op.cc
index 856ee8c2f4..21abce932a 100644
--- a/tensorflow/contrib/quantization/kernels/quantized_matmul_op.cc
+++ b/tensorflow/contrib/quantization/kernels/quantized_matmul_op.cc
@@ -28,9 +28,9 @@ namespace tensorflow {
// combinations of transpose attributes we need to support, and they have to be
// compile-time constants to work with the templates used internally.
template <bool TransposeA, bool TransposeB, bool TransposeC>
-void GemmlowpMultiply(const quint8* a_data, const quint8* b_data,
- qint32* c_data, int m, int n, int k, int offset_a,
- int offset_b, int lda, int ldb, int ldc) {
+void GemmlowpMultiply(OpKernelContext* op_context, const quint8* a_data,
+ const quint8* b_data, qint32* c_data, int m, int n, int k,
+ int offset_a, int offset_b, int lda, int ldb, int ldc) {
const uint8* a_data_as_uint8 = &(a_data->value);
const uint8* b_data_as_uint8 = &(b_data->value);
int32* c_data_as_int32 = &(c_data->value);
@@ -47,7 +47,10 @@ void GemmlowpMultiply(const quint8* a_data, const quint8* b_data,
gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(c_data_as_int32, m, n,
ldc);
const std::tuple<> empty_pipeline = {};
- gemmlowp::GemmContext context;
+ auto& worker_threads =
+ *(op_context->device()->tensorflow_cpu_worker_threads());
+ TensorflowGemmContext context(worker_threads.num_threads,
+ worker_threads.workers);
gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t,
gemmlowp::DefaultL8R8BitDepthParams>(
&context, lhs, rhs, &result, -offset_a, -offset_b, empty_pipeline);
@@ -130,23 +133,23 @@ class QuantizedMatMulOp : public OpKernel {
(shift_c == 0) && (transpose_c == false)) {
if (transpose_a_) {
if (transpose_b_) {
- GemmlowpMultiply<true, true, false>(a_data, b_data, c_data, m, n, k,
- offset_a, offset_b, lda, ldb,
- ldc);
+ GemmlowpMultiply<true, true, false>(context, a_data, b_data, c_data,
+ m, n, k, offset_a, offset_b, lda,
+ ldb, ldc);
} else {
- GemmlowpMultiply<true, false, false>(a_data, b_data, c_data, m, n, k,
- offset_a, offset_b, lda, ldb,
- ldc);
+ GemmlowpMultiply<true, false, false>(context, a_data, b_data, c_data,
+ m, n, k, offset_a, offset_b, lda,
+ ldb, ldc);
}
} else {
if (transpose_b_) {
- GemmlowpMultiply<false, true, false>(a_data, b_data, c_data, m, n, k,
- offset_a, offset_b, lda, ldb,
- ldc);
+ GemmlowpMultiply<false, true, false>(context, a_data, b_data, c_data,
+ m, n, k, offset_a, offset_b, lda,
+ ldb, ldc);
} else {
- GemmlowpMultiply<false, false, false>(a_data, b_data, c_data, m, n, k,
- offset_a, offset_b, lda, ldb,
- ldc);
+ GemmlowpMultiply<false, false, false>(context, a_data, b_data, c_data,
+ m, n, k, offset_a, offset_b,
+ lda, ldb, ldc);
}
}
} else {