diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-09-07 14:37:07 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-07 15:48:18 -0700 |
commit | 7705791619f5e851687e9a63b4315087e189f8be (patch) | |
tree | bc4f1da0194c55349c8bea154a6dd90cf2878d60 /tensorflow/core/util/example_proto_fast_parsing.cc | |
parent | d3b34b7e5600741080d10289ddb5d9eafdf53a82 (diff) |
Further improve performance of ParseExample fixing regression on specific benchmarks.
- Ensure that each thread receives equal piece.
- At least 8 threads (given 8 or more examples) and at most 64
- Use InlinedVector
Change: 132488652
Diffstat (limited to 'tensorflow/core/util/example_proto_fast_parsing.cc')
-rw-r--r-- | tensorflow/core/util/example_proto_fast_parsing.cc | 91 |
1 files changed, 56 insertions, 35 deletions
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index e55c812fff..a8b91859ad 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/util/presized_cuckoo_map.h" @@ -34,6 +35,10 @@ namespace tensorflow { namespace example { namespace { + +template <typename T> +using SmallVector = gtl::InlinedVector<T, 4>; + template <typename A> auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) { a->EnableAliasing(true); @@ -86,7 +91,7 @@ class Feature { return Status::OK(); } - bool ParseBytesList(std::vector<string>* bytes_list) { + bool ParseBytesList(SmallVector<string>* bytes_list) { DCHECK(bytes_list != nullptr); protobuf::io::CodedInputStream stream( reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size()); @@ -110,7 +115,7 @@ class Feature { return true; } - bool ParseFloatList(std::vector<float>* float_list) { + bool ParseFloatList(SmallVector<float>* float_list) { DCHECK(float_list != nullptr); protobuf::io::CodedInputStream stream( reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size()); @@ -152,7 +157,7 @@ class Feature { return true; } - bool ParseInt64List(std::vector<int64>* int64_list) { + bool ParseInt64List(SmallVector<int64>* int64_list) { DCHECK(int64_list != nullptr); protobuf::io::CodedInputStream stream( reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size()); @@ -293,7 +298,7 @@ bool TestFastParse(const string& serialized, Example* example) { case DT_INVALID: break; case DT_STRING: { - std::vector<string> list; + SmallVector<string> list; if (!entry.second.ParseBytesList(&list)) return false; auto* result_list = value.mutable_bytes_list(); for (auto& bytes : list) { @@ -302,7 +307,7 @@ bool TestFastParse(const string& serialized, Example* example) { break; } case DT_FLOAT: { - std::vector<float> list; + SmallVector<float> list; if (!entry.second.ParseFloatList(&list)) return false; auto* result_list = value.mutable_float_list(); for (float f : list) { @@ -311,7 +316,7 @@ bool TestFastParse(const string& serialized, Example* example) { break; } case DT_INT64: { - std::vector<int64> list; + SmallVector<int64> list; if (!entry.second.ParseInt64List(&list)) return false; auto* result_list = value.mutable_int64_list(); for (int64 i : list) { @@ -334,28 +339,32 @@ using Config = FastParseExampleConfig; void ParallelFor(const std::function<void(size_t)>& f, size_t n, thread::ThreadPool* thread_pool) { - DCHECK(thread_pool != nullptr); if (n == 0) return; - BlockingCounter counter(n - 1); - for (size_t i = 1; i < n; ++i) { - thread_pool->Schedule([i, &f, &counter] { + if (thread_pool == nullptr) { + for (size_t i = 0; i < n; ++i) { f(i); - counter.DecrementCount(); - }); + } + } else { + BlockingCounter counter(n - 1); + for (size_t i = 1; i < n; ++i) { + thread_pool->Schedule([i, &f, &counter] { + f(i); + counter.DecrementCount(); + }); + } + f(0); + counter.Wait(); } - f(0); - counter.Wait(); } enum class Type { Sparse, Dense }; struct SparseBuffer { - // TODO(lew): Use InlinedVector. // Features are in one of the 3 vectors below depending on config's dtype. // Other 2 vectors remain empty. - std::vector<string> bytes_list; - std::vector<float> float_list; - std::vector<int64> int64_list; + SmallVector<string> bytes_list; + SmallVector<float> float_list; + SmallVector<int64> int64_list; // Features of example i are elements with indices // from example_end_indices[i-1] to example_end_indices[i]-1 on the @@ -432,7 +441,7 @@ Status FastParseSerializedExample( switch (config.dense[d].dtype) { case DT_INT64: { - std::vector<int64> list; + SmallVector<int64> list; if (!feature.ParseInt64List(&list)) return parse_error(feature_name); if (list.size() != num_elements) { return shape_error(list.size(), "int64"); @@ -442,7 +451,7 @@ Status FastParseSerializedExample( break; } case DT_FLOAT: { - std::vector<float> list; + SmallVector<float> list; if (!feature.ParseFloatList(&list)) return parse_error(feature_name); if (list.size() != num_elements) { return shape_error(list.size(), "float"); @@ -452,7 +461,7 @@ Status FastParseSerializedExample( break; } case DT_STRING: { - std::vector<string> list; + SmallVector<string> list; if (!feature.ParseBytesList(&list)) return parse_error(feature_name); if (list.size() != num_elements) { return shape_error(list.size(), "bytes"); @@ -580,7 +589,6 @@ Status FastParseExample(const Config& config, gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names, thread::ThreadPool* thread_pool, Result* result) { - DCHECK(thread_pool != nullptr); DCHECK(result != nullptr); // Check config so we can safely CHECK(false) in switches on config.*.dtype for (auto& c : config.sparse) { @@ -626,36 +634,49 @@ Status FastParseExample(const Config& config, } // This parameter affects performance in a big and data-dependent way. - const size_t kMiniBatchSizeBytes = 100000; + const size_t kMiniBatchSizeBytes = 50000; - // Split examples into mini-batches for parallel processing. - auto first_example_of_minibatch = [&] { - std::vector<size_t> result; + // Calculate number of minibatches. + // In main regime make each minibatch around kMiniBatchSizeBytes bytes. + // Apply 'special logic' below for small and big regimes. + const size_t num_minibatches = [&] { + size_t result = 0; size_t minibatch_bytes = 0; for (size_t i = 0; i < serialized.size(); i++) { if (minibatch_bytes == 0) { // start minibatch - result.push_back(i); + result++; } minibatch_bytes += serialized[i].size() + 1; if (minibatch_bytes > kMiniBatchSizeBytes) { minibatch_bytes = 0; } } - return result; + // 'special logic' + const size_t min_minibatches = std::min<size_t>(8, serialized.size()); + const size_t max_minibatches = 64; + return std::max<size_t>(min_minibatches, + std::min<size_t>(max_minibatches, result)); }(); - size_t num_minibatches = first_example_of_minibatch.size(); + auto first_example_of_minibatch = [&](size_t minibatch) -> size_t { + return (serialized.size() * minibatch) / num_minibatches; + }; + + // TODO(lew): A big performance low-hanging fruit here is to improve + // num_minibatches calculation to take into account actual amount of work + // needed, as the size in bytes is not perfect. Linear combination of + // size in bytes and average number of features per example is promising. + // Even better: measure time instead of estimating, but this is too costly + // in small batches. + // Maybe accept outside parameter #num_minibatches? // Do minibatches in parallel. std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches); std::vector<Status> status_of_minibatch(num_minibatches); - auto ProcessMiniBatch = [&](size_t minibatch) { sparse_buffers[minibatch].resize(config.sparse.size()); - size_t start = first_example_of_minibatch[minibatch]; - size_t end = minibatch + 1 < num_minibatches - ? first_example_of_minibatch[minibatch + 1] - : serialized.size(); + size_t start = first_example_of_minibatch(minibatch); + size_t end = first_example_of_minibatch(minibatch + 1); for (size_t e = start; e < end; ++e) { status_of_minibatch[minibatch] = FastParseSerializedExample( serialized[e], @@ -711,7 +732,7 @@ Status FastParseExample(const Config& config, // Update indices. int64* ix_p = &indices->matrix<int64>()(offset, 0); size_t delta = 0; - size_t example_index = first_example_of_minibatch[i]; + size_t example_index = first_example_of_minibatch(i); for (size_t example_end_index : buffer.example_end_indices) { size_t feature_index = 0; for (; delta < example_end_index; ++delta) { |