diff options
author | 2016-08-19 10:44:37 -0800 | |
---|---|---|
committer | 2016-08-19 11:48:07 -0700 | |
commit | 0b9f0f53ddbf693bb30afb211a6d514a1fce1c22 (patch) | |
tree | fcc68dbf0d05ba6e093b58274645aaaa9d9b7c39 /tensorflow | |
parent | 859e47fd8ebcdd7fb1411fd0090d0e95a801e7cb (diff) |
Implement fast ParseExample.
Change: 130775324
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/example_parsing_ops.cc | 151 | ||||
-rw-r--r-- | tensorflow/core/kernels/example_parsing_ops_test.cc | 31 | ||||
-rw-r--r-- | tensorflow/core/util/example_proto_fast_parsing.cc | 761 | ||||
-rw-r--r-- | tensorflow/core/util/example_proto_fast_parsing.h | 88 | ||||
-rw-r--r-- | tensorflow/core/util/example_proto_fast_parsing_test.cc | 184 | ||||
-rw-r--r-- | tensorflow/core/util/presized_cuckoo_map.h | 8 | ||||
-rw-r--r-- | tensorflow/core/util/presized_cuckoo_map_test.cc | 18 |
8 files changed, 1102 insertions, 141 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 224badc7d3..60c5976b9b 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -266,6 +266,7 @@ tf_cuda_library( "util/cuda_kernel_helper.h", "util/device_name_utils.h", "util/events_writer.h", + "util/example_proto_fast_parsing.h", "util/example_proto_helper.h", "util/guarded_philox_random.h", "util/memmapped_file_system.h", @@ -1459,6 +1460,7 @@ tf_cc_tests( "util/command_line_flags_test.cc", "util/device_name_utils_test.cc", "util/events_writer_test.cc", + "util/example_proto_fast_parsing_test.cc", "util/example_proto_helper_test.cc", "util/memmapped_file_system_test.cc", "util/presized_cuckoo_map_test.cc", diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc index a7091645fa..6338638a3b 100644 --- a/tensorflow/core/kernels/example_parsing_ops.cc +++ b/tensorflow/core/kernels/example_parsing_ops.cc @@ -24,8 +24,10 @@ limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/util/example_proto_fast_parsing.h" #include "tensorflow/core/util/example_proto_helper.h" #include "tensorflow/core/util/sparse/sparse_tensor.h" #include "tensorflow/core/util/work_sharder.h" @@ -67,8 +69,7 @@ class ExampleParserOp : public OpKernel { sparse_keys_t[di] = sparse_keys[di].scalar<string>()(); } - bool has_names = (names->NumElements() > 0); - if (has_names) { + if (names->NumElements() > 0) { OP_REQUIRES( ctx, TensorShapeUtils::IsVector(names->shape()), errors::InvalidArgument("Expected names to be a vector, got shape: ", @@ -79,7 +80,6 @@ class ExampleParserOp : public OpKernel { "Expected len(names) == len(serialized), but got: ", names->NumElements(), " vs. ", serialized->NumElements())); } - auto names_t = names->flat<string>(); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(serialized->shape()), errors::InvalidArgument( @@ -106,136 +106,43 @@ class ExampleParserOp : public OpKernel { } } - auto serialized_t = serialized->vec<string>(); + example::Result result; + + example::FastParseExampleConfig config; + for (int d = 0; d < attrs_.num_dense; ++d) { + config.dense.push_back({dense_keys_t[d], attrs_.dense_types[d], + attrs_.dense_shapes[d], dense_defaults[d]}); + } + for (int d = 0; d < attrs_.num_sparse; ++d) { + config.sparse.push_back({sparse_keys_t[d], attrs_.sparse_types[d]}); + } + + auto serialized_t = serialized->flat<string>(); + auto names_t = names->flat<string>(); + gtl::ArraySlice<string> slice(serialized_t.data(), serialized_t.size()); + gtl::ArraySlice<string> names_slice(names_t.data(), names_t.size()); - const int64 batch_size = serialized_t.size(); + OP_REQUIRES_OK( + ctx, + FastParseExample( + config, slice, names_slice, + ctx->device()->tensorflow_cpu_worker_threads()->workers, &result)); + OpOutputList dense_values; OpOutputList sparse_indices; OpOutputList sparse_values; OpOutputList sparse_shapes; - OpOutputList dense_values; - + OP_REQUIRES_OK(ctx, ctx->output_list("dense_values", &dense_values)); OP_REQUIRES_OK(ctx, ctx->output_list("sparse_indices", &sparse_indices)); OP_REQUIRES_OK(ctx, ctx->output_list("sparse_values", &sparse_values)); OP_REQUIRES_OK(ctx, ctx->output_list("sparse_shapes", &sparse_shapes)); - OP_REQUIRES_OK(ctx, ctx->output_list("dense_values", &dense_values)); - - // Setup Dense features and the output_dense_values Tensor* vector. - std::vector<FixedLenFeature> fixed_len_features(attrs_.num_dense); - std::vector<Tensor*> output_dense_values(attrs_.num_dense); - for (int d = 0; d < attrs_.num_dense; ++d) { - // Preallocate dense_values, since we know their sizes - TensorShape out_shape; - out_shape.AddDim(batch_size); - for (const int64 dim : attrs_.dense_shapes[d].dim_sizes()) { - out_shape.AddDim(dim); - } - Tensor* out = nullptr; - dense_values.allocate(d, out_shape, &out); - - FixedLenFeature config; - config.key = dense_keys_t[d]; - config.dtype = attrs_.dense_types[d]; - config.shape = attrs_.dense_shapes[d]; - config.default_value = dense_defaults[d]; - fixed_len_features[d] = config; - output_dense_values[d] = dense_values[d]; + dense_values.set(d, result.dense_values[d]); } - - // sparse_values_tmp will be attrs_.num_sparse size map of batch_size length - // tensor vector's, containing the sparse values from the input layer. - // After these are all stored, we can allocate properly sized outputs - // and copy data over. Doing it this way saves us the trouble of either - // performing deserialization twice, or alternatively storing all copies of - // the full Example protos. - std::vector<std::vector<Tensor>> sparse_values_tmp( - attrs_.num_sparse, std::vector<Tensor>(batch_size)); - - // Setup Sparse features. - std::vector<VarLenFeature> var_len_features(attrs_.num_sparse); for (int d = 0; d < attrs_.num_sparse; ++d) { - VarLenFeature config; - config.key = sparse_keys_t[d]; - config.dtype = attrs_.sparse_types[d]; - var_len_features[d] = config; - } - - auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); - - // Estimate the cost of parsing each batch element. - int64 work_unit_size = 1000 + 100 * attrs_.num_sparse; - for (int d = 0; d < attrs_.num_dense; ++d) { - work_unit_size += 100 + attrs_.dense_shapes[d].num_elements(); - } - - mutex mu; - - auto DoWork = [&ctx, &mu, &serialized_t, has_names, &names_t, - &fixed_len_features, &var_len_features, &output_dense_values, - &sparse_values_tmp](int64 start, int64 limit) { - // Processing each Example in the batch starts here. - for (std::size_t b = static_cast<size_t>(start); - b < static_cast<size_t>(limit); ++b) { - // Benchmarks indicate that a tight Arena+Example is most performant. - protobuf::Arena arena; - // ex is owned by the arena. - Example* ex = protobuf::Arena::CreateMessage<Example>(&arena); - bool parse_success = ParseProtoUnlimited(ex, serialized_t(b)); - if (!TF_PREDICT_TRUE(parse_success)) { - mutex_lock l(mu); - ctx->CtxFailure(errors::InvalidArgument( - "Could not parse example input, value: '", serialized_t(b), "'")); - return; - } - const string& example_name = (has_names) ? names_t(b) : "<unknown>"; - Status s = SingleExampleProtoToTensors( - *ex, example_name, b, fixed_len_features, var_len_features, - &output_dense_values, &sparse_values_tmp); - if (!TF_PREDICT_TRUE(s.ok())) { - mutex_lock l(mu); - ctx->CtxFailureWithWarning(s); - } - } - }; - - Shard(worker_threads.num_threads, worker_threads.workers, batch_size, - work_unit_size, DoWork); - - if (!TF_PREDICT_TRUE(ctx->status().ok())) { - return; - } - - // Copy from sparse_values_tmp into final resting Tensors - // ------------------------- - for (int d = 0; d < attrs_.num_sparse; ++d) { - const VarLenFeature& feature_config = var_len_features[d]; - const std::vector<Tensor>& sparse_values_tmp_tensors = - sparse_values_tmp[d]; - VarLenFeatureBatchShapes sparse_tensor_batch_shapes; - GetSparseTensorShapes(feature_config, sparse_values_tmp_tensors, - batch_size, &sparse_tensor_batch_shapes); - - Tensor* sp_indices_d = nullptr; - Tensor* sp_values_d = nullptr; - Tensor* sp_shape_d = nullptr; - - sparse_indices.allocate(d, sparse_tensor_batch_shapes.indices_shape, - &sp_indices_d); - sparse_values.allocate(d, sparse_tensor_batch_shapes.values_shape, - &sp_values_d); - sparse_shapes.allocate(d, TensorShape({2}), &sp_shape_d); - - auto shape_t = sp_shape_d->vec<int64>(); - shape_t(0) = batch_size; - shape_t(1) = sparse_tensor_batch_shapes.max_num_features; - - int64 offset = 0; - for (int b = 0; b < batch_size; ++b) { - const int64 num_elements = CopyIntoSparseTensor( - sparse_values_tmp_tensors[b], b, offset, sp_indices_d, sp_values_d); - offset += num_elements; - } + sparse_indices.set(d, result.sparse_indices[d]); + sparse_values.set(d, result.sparse_values[d]); + sparse_shapes.set(d, result.sparse_shapes[d]); } } diff --git a/tensorflow/core/kernels/example_parsing_ops_test.cc b/tensorflow/core/kernels/example_parsing_ops_test.cc index e58cecff14..187d72685e 100644 --- a/tensorflow/core/kernels/example_parsing_ops_test.cc +++ b/tensorflow/core/kernels/example_parsing_ops_test.cc @@ -161,7 +161,7 @@ static Graph* ParseExample(int batch_size, int num_keys) { // Benchmark settings (Sparse, Dense) X (Bytes, Int64, Float) typedef BenchmarkOptions<ExampleStore<BytesFiller>, false> SparseString; typedef BenchmarkOptions<ExampleStore<BytesFiller>, true> DenseString; -typedef BenchmarkOptions<ExampleStore<Int64Filler>, false> SparseIn64; +typedef BenchmarkOptions<ExampleStore<Int64Filler>, false> SparseInt64; typedef BenchmarkOptions<ExampleStore<Int64Filler>, true> DenseInt64; typedef BenchmarkOptions<ExampleStore<FloatFiller>, false> SparseFloat; typedef BenchmarkOptions<ExampleStore<FloatFiller>, true> DenseFloat; @@ -176,20 +176,19 @@ typedef BenchmarkOptions<ExampleStore<FloatFiller>, true> DenseFloat; } \ BENCHMARK(BM_ParseExample##_##TYPE##_##B##_##K); -#define BM_AllParseExample(B, K) \ - BM_ParseExample(SparseString, B, K); \ - BM_ParseExample(DenseString, B, K); \ - BM_ParseExample(SparseIn64, B, K); \ - BM_ParseExample(DenseInt64, B, K); \ - BM_ParseExample(SparseFloat, B, K); \ - BM_ParseExample(DenseFloat, B, K); - -BM_AllParseExample(128, 10); -BM_AllParseExample(128, 100); -BM_AllParseExample(128, 1000); - -BM_AllParseExample(512, 10); -BM_AllParseExample(512, 100); -BM_AllParseExample(512, 1000); +#define BM_AllParseExample(Type) \ + BM_ParseExample(Type, 128, 10); \ + BM_ParseExample(Type, 512, 10); \ + BM_ParseExample(Type, 128, 100); \ + BM_ParseExample(Type, 512, 100); \ + BM_ParseExample(Type, 128, 1000); \ + BM_ParseExample(Type, 512, 1000); + +BM_AllParseExample(SparseString); +BM_AllParseExample(DenseString); +BM_AllParseExample(SparseInt64); +BM_AllParseExample(DenseInt64); +BM_AllParseExample(SparseFloat); +BM_AllParseExample(DenseFloat); } // end namespace tensorflow diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc new file mode 100644 index 0000000000..fee6b8885f --- /dev/null +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -0,0 +1,761 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/util/example_proto_fast_parsing.h" + +#include <vector> + +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/example/feature.pb_text.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/util/presized_cuckoo_map.h" +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +namespace tensorflow { +namespace example { + +namespace { +template <typename A> +auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) { + a->EnableAliasing(true); +} + +template <typename A> +void EnableAliasing(A&& a) {} + +uint8 PeekTag(protobuf::io::CodedInputStream* stream) { + DCHECK(stream != nullptr); + const void* ptr; + int size; + if (!stream->GetDirectBufferPointer(&ptr, &size)) return 0; + return *static_cast<const uint8*>(ptr); +} + +constexpr uint8 kVarintTag(uint tag) { return (tag << 3) | 0; } +constexpr uint8 kDelimitedTag(uint tag) { return (tag << 3) | 2; } +constexpr uint8 kFixed32Tag(uint tag) { return (tag << 3) | 5; } + +namespace parsed { + +// ParseDataType has to be called first, then appropriate ParseZzzzList. +class Feature { + public: + Feature() {} + Feature(StringPiece serialized) : serialized_(serialized) {} + + Status ParseDataType(DataType* dtype) { + DCHECK(dtype != nullptr); + if (serialized_.empty()) { + *dtype = DT_INVALID; + return Status::OK(); + } + uint8 oneof_tag = static_cast<uint8>(*serialized_.data()); + serialized_.remove_prefix(1); + switch (oneof_tag) { + case kDelimitedTag(1): + *dtype = DT_STRING; + break; + case kDelimitedTag(2): + *dtype = DT_FLOAT; + break; + case kDelimitedTag(3): + *dtype = DT_INT64; + break; + default: + return errors::InvalidArgument("Unsuported datatype."); + } + return Status::OK(); + } + + bool ParseBytesList(std::vector<string>* bytes_list) { + DCHECK(bytes_list != nullptr); + protobuf::io::CodedInputStream stream( + reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size()); + + EnableAliasing(&stream); + + uint32 length; + if (!stream.ReadVarint32(&length)) return false; + auto limit = stream.PushLimit(length); + + while (!stream.ExpectAtEnd()) { + if (!stream.ExpectTag(kDelimitedTag(1))) return false; + // parse string + uint32 bytes_length; + if (!stream.ReadVarint32(&bytes_length)) return false; + string bytes; + if (!stream.ReadString(&bytes, bytes_length)) return false; + bytes_list->push_back(std::move(bytes)); + } + stream.PopLimit(limit); + return true; + } + + bool ParseFloatList(std::vector<float>* float_list) { + DCHECK(float_list != nullptr); + protobuf::io::CodedInputStream stream( + reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size()); + EnableAliasing(&stream); + uint32 length; + if (!stream.ReadVarint32(&length)) return false; + auto limit = stream.PushLimit(length); + + if (!stream.ExpectAtEnd()) { + uint8 peek_tag = PeekTag(&stream); + if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) { + return false; + } + + if (peek_tag == kDelimitedTag(1)) { // packed + if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag + uint32 packed_length; + if (!stream.ReadVarint32(&packed_length)) return false; + auto packed_limit = stream.PushLimit(packed_length); + + while (!stream.ExpectAtEnd()) { + uint32 buffer32; + if (!stream.ReadLittleEndian32(&buffer32)) return false; + float_list->push_back(bit_cast<float>(buffer32)); + } + + stream.PopLimit(packed_limit); + } else { // non-packed + while (!stream.ExpectAtEnd()) { + if (!stream.ExpectTag(kFixed32Tag(1))) return false; + uint32 buffer32; + if (!stream.ReadLittleEndian32(&buffer32)) return false; + float_list->push_back(bit_cast<float>(buffer32)); + } + } + } + + stream.PopLimit(limit); + return true; + } + + bool ParseInt64List(std::vector<int64>* int64_list) { + DCHECK(int64_list != nullptr); + protobuf::io::CodedInputStream stream( + reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size()); + EnableAliasing(&stream); + uint32 length; + if (!stream.ReadVarint32(&length)) return false; + auto limit = stream.PushLimit(length); + + if (!stream.ExpectAtEnd()) { + uint8 peek_tag = PeekTag(&stream); + if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) { + return false; + } + if (peek_tag == kDelimitedTag(1)) { // packed + if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag + uint32 packed_length; + if (!stream.ReadVarint32(&packed_length)) return false; + auto packed_limit = stream.PushLimit(packed_length); + + while (!stream.ExpectAtEnd()) { + uint64 n; // There is no API for int64 + if (!stream.ReadVarint64(&n)) return false; + int64_list->push_back(n); + } + + stream.PopLimit(packed_limit); + } else { // non-packed + while (!stream.ExpectAtEnd()) { + if (!stream.ExpectTag(kVarintTag(1))) return false; + uint64 n; // There is no API for int64 + if (!stream.ReadVarint64(&n)) return false; + int64_list->push_back(n); + } + } + } + stream.PopLimit(limit); + return true; + } + + StringPiece GetSerialized() const { return serialized_; } + + private: + // TODO(lew): Pair of uint8* would be more natural. + StringPiece serialized_; +}; + +using FeatureMapEntry = std::pair<StringPiece, Feature>; +using Example = std::vector<FeatureMapEntry>; + +} // namespace parsed + +bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) { + DCHECK(stream != nullptr); + DCHECK(result != nullptr); + uint32 length; + if (!stream->ReadVarint32(&length)) return false; + if (length == 0) { + *result = StringPiece(nullptr, 0); + return true; + } + const void* stream_alias; + int stream_size; + if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) { + return false; + } + if (static_cast<uint32>(stream_size) < length) return false; + *result = StringPiece(static_cast<const char*>(stream_alias), length); + stream->Skip(length); + return true; +} + +bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream, + parsed::FeatureMapEntry* feature_map_entry) { + DCHECK(stream != nullptr); + DCHECK(feature_map_entry != nullptr); + uint32 length; + if (!stream->ReadVarint32(&length)) return false; + auto limit = stream->PushLimit(length); + if (!stream->ExpectTag(kDelimitedTag(1))) return false; + if (!ParseString(stream, &feature_map_entry->first)) return false; + if (!stream->ExpectTag(kDelimitedTag(2))) return false; + StringPiece feature_string_piece; + if (!ParseString(stream, &feature_string_piece)) return false; + feature_map_entry->second = parsed::Feature(feature_string_piece); + if (!stream->ExpectAtEnd()) return false; + stream->PopLimit(limit); + return true; +} + +bool ParseFeatures(protobuf::io::CodedInputStream* stream, + parsed::Example* example) { + DCHECK(stream != nullptr); + DCHECK(example != nullptr); + uint32 length; + if (!stream->ReadVarint32(&length)) return false; + auto limit = stream->PushLimit(length); + while (!stream->ExpectAtEnd()) { + parsed::FeatureMapEntry feature_map_entry; + if (!stream->ExpectTag(kDelimitedTag(1))) return false; + if (!ParseFeatureMapEntry(stream, &feature_map_entry)) return false; + example->push_back(std::move(feature_map_entry)); + } + stream->PopLimit(limit); + return true; +} + +bool ParseExample(protobuf::io::CodedInputStream* stream, + parsed::Example* example) { + DCHECK(stream != nullptr); + DCHECK(example != nullptr); + if (stream->ExpectTag(kDelimitedTag(1))) { + if (!ParseFeatures(stream, example)) return false; + } + if (!stream->ExpectAtEnd()) return false; + return true; +} + +bool ParseExample(StringPiece serialized, parsed::Example* example) { + DCHECK(example != nullptr); + protobuf::io::CodedInputStream stream( + reinterpret_cast<const uint8*>(serialized.data()), serialized.size()); + EnableAliasing(&stream); + return ParseExample(&stream, example); +} + +} // namespace + +bool TestFastParse(const string& serialized, Example* example) { + DCHECK(example != nullptr); + parsed::Example parsed_example; + if (!ParseExample(serialized, &parsed_example)) return false; + auto& features = *example->mutable_features(); + for (parsed::FeatureMapEntry& entry : parsed_example) { + auto& value = (*features.mutable_feature())[entry.first.ToString()]; + DataType dtype; + if (!entry.second.ParseDataType(&dtype).ok()) return false; + switch (dtype) { + case DT_INVALID: + break; + case DT_STRING: { + std::vector<string> list; + if (!entry.second.ParseBytesList(&list)) return false; + auto* result_list = value.mutable_bytes_list(); + for (auto& bytes : list) { + result_list->add_value(std::move(bytes)); + } + break; + } + case DT_FLOAT: { + std::vector<float> list; + if (!entry.second.ParseFloatList(&list)) return false; + auto* result_list = value.mutable_float_list(); + for (float f : list) { + result_list->add_value(f); + } + break; + } + case DT_INT64: { + std::vector<int64> list; + if (!entry.second.ParseInt64List(&list)) return false; + auto* result_list = value.mutable_int64_list(); + for (int64 i : list) { + result_list->add_value(i); + } + break; + } + default: + CHECK(false) << "Should not happen."; + } + } + return true; +} + +// ----------------------------------------------------------------------------- + +namespace { + +using Config = FastParseExampleConfig; + +void ParallelFor(const std::function<void(size_t)>& f, size_t n, + thread::ThreadPool* thread_pool) { + DCHECK(thread_pool != nullptr); + if (n == 0) return; + BlockingCounter counter(n - 1); + for (size_t i = 1; i < n; ++i) { + thread_pool->Schedule([i, &f, &counter] { + f(i); + counter.DecrementCount(); + }); + } + f(0); + counter.Wait(); +} + +enum class Type { Sparse, Dense }; + +struct SparseBuffer { + // TODO(lew): Use InlinedVector. + // Features are in one of the 3 vectors below depending on config's dtype. + // Other 2 vectors remain empty. + std::vector<string> bytes_list; + std::vector<float> float_list; + std::vector<int64> int64_list; + + // Features of example i are elements with indices + // from example_end_indices[i-1] to example_end_indices[i]-1 on the + // appropriate xxxxx_list + std::vector<size_t> example_end_indices; +}; + +struct SeededHasher { + uint64 operator()(StringPiece s) const { + return Hash64(s.data(), s.size(), seed); + } + uint64 seed{0xDECAFCAFFE}; +}; + +Status FastParseSerializedExample( + const string& serialized_example, const string& example_name, + const size_t example_index, const Config& config, + const PresizedCuckooMap<std::pair<size_t, Type>>& config_index, + SeededHasher hasher, std::vector<Tensor>* output_dense, + std::vector<SparseBuffer>* output_sparse) { + DCHECK(output_dense != nullptr); + DCHECK(output_sparse != nullptr); + parsed::Example parsed_example; + if (!ParseExample(serialized_example, &parsed_example)) { + return errors::InvalidArgument("Could not parse example input, value: '", + serialized_example, "'"); + } + constexpr size_t kMax = std::numeric_limits<size_t>::max(); + std::vector<size_t> sparse_features_found(config.sparse.size(), kMax); + std::vector<size_t> dense_features_found(config.dense.size(), kMax); + + // Handle features present in the example. + for (parsed::FeatureMapEntry& name_and_feature : parsed_example) { + parsed::Feature& feature = name_and_feature.second; + std::pair<size_t, Type> d_and_type; + uint64 h = hasher(name_and_feature.first); + if (!config_index.Find(h, &d_and_type)) continue; + size_t d = d_and_type.first; + + auto parse_error = [&](StringPiece feature_name) { + return errors::InvalidArgument("Name: ", example_name, ", Key: ", + feature_name, ", Index: ", example_index, + ". Can't parse serialized Example."); + }; + + if (d_and_type.second == Type::Dense) { + DataType example_dtype; + TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype)); + if (example_dtype == DT_INVALID) continue; + + dense_features_found[d] = example_index; + if (example_dtype != config.dense[d].dtype) { + return errors::InvalidArgument( + "Name: ", example_name, ", Feature: ", config.dense[d].feature_name, + ". Data types don't match. ", "Data type: ", + DataTypeString(example_dtype), "Expected type: ", + DataTypeString(config.dense[d].dtype)); + } + const string& feature_name = config.dense[d].feature_name; + const TensorShape& shape = config.dense[d].shape; + Tensor& out = (*output_dense)[d]; + + const std::size_t num_elements = shape.num_elements(); + const std::size_t offset = example_index * num_elements; + + auto shape_error = [&](size_t size, StringPiece type_str) { + return errors::InvalidArgument( + "Name: ", example_name, ", Key: ", feature_name, ", Index: ", + example_index, ". Number of ", type_str, + " values != expected. " + "Values size: ", + size, " but output shape: ", shape.DebugString()); + }; + + switch (config.dense[d].dtype) { + case DT_INT64: { + std::vector<int64> list; + if (!feature.ParseInt64List(&list)) return parse_error(feature_name); + if (list.size() != num_elements) { + return shape_error(list.size(), "int64"); + } + auto out_p = out.flat<int64>().data() + offset; + std::copy_n(list.begin(), list.size(), out_p); + break; + } + case DT_FLOAT: { + std::vector<float> list; + if (!feature.ParseFloatList(&list)) return parse_error(feature_name); + if (list.size() != num_elements) { + return shape_error(list.size(), "float"); + } + auto out_p = out.flat<float>().data() + offset; + std::copy_n(list.begin(), list.size(), out_p); + break; + } + case DT_STRING: { + std::vector<string> list; + if (!feature.ParseBytesList(&list)) return parse_error(feature_name); + if (list.size() != num_elements) { + return shape_error(list.size(), "bytes"); + } + auto out_p = out.flat<string>().data() + offset; + for (size_t i = 0; i < list.size(); ++i) { + out_p[i] = std::move(list[i]); + } + break; + } + default: + CHECK(false) << "Should not happen."; + } + } else { + // Handle sparse features. + sparse_features_found[d] = example_index; + const string& feature_name = config.sparse[d].feature_name; + SparseBuffer& out = (*output_sparse)[d]; + DataType example_dtype; + TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype)); + if (example_dtype != DT_INVALID && + example_dtype != config.sparse[d].dtype) { + return errors::InvalidArgument( + "Name: ", example_name, ", Feature: ", + config.sparse[d].feature_name, ". Data types don't match. ", + "Expected type: ", DataTypeString(config.sparse[d].dtype)); + } + + switch (config.sparse[d].dtype) { + case DT_INT64: { + if (example_dtype != DT_INVALID) { + if (!feature.ParseInt64List(&out.int64_list)) { + return parse_error(feature_name); + } + } + out.example_end_indices.push_back(out.int64_list.size()); + break; + } + case DT_FLOAT: { + if (example_dtype != DT_INVALID) { + if (!feature.ParseFloatList(&out.float_list)) { + return parse_error(feature_name); + } + } + out.example_end_indices.push_back(out.float_list.size()); + break; + } + case DT_STRING: { + if (example_dtype != DT_INVALID) { + if (!feature.ParseBytesList(&out.bytes_list)) { + return parse_error(feature_name); + } + } + out.example_end_indices.push_back(out.bytes_list.size()); + break; + } + default: + CHECK(false) << "Should not happen."; + } + } + } + + // Handle missing dense features. + for (size_t d = 0; d < config.dense.size(); ++d) { + if (dense_features_found[d] == example_index) continue; + if (config.dense[d].default_value.NumElements() == 0) { + return errors::InvalidArgument("Name: ", example_name, ", Feature: ", + config.dense[d].feature_name, + " is required but could not be found."); + } + + const Tensor& in = config.dense[d].default_value; + Tensor& out = (*output_dense)[d]; + const std::size_t num_elements = in.shape().num_elements(); + const std::size_t offset = example_index * num_elements; + + switch (config.dense[d].dtype) { + case DT_INT64: { + std::copy_n(in.flat<int64>().data(), num_elements, + out.flat<int64>().data() + offset); + break; + } + case DT_FLOAT: { + std::copy_n(in.flat<float>().data(), num_elements, + out.flat<float>().data() + offset); + break; + } + case DT_STRING: { + std::copy_n(in.flat<string>().data(), num_elements, + out.flat<string>().data() + offset); + break; + } + default: + CHECK(false) << "Should not happen."; + } + } + + // Handle missing sparse features. + for (size_t d = 0; d < config.sparse.size(); ++d) { + if (sparse_features_found[d] == example_index) continue; + SparseBuffer& out = (*output_sparse)[d]; + size_t prev_example_end_index = + out.example_end_indices.empty() ? 0 : out.example_end_indices.back(); + out.example_end_indices.push_back(prev_example_end_index); + } + + return Status::OK(); +} + +Status CheckConfigDataType(DataType dtype) { + switch (dtype) { + case DT_INT64: + case DT_FLOAT: + case DT_STRING: + return Status::OK(); + default: + return errors::InvalidArgument("Invalid config dtype: ", + DataTypeString(dtype)); + } +} + +} // namespace + +Status FastParseExample(const Config& config, + gtl::ArraySlice<string> serialized, + gtl::ArraySlice<string> example_names, + thread::ThreadPool* thread_pool, Result* result) { + DCHECK(thread_pool != nullptr); + DCHECK(result != nullptr); + // Check config so we can safely CHECK(false) in switches on config.*.dtype + for (auto& c : config.sparse) { + TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); + } + for (auto& c : config.dense) { + TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); + } + + size_t config_size = config.dense.size() + config.sparse.size(); + SeededHasher hasher; + // Build config index. + PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size); + bool ok = true; + for (size_t i = 0; i < 1000; ++i) { + for (size_t d = 0; d < config.dense.size(); ++d) { + ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name), + {d, Type::Dense}); + } + for (size_t d = 0; d < config.sparse.size(); ++d) { + ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name), + {d, Type::Sparse}); + } + if (ok) break; + LOG(WARNING) << "Collision found. This should happen only if you have " + "around 2^32 entries in your config."; + hasher.seed++; + config_index.Clear(config_size); + } + if (!ok) { + return errors::Internal( + "Could not avoid collision. This should not happen."); + } + + // Allocate dense output (sparse have to be buffered). + for (size_t d = 0; d < config.dense.size(); ++d) { + TensorShape out_shape; + out_shape.AddDim(serialized.size()); + for (const int64 dim : config.dense[d].shape.dim_sizes()) { + out_shape.AddDim(dim); + } + result->dense_values.emplace_back(config.dense[d].dtype, out_shape); + } + + // This parameter affects performance in a big and data-dependent way. + const size_t kMiniBatchSizeBytes = 100000; + + // Split examples into mini-batches for parallel processing. + auto first_example_of_minibatch = [&] { + std::vector<size_t> result; + size_t minibatch_bytes = 0; + for (size_t i = 0; i < serialized.size(); i++) { + if (minibatch_bytes == 0) { // start minibatch + result.push_back(i); + } + minibatch_bytes += serialized[i].size() + 1; + if (minibatch_bytes > kMiniBatchSizeBytes) { + minibatch_bytes = 0; + } + } + return result; + }(); + + size_t num_minibatches = first_example_of_minibatch.size(); + + // Do minibatches in parallel. + std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches); + std::vector<Status> status_of_minibatch(num_minibatches); + + auto ProcessMiniBatch = [&](size_t minibatch) { + sparse_buffers[minibatch].resize(config.sparse.size()); + size_t start = first_example_of_minibatch[minibatch]; + size_t end = minibatch + 1 < num_minibatches + ? first_example_of_minibatch[minibatch + 1] + : serialized.size(); + for (size_t e = start; e < end; ++e) { + status_of_minibatch[minibatch] = FastParseSerializedExample( + serialized[e], + (example_names.size() > 0 ? example_names[e] : "<unknown>"), e, + config, config_index, hasher, &result->dense_values, + &sparse_buffers[minibatch]); + if (!status_of_minibatch[minibatch].ok()) break; + } + }; + + ParallelFor(ProcessMiniBatch, num_minibatches, thread_pool); + + for (Status& status : status_of_minibatch) { + TF_RETURN_IF_ERROR(status); + } + + // Merge SparseBuffers from all minibatches for every config.sparse. + auto MergeMinibatches = [&](size_t d) { + // Loop over minibatches + size_t total_num_features = 0; + size_t max_num_features = 0; + for (auto& sparse_values_tmp : sparse_buffers) { + std::vector<size_t>& end_indices = + sparse_values_tmp[d].example_end_indices; + total_num_features += end_indices.back(); + max_num_features = std::max(max_num_features, end_indices[0]); + for (size_t i = 1; i < end_indices.size(); ++i) { + size_t example_size = end_indices[i] - end_indices[i - 1]; + max_num_features = std::max(max_num_features, example_size); + } + } + + TensorShape indices_shape; + indices_shape.AddDim(total_num_features); + indices_shape.AddDim(2); + result->sparse_indices.emplace_back(DT_INT64, indices_shape); + Tensor* indices = &result->sparse_indices.back(); + + TensorShape values_shape; + values_shape.AddDim(total_num_features); + result->sparse_values.emplace_back(config.sparse[d].dtype, values_shape); + Tensor* values = &result->sparse_values.back(); + + result->sparse_shapes.emplace_back(DT_INT64, TensorShape({2})); + auto shapes_shape_t = result->sparse_shapes.back().vec<int64>(); + shapes_shape_t(0) = serialized.size(); + shapes_shape_t(1) = max_num_features; + + size_t offset = 0; + for (size_t i = 0; i < sparse_buffers.size(); ++i) { + const SparseBuffer& buffer = sparse_buffers[i][d]; + + // Update indices. + int64* ix_p = &indices->matrix<int64>()(offset, 0); + size_t delta = 0; + size_t example_index = first_example_of_minibatch[i]; + for (size_t example_end_index : buffer.example_end_indices) { + size_t feature_index = 0; + for (; delta < example_end_index; ++delta) { + // Column 0: example index + *ix_p = example_index; + // Column 1: the feature index buffer example + *(ix_p + 1) = feature_index; + ix_p += 2; + ++feature_index; + } + ++example_index; + } + + // Copy values over. + switch (config.sparse[d].dtype) { + case DT_INT64: { + std::copy(buffer.int64_list.begin(), buffer.int64_list.end(), + values->flat<int64>().data() + offset); + break; + } + case DT_FLOAT: { + std::copy(buffer.float_list.begin(), buffer.float_list.end(), + values->flat<float>().data() + offset); + break; + } + case DT_STRING: { + std::move(buffer.bytes_list.begin(), buffer.bytes_list.end(), + values->flat<string>().data() + offset); + break; + } + default: + CHECK(false) << "Should not happen."; + } + + offset += delta; + } + }; + + for (size_t d = 0; d < config.sparse.size(); ++d) { + MergeMinibatches(d); + } + + return Status::OK(); +} + +} // namespace example +} // namespace tensorflow diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h new file mode 100644 index 0000000000..6ed9d57838 --- /dev/null +++ b/tensorflow/core/util/example_proto_fast_parsing.h @@ -0,0 +1,88 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_ + +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +namespace tensorflow { +namespace example { + +// FastParseExampleConfig defines how to parse features in Example. +// Each sub-config is responsible for one feature identified with feautre_name. +// FastParseExampleConfig can't have two sub-configs with the same feature_name. +// dtype identifies the type of output vector and the kind of Feature expected +// in Example. +struct FastParseExampleConfig { + struct Dense { + string feature_name; + DataType dtype; + // These 2 fields correspond exactly to dense_shapes and dense_defaults in + // ParseExample op. + // Documentation is avaliable in: tensorflow/core/ops/parsing_ops.cc + TensorShape shape; + Tensor default_value; + }; + + struct Sparse { + string feature_name; + DataType dtype; + }; + + std::vector<Dense> dense; + std::vector<Sparse> sparse; +}; + +// This is exactly the output of TF's ParseExample Op. +// Documentation is avaliable in: tensorflow/core/ops/parsing_ops.cc +struct Result { + std::vector<Tensor> sparse_indices; + std::vector<Tensor> sparse_values; + std::vector<Tensor> sparse_shapes; + std::vector<Tensor> dense_values; +}; + +// Parses a batch of serialized Example protos and converts them into result +// according to given config. +// Given example names have to either be empty or the same size as serialized. +// example_names are used only for error messages. +Status FastParseExample(const FastParseExampleConfig& config, + gtl::ArraySlice<string> serialized, + gtl::ArraySlice<string> example_names, + thread::ThreadPool* thread_pool, Result* result); + +// This function parses serialized Example and populates given example. +// It uses the same specialized parser as FastParseExample which is efficient. +// But then constructs Example which is relatively slow. +// It is exported here as a convenient API to test parser part separately. +bool TestFastParse(const string& serialized, Example* example); + +} // namespace example +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_ diff --git a/tensorflow/core/util/example_proto_fast_parsing_test.cc b/tensorflow/core/util/example_proto_fast_parsing_test.cc new file mode 100644 index 0000000000..6d3b548851 --- /dev/null +++ b/tensorflow/core/util/example_proto_fast_parsing_test.cc @@ -0,0 +1,184 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/util/example_proto_fast_parsing.h" + +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace example { +namespace { + +constexpr char kDenseInt64Key[] = "dense_int64"; +constexpr char kDenseFloatKey[] = "dense_float"; +constexpr char kDenseStringKey[] = "dense_string"; + +constexpr char kSparseInt64Key[] = "sparse_int64"; +constexpr char kSparseFloatKey[] = "sparse_float"; +constexpr char kSparseStringKey[] = "sparse_string"; + +string SerializedToReadable(string serialized) { + string result; + result += '"'; + for (char c : serialized) + result += strings::StrCat("\\x", strings::Hex(c, strings::ZERO_PAD_2)); + result += '"'; + return result; +} + +string Serialize(const Example& example) { + string serialized; + example.SerializeToString(&serialized); + return serialized; +} + +void TestCorrectness(const string& serialized) { + Example example; + Example fast_example; + EXPECT_TRUE(example.ParseFromString(serialized)); + EXPECT_TRUE(TestFastParse(serialized, &fast_example)); + EXPECT_EQ(example.DebugString(), fast_example.DebugString()); + if (example.DebugString() != fast_example.DebugString()) { + LOG(ERROR) << "Bad serialized: " << SerializedToReadable(serialized); + } +} + +// Fast parsing does not differentiate between EmptyExample and EmptyFeatures +// TEST(FastParse, EmptyExample) { +// Example example; +// TestCorrectness(example); +// } + +TEST(FastParse, NonPacked) { + TestCorrectness( + "\x0a\x0e\x0a\x0c\x0a\x03\x61\x67\x65\x12\x05\x1a\x03\x0a\x01\x0d"); +} + +TEST(FastParse, Packed) { + TestCorrectness( + "\x0a\x0d\x0a\x0b\x0a\x03\x61\x67\x65\x12\x04\x1a\x02\x08\x0d"); +} + +TEST(FastParse, EmptyFeatures) { + Example example; + example.mutable_features(); + TestCorrectness(Serialize(example)); +} + +void TestCorrectnessJson(const string& json) { + auto resolver = protobuf::util::NewTypeResolverForDescriptorPool( + "type.googleapis.com", protobuf::DescriptorPool::generated_pool()); + string serialized; + auto s = protobuf::util::JsonToBinaryString( + resolver, "type.googleapis.com/tensorflow.Example", json, &serialized); + EXPECT_TRUE(s.ok()) << s; + delete resolver; + TestCorrectness(serialized); +} + +TEST(FastParse, JsonUnivalent) { + TestCorrectnessJson( + "{'features': {" + " 'feature': {'age': {'int64_list': {'value': [0]} }}, " + " 'feature': {'flo': {'float_list': {'value': [1.1]} }}, " + " 'feature': {'byt': {'bytes_list': {'value': ['WW8='] }}}" + "}}"); +} + +TEST(FastParse, JsonMultivalent) { + TestCorrectnessJson( + "{'features': {" + " 'feature': {'age': {'int64_list': {'value': [0, 13, 23]} }}, " + " 'feature': {'flo': {'float_list': {'value': [1.1, 1.2, 1.3]} }}, " + " 'feature': {'byt': {'bytes_list': {'value': ['WW8=', 'WW8K'] }}}" + "}}"); +} + +TEST(FastParse, SingleInt64) { + Example example; + (*example.mutable_features()->mutable_feature())["age"] + .mutable_int64_list() + ->add_value(13); + TestCorrectness(Serialize(example)); +} + +TEST(FastParse, SomeFeatures) { + Example example; + + (*example.mutable_features()->mutable_feature())[""]; + + (*example.mutable_features()->mutable_feature())["empty_bytes_list"] + .mutable_bytes_list(); + (*example.mutable_features()->mutable_feature())["empty_float_list"] + .mutable_float_list(); + (*example.mutable_features()->mutable_feature())["empty_int64_list"] + .mutable_int64_list(); + + BytesList* bytes_list = + (*example.mutable_features()->mutable_feature())["bytes_list"] + .mutable_bytes_list(); + bytes_list->add_value("bytes1"); + bytes_list->add_value("bytes2"); + + FloatList* float_list = + (*example.mutable_features()->mutable_feature())["float_list"] + .mutable_float_list(); + float_list->add_value(1.0); + float_list->add_value(2.0); + + Int64List* int64_list = + (*example.mutable_features()->mutable_feature())["int64_list"] + .mutable_int64_list(); + int64_list->add_value(3); + int64_list->add_value(270); + int64_list->add_value(86942); + + TestCorrectness(Serialize(example)); +} + +string MakeSerializedExample() { + Example example; + const int kFeatureNameLength = 10; + const int kFeatureValueLength = 20; + const int kBytesFeatureCount = 200; + const int kFloatFeatureCount = 200; + const int kInt64FeatureCount = 200; + auto& fmap = *example.mutable_features()->mutable_feature(); + for (int i = 0; i < kBytesFeatureCount; ++i) { + fmap[strings::StrCat(string('b', kFeatureNameLength), i)] + .mutable_bytes_list() + ->add_value(string('v', kFeatureValueLength)); + } + for (int i = 0; i < kFloatFeatureCount; ++i) { + fmap[strings::StrCat(string('f', kFeatureNameLength), i)] + .mutable_float_list() + ->add_value(123123123.123); + } + for (int i = 0; i < kInt64FeatureCount; ++i) { + fmap[strings::StrCat(string('i', kFeatureNameLength), i)] + .mutable_int64_list() + ->add_value(10 * i); + } + string serialized; + example.SerializeToString(&serialized); + return serialized; +} + +} // namespace + +} // namespace example +} // namespace tensorflow diff --git a/tensorflow/core/util/presized_cuckoo_map.h b/tensorflow/core/util/presized_cuckoo_map.h index b488d32e03..5244ab693a 100644 --- a/tensorflow/core/util/presized_cuckoo_map.h +++ b/tensorflow/core/util/presized_cuckoo_map.h @@ -50,7 +50,10 @@ class PresizedCuckooMap { // The key type is fixed as a pre-hashed key for this specialized use. typedef uint64 key_type; - explicit PresizedCuckooMap(uint64 num_entries) : cpq_(new CuckooPathQueue) { + explicit PresizedCuckooMap(uint64 num_entries) { Clear(num_entries); } + + void Clear(uint64 num_entries) { + cpq_.reset(new CuckooPathQueue()); double n(num_entries); n /= kLoadFactor; num_buckets_ = (static_cast<uint64>(n) / kSlotsPerBucket); @@ -62,6 +65,7 @@ class PresizedCuckooMap { for (int i = 0; i < kSlotsPerBucket; i++) { empty_bucket.keys[i] = kUnusedSlot; } + buckets_.clear(); buckets_.resize(num_buckets_, empty_bucket); #if !defined(__GCUDACC__) && !defined(__GCUDACC_HOST__) buckets_divisor_ = Eigen::internal::TensorIntDivisor<uint64>(num_buckets_); @@ -317,7 +321,7 @@ class PresizedCuckooMap { std::vector<Bucket> buckets_; Eigen::internal::TensorIntDivisor<uint64> buckets_divisor_; // for fast mod - const std::unique_ptr<CuckooPathQueue> cpq_; + std::unique_ptr<CuckooPathQueue> cpq_; CuckooPathEntry visited_[kVisitedListSize]; TF_DISALLOW_COPY_AND_ASSIGN(PresizedCuckooMap); diff --git a/tensorflow/core/util/presized_cuckoo_map_test.cc b/tensorflow/core/util/presized_cuckoo_map_test.cc index 64ad315518..fe8e5dcfbd 100644 --- a/tensorflow/core/util/presized_cuckoo_map_test.cc +++ b/tensorflow/core/util/presized_cuckoo_map_test.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/util/presized_cuckoo_map.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" -#include "tensorflow/core/util/presized_cuckoo_map.h" namespace tensorflow { namespace { @@ -64,6 +64,22 @@ TEST(PresizedCuckooMapTest, ZeroSizeMap) { } } +TEST(PresizedCuckooMapTest, RepeatedClear) { + PresizedCuckooMap<int> pscm(2); + int out; + for (int i = 0; i < 100; ++i) { + pscm.InsertUnique(0, 0); + pscm.InsertUnique(1, 1); + EXPECT_TRUE(pscm.Find(0, &out)); + EXPECT_EQ(0, out); + EXPECT_TRUE(pscm.Find(1, &out)); + EXPECT_EQ(1, out); + pscm.Clear(2); + EXPECT_FALSE(pscm.Find(0, &out)); + EXPECT_FALSE(pscm.Find(1, &out)); + } +} + void RunFill(int64 table_size) { PresizedCuckooMap<int> pscm(table_size); for (int64 i = 0; i < table_size; i++) { |