aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/example_proto_fast_parsing.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-07 14:37:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-07 15:48:18 -0700
commit7705791619f5e851687e9a63b4315087e189f8be (patch)
treebc4f1da0194c55349c8bea154a6dd90cf2878d60 /tensorflow/core/util/example_proto_fast_parsing.cc
parentd3b34b7e5600741080d10289ddb5d9eafdf53a82 (diff)
Further improve performance of ParseExample fixing regression on specific benchmarks.
- Ensure that each thread receives equal piece. - At least 8 threads (given 8 or more examples) and at most 64 - Use InlinedVector Change: 132488652
Diffstat (limited to 'tensorflow/core/util/example_proto_fast_parsing.cc')
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.cc91
1 files changed, 56 insertions, 35 deletions
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc
index e55c812fff..a8b91859ad 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/presized_cuckoo_map.h"
@@ -34,6 +35,10 @@ namespace tensorflow {
namespace example {
namespace {
+
+template <typename T>
+using SmallVector = gtl::InlinedVector<T, 4>;
+
template <typename A>
auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
a->EnableAliasing(true);
@@ -86,7 +91,7 @@ class Feature {
return Status::OK();
}
- bool ParseBytesList(std::vector<string>* bytes_list) {
+ bool ParseBytesList(SmallVector<string>* bytes_list) {
DCHECK(bytes_list != nullptr);
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
@@ -110,7 +115,7 @@ class Feature {
return true;
}
- bool ParseFloatList(std::vector<float>* float_list) {
+ bool ParseFloatList(SmallVector<float>* float_list) {
DCHECK(float_list != nullptr);
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
@@ -152,7 +157,7 @@ class Feature {
return true;
}
- bool ParseInt64List(std::vector<int64>* int64_list) {
+ bool ParseInt64List(SmallVector<int64>* int64_list) {
DCHECK(int64_list != nullptr);
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
@@ -293,7 +298,7 @@ bool TestFastParse(const string& serialized, Example* example) {
case DT_INVALID:
break;
case DT_STRING: {
- std::vector<string> list;
+ SmallVector<string> list;
if (!entry.second.ParseBytesList(&list)) return false;
auto* result_list = value.mutable_bytes_list();
for (auto& bytes : list) {
@@ -302,7 +307,7 @@ bool TestFastParse(const string& serialized, Example* example) {
break;
}
case DT_FLOAT: {
- std::vector<float> list;
+ SmallVector<float> list;
if (!entry.second.ParseFloatList(&list)) return false;
auto* result_list = value.mutable_float_list();
for (float f : list) {
@@ -311,7 +316,7 @@ bool TestFastParse(const string& serialized, Example* example) {
break;
}
case DT_INT64: {
- std::vector<int64> list;
+ SmallVector<int64> list;
if (!entry.second.ParseInt64List(&list)) return false;
auto* result_list = value.mutable_int64_list();
for (int64 i : list) {
@@ -334,28 +339,32 @@ using Config = FastParseExampleConfig;
void ParallelFor(const std::function<void(size_t)>& f, size_t n,
thread::ThreadPool* thread_pool) {
- DCHECK(thread_pool != nullptr);
if (n == 0) return;
- BlockingCounter counter(n - 1);
- for (size_t i = 1; i < n; ++i) {
- thread_pool->Schedule([i, &f, &counter] {
+ if (thread_pool == nullptr) {
+ for (size_t i = 0; i < n; ++i) {
f(i);
- counter.DecrementCount();
- });
+ }
+ } else {
+ BlockingCounter counter(n - 1);
+ for (size_t i = 1; i < n; ++i) {
+ thread_pool->Schedule([i, &f, &counter] {
+ f(i);
+ counter.DecrementCount();
+ });
+ }
+ f(0);
+ counter.Wait();
}
- f(0);
- counter.Wait();
}
enum class Type { Sparse, Dense };
struct SparseBuffer {
- // TODO(lew): Use InlinedVector.
// Features are in one of the 3 vectors below depending on config's dtype.
// Other 2 vectors remain empty.
- std::vector<string> bytes_list;
- std::vector<float> float_list;
- std::vector<int64> int64_list;
+ SmallVector<string> bytes_list;
+ SmallVector<float> float_list;
+ SmallVector<int64> int64_list;
// Features of example i are elements with indices
// from example_end_indices[i-1] to example_end_indices[i]-1 on the
@@ -432,7 +441,7 @@ Status FastParseSerializedExample(
switch (config.dense[d].dtype) {
case DT_INT64: {
- std::vector<int64> list;
+ SmallVector<int64> list;
if (!feature.ParseInt64List(&list)) return parse_error(feature_name);
if (list.size() != num_elements) {
return shape_error(list.size(), "int64");
@@ -442,7 +451,7 @@ Status FastParseSerializedExample(
break;
}
case DT_FLOAT: {
- std::vector<float> list;
+ SmallVector<float> list;
if (!feature.ParseFloatList(&list)) return parse_error(feature_name);
if (list.size() != num_elements) {
return shape_error(list.size(), "float");
@@ -452,7 +461,7 @@ Status FastParseSerializedExample(
break;
}
case DT_STRING: {
- std::vector<string> list;
+ SmallVector<string> list;
if (!feature.ParseBytesList(&list)) return parse_error(feature_name);
if (list.size() != num_elements) {
return shape_error(list.size(), "bytes");
@@ -580,7 +589,6 @@ Status FastParseExample(const Config& config,
gtl::ArraySlice<string> serialized,
gtl::ArraySlice<string> example_names,
thread::ThreadPool* thread_pool, Result* result) {
- DCHECK(thread_pool != nullptr);
DCHECK(result != nullptr);
// Check config so we can safely CHECK(false) in switches on config.*.dtype
for (auto& c : config.sparse) {
@@ -626,36 +634,49 @@ Status FastParseExample(const Config& config,
}
// This parameter affects performance in a big and data-dependent way.
- const size_t kMiniBatchSizeBytes = 100000;
+ const size_t kMiniBatchSizeBytes = 50000;
- // Split examples into mini-batches for parallel processing.
- auto first_example_of_minibatch = [&] {
- std::vector<size_t> result;
+ // Calculate number of minibatches.
+ // In main regime make each minibatch around kMiniBatchSizeBytes bytes.
+ // Apply 'special logic' below for small and big regimes.
+ const size_t num_minibatches = [&] {
+ size_t result = 0;
size_t minibatch_bytes = 0;
for (size_t i = 0; i < serialized.size(); i++) {
if (minibatch_bytes == 0) { // start minibatch
- result.push_back(i);
+ result++;
}
minibatch_bytes += serialized[i].size() + 1;
if (minibatch_bytes > kMiniBatchSizeBytes) {
minibatch_bytes = 0;
}
}
- return result;
+ // 'special logic'
+ const size_t min_minibatches = std::min<size_t>(8, serialized.size());
+ const size_t max_minibatches = 64;
+ return std::max<size_t>(min_minibatches,
+ std::min<size_t>(max_minibatches, result));
}();
- size_t num_minibatches = first_example_of_minibatch.size();
+ auto first_example_of_minibatch = [&](size_t minibatch) -> size_t {
+ return (serialized.size() * minibatch) / num_minibatches;
+ };
+
+ // TODO(lew): A big performance low-hanging fruit here is to improve
+ // num_minibatches calculation to take into account actual amount of work
+ // needed, as the size in bytes is not perfect. Linear combination of
+ // size in bytes and average number of features per example is promising.
+ // Even better: measure time instead of estimating, but this is too costly
+ // in small batches.
+ // Maybe accept outside parameter #num_minibatches?
// Do minibatches in parallel.
std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches);
std::vector<Status> status_of_minibatch(num_minibatches);
-
auto ProcessMiniBatch = [&](size_t minibatch) {
sparse_buffers[minibatch].resize(config.sparse.size());
- size_t start = first_example_of_minibatch[minibatch];
- size_t end = minibatch + 1 < num_minibatches
- ? first_example_of_minibatch[minibatch + 1]
- : serialized.size();
+ size_t start = first_example_of_minibatch(minibatch);
+ size_t end = first_example_of_minibatch(minibatch + 1);
for (size_t e = start; e < end; ++e) {
status_of_minibatch[minibatch] = FastParseSerializedExample(
serialized[e],
@@ -711,7 +732,7 @@ Status FastParseExample(const Config& config,
// Update indices.
int64* ix_p = &indices->matrix<int64>()(offset, 0);
size_t delta = 0;
- size_t example_index = first_example_of_minibatch[i];
+ size_t example_index = first_example_of_minibatch(i);
for (size_t example_end_index : buffer.example_end_indices) {
size_t feature_index = 0;
for (; delta < example_end_index; ++delta) {