aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-19 10:44:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-19 11:48:07 -0700
commit0b9f0f53ddbf693bb30afb211a6d514a1fce1c22 (patch)
treefcc68dbf0d05ba6e093b58274645aaaa9d9b7c39 /tensorflow
parent859e47fd8ebcdd7fb1411fd0090d0e95a801e7cb (diff)
Implement fast ParseExample.
Change: 130775324
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/core/kernels/example_parsing_ops.cc151
-rw-r--r--tensorflow/core/kernels/example_parsing_ops_test.cc31
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.cc761
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.h88
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing_test.cc184
-rw-r--r--tensorflow/core/util/presized_cuckoo_map.h8
-rw-r--r--tensorflow/core/util/presized_cuckoo_map_test.cc18
8 files changed, 1102 insertions, 141 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 224badc7d3..60c5976b9b 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -266,6 +266,7 @@ tf_cuda_library(
"util/cuda_kernel_helper.h",
"util/device_name_utils.h",
"util/events_writer.h",
+ "util/example_proto_fast_parsing.h",
"util/example_proto_helper.h",
"util/guarded_philox_random.h",
"util/memmapped_file_system.h",
@@ -1459,6 +1460,7 @@ tf_cc_tests(
"util/command_line_flags_test.cc",
"util/device_name_utils_test.cc",
"util/events_writer_test.cc",
+ "util/example_proto_fast_parsing_test.cc",
"util/example_proto_helper_test.cc",
"util/memmapped_file_system_test.cc",
"util/presized_cuckoo_map_test.cc",
diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc
index a7091645fa..6338638a3b 100644
--- a/tensorflow/core/kernels/example_parsing_ops.cc
+++ b/tensorflow/core/kernels/example_parsing_ops.cc
@@ -24,8 +24,10 @@ limitations under the License.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/util/example_proto_fast_parsing.h"
#include "tensorflow/core/util/example_proto_helper.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
#include "tensorflow/core/util/work_sharder.h"
@@ -67,8 +69,7 @@ class ExampleParserOp : public OpKernel {
sparse_keys_t[di] = sparse_keys[di].scalar<string>()();
}
- bool has_names = (names->NumElements() > 0);
- if (has_names) {
+ if (names->NumElements() > 0) {
OP_REQUIRES(
ctx, TensorShapeUtils::IsVector(names->shape()),
errors::InvalidArgument("Expected names to be a vector, got shape: ",
@@ -79,7 +80,6 @@ class ExampleParserOp : public OpKernel {
"Expected len(names) == len(serialized), but got: ",
names->NumElements(), " vs. ", serialized->NumElements()));
}
- auto names_t = names->flat<string>();
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(serialized->shape()),
errors::InvalidArgument(
@@ -106,136 +106,43 @@ class ExampleParserOp : public OpKernel {
}
}
- auto serialized_t = serialized->vec<string>();
+ example::Result result;
+
+ example::FastParseExampleConfig config;
+ for (int d = 0; d < attrs_.num_dense; ++d) {
+ config.dense.push_back({dense_keys_t[d], attrs_.dense_types[d],
+ attrs_.dense_shapes[d], dense_defaults[d]});
+ }
+ for (int d = 0; d < attrs_.num_sparse; ++d) {
+ config.sparse.push_back({sparse_keys_t[d], attrs_.sparse_types[d]});
+ }
+
+ auto serialized_t = serialized->flat<string>();
+ auto names_t = names->flat<string>();
+ gtl::ArraySlice<string> slice(serialized_t.data(), serialized_t.size());
+ gtl::ArraySlice<string> names_slice(names_t.data(), names_t.size());
- const int64 batch_size = serialized_t.size();
+ OP_REQUIRES_OK(
+ ctx,
+ FastParseExample(
+ config, slice, names_slice,
+ ctx->device()->tensorflow_cpu_worker_threads()->workers, &result));
+ OpOutputList dense_values;
OpOutputList sparse_indices;
OpOutputList sparse_values;
OpOutputList sparse_shapes;
- OpOutputList dense_values;
-
+ OP_REQUIRES_OK(ctx, ctx->output_list("dense_values", &dense_values));
OP_REQUIRES_OK(ctx, ctx->output_list("sparse_indices", &sparse_indices));
OP_REQUIRES_OK(ctx, ctx->output_list("sparse_values", &sparse_values));
OP_REQUIRES_OK(ctx, ctx->output_list("sparse_shapes", &sparse_shapes));
- OP_REQUIRES_OK(ctx, ctx->output_list("dense_values", &dense_values));
-
- // Setup Dense features and the output_dense_values Tensor* vector.
- std::vector<FixedLenFeature> fixed_len_features(attrs_.num_dense);
- std::vector<Tensor*> output_dense_values(attrs_.num_dense);
-
for (int d = 0; d < attrs_.num_dense; ++d) {
- // Preallocate dense_values, since we know their sizes
- TensorShape out_shape;
- out_shape.AddDim(batch_size);
- for (const int64 dim : attrs_.dense_shapes[d].dim_sizes()) {
- out_shape.AddDim(dim);
- }
- Tensor* out = nullptr;
- dense_values.allocate(d, out_shape, &out);
-
- FixedLenFeature config;
- config.key = dense_keys_t[d];
- config.dtype = attrs_.dense_types[d];
- config.shape = attrs_.dense_shapes[d];
- config.default_value = dense_defaults[d];
- fixed_len_features[d] = config;
- output_dense_values[d] = dense_values[d];
+ dense_values.set(d, result.dense_values[d]);
}
-
- // sparse_values_tmp will be attrs_.num_sparse size map of batch_size length
- // tensor vector's, containing the sparse values from the input layer.
- // After these are all stored, we can allocate properly sized outputs
- // and copy data over. Doing it this way saves us the trouble of either
- // performing deserialization twice, or alternatively storing all copies of
- // the full Example protos.
- std::vector<std::vector<Tensor>> sparse_values_tmp(
- attrs_.num_sparse, std::vector<Tensor>(batch_size));
-
- // Setup Sparse features.
- std::vector<VarLenFeature> var_len_features(attrs_.num_sparse);
for (int d = 0; d < attrs_.num_sparse; ++d) {
- VarLenFeature config;
- config.key = sparse_keys_t[d];
- config.dtype = attrs_.sparse_types[d];
- var_len_features[d] = config;
- }
-
- auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
-
- // Estimate the cost of parsing each batch element.
- int64 work_unit_size = 1000 + 100 * attrs_.num_sparse;
- for (int d = 0; d < attrs_.num_dense; ++d) {
- work_unit_size += 100 + attrs_.dense_shapes[d].num_elements();
- }
-
- mutex mu;
-
- auto DoWork = [&ctx, &mu, &serialized_t, has_names, &names_t,
- &fixed_len_features, &var_len_features, &output_dense_values,
- &sparse_values_tmp](int64 start, int64 limit) {
- // Processing each Example in the batch starts here.
- for (std::size_t b = static_cast<size_t>(start);
- b < static_cast<size_t>(limit); ++b) {
- // Benchmarks indicate that a tight Arena+Example is most performant.
- protobuf::Arena arena;
- // ex is owned by the arena.
- Example* ex = protobuf::Arena::CreateMessage<Example>(&arena);
- bool parse_success = ParseProtoUnlimited(ex, serialized_t(b));
- if (!TF_PREDICT_TRUE(parse_success)) {
- mutex_lock l(mu);
- ctx->CtxFailure(errors::InvalidArgument(
- "Could not parse example input, value: '", serialized_t(b), "'"));
- return;
- }
- const string& example_name = (has_names) ? names_t(b) : "<unknown>";
- Status s = SingleExampleProtoToTensors(
- *ex, example_name, b, fixed_len_features, var_len_features,
- &output_dense_values, &sparse_values_tmp);
- if (!TF_PREDICT_TRUE(s.ok())) {
- mutex_lock l(mu);
- ctx->CtxFailureWithWarning(s);
- }
- }
- };
-
- Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
- work_unit_size, DoWork);
-
- if (!TF_PREDICT_TRUE(ctx->status().ok())) {
- return;
- }
-
- // Copy from sparse_values_tmp into final resting Tensors
- // -------------------------
- for (int d = 0; d < attrs_.num_sparse; ++d) {
- const VarLenFeature& feature_config = var_len_features[d];
- const std::vector<Tensor>& sparse_values_tmp_tensors =
- sparse_values_tmp[d];
- VarLenFeatureBatchShapes sparse_tensor_batch_shapes;
- GetSparseTensorShapes(feature_config, sparse_values_tmp_tensors,
- batch_size, &sparse_tensor_batch_shapes);
-
- Tensor* sp_indices_d = nullptr;
- Tensor* sp_values_d = nullptr;
- Tensor* sp_shape_d = nullptr;
-
- sparse_indices.allocate(d, sparse_tensor_batch_shapes.indices_shape,
- &sp_indices_d);
- sparse_values.allocate(d, sparse_tensor_batch_shapes.values_shape,
- &sp_values_d);
- sparse_shapes.allocate(d, TensorShape({2}), &sp_shape_d);
-
- auto shape_t = sp_shape_d->vec<int64>();
- shape_t(0) = batch_size;
- shape_t(1) = sparse_tensor_batch_shapes.max_num_features;
-
- int64 offset = 0;
- for (int b = 0; b < batch_size; ++b) {
- const int64 num_elements = CopyIntoSparseTensor(
- sparse_values_tmp_tensors[b], b, offset, sp_indices_d, sp_values_d);
- offset += num_elements;
- }
+ sparse_indices.set(d, result.sparse_indices[d]);
+ sparse_values.set(d, result.sparse_values[d]);
+ sparse_shapes.set(d, result.sparse_shapes[d]);
}
}
diff --git a/tensorflow/core/kernels/example_parsing_ops_test.cc b/tensorflow/core/kernels/example_parsing_ops_test.cc
index e58cecff14..187d72685e 100644
--- a/tensorflow/core/kernels/example_parsing_ops_test.cc
+++ b/tensorflow/core/kernels/example_parsing_ops_test.cc
@@ -161,7 +161,7 @@ static Graph* ParseExample(int batch_size, int num_keys) {
// Benchmark settings (Sparse, Dense) X (Bytes, Int64, Float)
typedef BenchmarkOptions<ExampleStore<BytesFiller>, false> SparseString;
typedef BenchmarkOptions<ExampleStore<BytesFiller>, true> DenseString;
-typedef BenchmarkOptions<ExampleStore<Int64Filler>, false> SparseIn64;
+typedef BenchmarkOptions<ExampleStore<Int64Filler>, false> SparseInt64;
typedef BenchmarkOptions<ExampleStore<Int64Filler>, true> DenseInt64;
typedef BenchmarkOptions<ExampleStore<FloatFiller>, false> SparseFloat;
typedef BenchmarkOptions<ExampleStore<FloatFiller>, true> DenseFloat;
@@ -176,20 +176,19 @@ typedef BenchmarkOptions<ExampleStore<FloatFiller>, true> DenseFloat;
} \
BENCHMARK(BM_ParseExample##_##TYPE##_##B##_##K);
-#define BM_AllParseExample(B, K) \
- BM_ParseExample(SparseString, B, K); \
- BM_ParseExample(DenseString, B, K); \
- BM_ParseExample(SparseIn64, B, K); \
- BM_ParseExample(DenseInt64, B, K); \
- BM_ParseExample(SparseFloat, B, K); \
- BM_ParseExample(DenseFloat, B, K);
-
-BM_AllParseExample(128, 10);
-BM_AllParseExample(128, 100);
-BM_AllParseExample(128, 1000);
-
-BM_AllParseExample(512, 10);
-BM_AllParseExample(512, 100);
-BM_AllParseExample(512, 1000);
+#define BM_AllParseExample(Type) \
+ BM_ParseExample(Type, 128, 10); \
+ BM_ParseExample(Type, 512, 10); \
+ BM_ParseExample(Type, 128, 100); \
+ BM_ParseExample(Type, 512, 100); \
+ BM_ParseExample(Type, 128, 1000); \
+ BM_ParseExample(Type, 512, 1000);
+
+BM_AllParseExample(SparseString);
+BM_AllParseExample(DenseString);
+BM_AllParseExample(SparseInt64);
+BM_AllParseExample(DenseInt64);
+BM_AllParseExample(SparseFloat);
+BM_AllParseExample(DenseFloat);
} // end namespace tensorflow
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc
new file mode 100644
index 0000000000..fee6b8885f
--- /dev/null
+++ b/tensorflow/core/util/example_proto_fast_parsing.cc
@@ -0,0 +1,761 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/util/example_proto_fast_parsing.h"
+
+#include <vector>
+
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/example/feature.pb_text.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/util/presized_cuckoo_map.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+namespace tensorflow {
+namespace example {
+
+namespace {
+template <typename A>
+auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
+ a->EnableAliasing(true);
+}
+
+template <typename A>
+void EnableAliasing(A&& a) {}
+
+uint8 PeekTag(protobuf::io::CodedInputStream* stream) {
+ DCHECK(stream != nullptr);
+ const void* ptr;
+ int size;
+ if (!stream->GetDirectBufferPointer(&ptr, &size)) return 0;
+ return *static_cast<const uint8*>(ptr);
+}
+
+constexpr uint8 kVarintTag(uint tag) { return (tag << 3) | 0; }
+constexpr uint8 kDelimitedTag(uint tag) { return (tag << 3) | 2; }
+constexpr uint8 kFixed32Tag(uint tag) { return (tag << 3) | 5; }
+
+namespace parsed {
+
+// ParseDataType has to be called first, then appropriate ParseZzzzList.
+class Feature {
+ public:
+ Feature() {}
+ Feature(StringPiece serialized) : serialized_(serialized) {}
+
+ Status ParseDataType(DataType* dtype) {
+ DCHECK(dtype != nullptr);
+ if (serialized_.empty()) {
+ *dtype = DT_INVALID;
+ return Status::OK();
+ }
+ uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
+ serialized_.remove_prefix(1);
+ switch (oneof_tag) {
+ case kDelimitedTag(1):
+ *dtype = DT_STRING;
+ break;
+ case kDelimitedTag(2):
+ *dtype = DT_FLOAT;
+ break;
+ case kDelimitedTag(3):
+ *dtype = DT_INT64;
+ break;
+ default:
+ return errors::InvalidArgument("Unsuported datatype.");
+ }
+ return Status::OK();
+ }
+
+ bool ParseBytesList(std::vector<string>* bytes_list) {
+ DCHECK(bytes_list != nullptr);
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
+
+ EnableAliasing(&stream);
+
+ uint32 length;
+ if (!stream.ReadVarint32(&length)) return false;
+ auto limit = stream.PushLimit(length);
+
+ while (!stream.ExpectAtEnd()) {
+ if (!stream.ExpectTag(kDelimitedTag(1))) return false;
+ // parse string
+ uint32 bytes_length;
+ if (!stream.ReadVarint32(&bytes_length)) return false;
+ string bytes;
+ if (!stream.ReadString(&bytes, bytes_length)) return false;
+ bytes_list->push_back(std::move(bytes));
+ }
+ stream.PopLimit(limit);
+ return true;
+ }
+
+ bool ParseFloatList(std::vector<float>* float_list) {
+ DCHECK(float_list != nullptr);
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
+ EnableAliasing(&stream);
+ uint32 length;
+ if (!stream.ReadVarint32(&length)) return false;
+ auto limit = stream.PushLimit(length);
+
+ if (!stream.ExpectAtEnd()) {
+ uint8 peek_tag = PeekTag(&stream);
+ if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
+ return false;
+ }
+
+ if (peek_tag == kDelimitedTag(1)) { // packed
+ if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
+ uint32 packed_length;
+ if (!stream.ReadVarint32(&packed_length)) return false;
+ auto packed_limit = stream.PushLimit(packed_length);
+
+ while (!stream.ExpectAtEnd()) {
+ uint32 buffer32;
+ if (!stream.ReadLittleEndian32(&buffer32)) return false;
+ float_list->push_back(bit_cast<float>(buffer32));
+ }
+
+ stream.PopLimit(packed_limit);
+ } else { // non-packed
+ while (!stream.ExpectAtEnd()) {
+ if (!stream.ExpectTag(kFixed32Tag(1))) return false;
+ uint32 buffer32;
+ if (!stream.ReadLittleEndian32(&buffer32)) return false;
+ float_list->push_back(bit_cast<float>(buffer32));
+ }
+ }
+ }
+
+ stream.PopLimit(limit);
+ return true;
+ }
+
+ bool ParseInt64List(std::vector<int64>* int64_list) {
+ DCHECK(int64_list != nullptr);
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
+ EnableAliasing(&stream);
+ uint32 length;
+ if (!stream.ReadVarint32(&length)) return false;
+ auto limit = stream.PushLimit(length);
+
+ if (!stream.ExpectAtEnd()) {
+ uint8 peek_tag = PeekTag(&stream);
+ if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
+ return false;
+ }
+ if (peek_tag == kDelimitedTag(1)) { // packed
+ if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
+ uint32 packed_length;
+ if (!stream.ReadVarint32(&packed_length)) return false;
+ auto packed_limit = stream.PushLimit(packed_length);
+
+ while (!stream.ExpectAtEnd()) {
+ uint64 n; // There is no API for int64
+ if (!stream.ReadVarint64(&n)) return false;
+ int64_list->push_back(n);
+ }
+
+ stream.PopLimit(packed_limit);
+ } else { // non-packed
+ while (!stream.ExpectAtEnd()) {
+ if (!stream.ExpectTag(kVarintTag(1))) return false;
+ uint64 n; // There is no API for int64
+ if (!stream.ReadVarint64(&n)) return false;
+ int64_list->push_back(n);
+ }
+ }
+ }
+ stream.PopLimit(limit);
+ return true;
+ }
+
+ StringPiece GetSerialized() const { return serialized_; }
+
+ private:
+ // TODO(lew): Pair of uint8* would be more natural.
+ StringPiece serialized_;
+};
+
+using FeatureMapEntry = std::pair<StringPiece, Feature>;
+using Example = std::vector<FeatureMapEntry>;
+
+} // namespace parsed
+
+bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) {
+ DCHECK(stream != nullptr);
+ DCHECK(result != nullptr);
+ uint32 length;
+ if (!stream->ReadVarint32(&length)) return false;
+ if (length == 0) {
+ *result = StringPiece(nullptr, 0);
+ return true;
+ }
+ const void* stream_alias;
+ int stream_size;
+ if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) {
+ return false;
+ }
+ if (static_cast<uint32>(stream_size) < length) return false;
+ *result = StringPiece(static_cast<const char*>(stream_alias), length);
+ stream->Skip(length);
+ return true;
+}
+
+bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
+ parsed::FeatureMapEntry* feature_map_entry) {
+ DCHECK(stream != nullptr);
+ DCHECK(feature_map_entry != nullptr);
+ uint32 length;
+ if (!stream->ReadVarint32(&length)) return false;
+ auto limit = stream->PushLimit(length);
+ if (!stream->ExpectTag(kDelimitedTag(1))) return false;
+ if (!ParseString(stream, &feature_map_entry->first)) return false;
+ if (!stream->ExpectTag(kDelimitedTag(2))) return false;
+ StringPiece feature_string_piece;
+ if (!ParseString(stream, &feature_string_piece)) return false;
+ feature_map_entry->second = parsed::Feature(feature_string_piece);
+ if (!stream->ExpectAtEnd()) return false;
+ stream->PopLimit(limit);
+ return true;
+}
+
+bool ParseFeatures(protobuf::io::CodedInputStream* stream,
+ parsed::Example* example) {
+ DCHECK(stream != nullptr);
+ DCHECK(example != nullptr);
+ uint32 length;
+ if (!stream->ReadVarint32(&length)) return false;
+ auto limit = stream->PushLimit(length);
+ while (!stream->ExpectAtEnd()) {
+ parsed::FeatureMapEntry feature_map_entry;
+ if (!stream->ExpectTag(kDelimitedTag(1))) return false;
+ if (!ParseFeatureMapEntry(stream, &feature_map_entry)) return false;
+ example->push_back(std::move(feature_map_entry));
+ }
+ stream->PopLimit(limit);
+ return true;
+}
+
+bool ParseExample(protobuf::io::CodedInputStream* stream,
+ parsed::Example* example) {
+ DCHECK(stream != nullptr);
+ DCHECK(example != nullptr);
+ if (stream->ExpectTag(kDelimitedTag(1))) {
+ if (!ParseFeatures(stream, example)) return false;
+ }
+ if (!stream->ExpectAtEnd()) return false;
+ return true;
+}
+
+bool ParseExample(StringPiece serialized, parsed::Example* example) {
+ DCHECK(example != nullptr);
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(serialized.data()), serialized.size());
+ EnableAliasing(&stream);
+ return ParseExample(&stream, example);
+}
+
+} // namespace
+
+bool TestFastParse(const string& serialized, Example* example) {
+ DCHECK(example != nullptr);
+ parsed::Example parsed_example;
+ if (!ParseExample(serialized, &parsed_example)) return false;
+ auto& features = *example->mutable_features();
+ for (parsed::FeatureMapEntry& entry : parsed_example) {
+ auto& value = (*features.mutable_feature())[entry.first.ToString()];
+ DataType dtype;
+ if (!entry.second.ParseDataType(&dtype).ok()) return false;
+ switch (dtype) {
+ case DT_INVALID:
+ break;
+ case DT_STRING: {
+ std::vector<string> list;
+ if (!entry.second.ParseBytesList(&list)) return false;
+ auto* result_list = value.mutable_bytes_list();
+ for (auto& bytes : list) {
+ result_list->add_value(std::move(bytes));
+ }
+ break;
+ }
+ case DT_FLOAT: {
+ std::vector<float> list;
+ if (!entry.second.ParseFloatList(&list)) return false;
+ auto* result_list = value.mutable_float_list();
+ for (float f : list) {
+ result_list->add_value(f);
+ }
+ break;
+ }
+ case DT_INT64: {
+ std::vector<int64> list;
+ if (!entry.second.ParseInt64List(&list)) return false;
+ auto* result_list = value.mutable_int64_list();
+ for (int64 i : list) {
+ result_list->add_value(i);
+ }
+ break;
+ }
+ default:
+ CHECK(false) << "Should not happen.";
+ }
+ }
+ return true;
+}
+
+// -----------------------------------------------------------------------------
+
+namespace {
+
+using Config = FastParseExampleConfig;
+
+void ParallelFor(const std::function<void(size_t)>& f, size_t n,
+ thread::ThreadPool* thread_pool) {
+ DCHECK(thread_pool != nullptr);
+ if (n == 0) return;
+ BlockingCounter counter(n - 1);
+ for (size_t i = 1; i < n; ++i) {
+ thread_pool->Schedule([i, &f, &counter] {
+ f(i);
+ counter.DecrementCount();
+ });
+ }
+ f(0);
+ counter.Wait();
+}
+
+enum class Type { Sparse, Dense };
+
+struct SparseBuffer {
+ // TODO(lew): Use InlinedVector.
+ // Features are in one of the 3 vectors below depending on config's dtype.
+ // Other 2 vectors remain empty.
+ std::vector<string> bytes_list;
+ std::vector<float> float_list;
+ std::vector<int64> int64_list;
+
+ // Features of example i are elements with indices
+ // from example_end_indices[i-1] to example_end_indices[i]-1 on the
+ // appropriate xxxxx_list
+ std::vector<size_t> example_end_indices;
+};
+
+struct SeededHasher {
+ uint64 operator()(StringPiece s) const {
+ return Hash64(s.data(), s.size(), seed);
+ }
+ uint64 seed{0xDECAFCAFFE};
+};
+
+Status FastParseSerializedExample(
+ const string& serialized_example, const string& example_name,
+ const size_t example_index, const Config& config,
+ const PresizedCuckooMap<std::pair<size_t, Type>>& config_index,
+ SeededHasher hasher, std::vector<Tensor>* output_dense,
+ std::vector<SparseBuffer>* output_sparse) {
+ DCHECK(output_dense != nullptr);
+ DCHECK(output_sparse != nullptr);
+ parsed::Example parsed_example;
+ if (!ParseExample(serialized_example, &parsed_example)) {
+ return errors::InvalidArgument("Could not parse example input, value: '",
+ serialized_example, "'");
+ }
+ constexpr size_t kMax = std::numeric_limits<size_t>::max();
+ std::vector<size_t> sparse_features_found(config.sparse.size(), kMax);
+ std::vector<size_t> dense_features_found(config.dense.size(), kMax);
+
+ // Handle features present in the example.
+ for (parsed::FeatureMapEntry& name_and_feature : parsed_example) {
+ parsed::Feature& feature = name_and_feature.second;
+ std::pair<size_t, Type> d_and_type;
+ uint64 h = hasher(name_and_feature.first);
+ if (!config_index.Find(h, &d_and_type)) continue;
+ size_t d = d_and_type.first;
+
+ auto parse_error = [&](StringPiece feature_name) {
+ return errors::InvalidArgument("Name: ", example_name, ", Key: ",
+ feature_name, ", Index: ", example_index,
+ ". Can't parse serialized Example.");
+ };
+
+ if (d_and_type.second == Type::Dense) {
+ DataType example_dtype;
+ TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
+ if (example_dtype == DT_INVALID) continue;
+
+ dense_features_found[d] = example_index;
+ if (example_dtype != config.dense[d].dtype) {
+ return errors::InvalidArgument(
+ "Name: ", example_name, ", Feature: ", config.dense[d].feature_name,
+ ". Data types don't match. ", "Data type: ",
+ DataTypeString(example_dtype), "Expected type: ",
+ DataTypeString(config.dense[d].dtype));
+ }
+ const string& feature_name = config.dense[d].feature_name;
+ const TensorShape& shape = config.dense[d].shape;
+ Tensor& out = (*output_dense)[d];
+
+ const std::size_t num_elements = shape.num_elements();
+ const std::size_t offset = example_index * num_elements;
+
+ auto shape_error = [&](size_t size, StringPiece type_str) {
+ return errors::InvalidArgument(
+ "Name: ", example_name, ", Key: ", feature_name, ", Index: ",
+ example_index, ". Number of ", type_str,
+ " values != expected. "
+ "Values size: ",
+ size, " but output shape: ", shape.DebugString());
+ };
+
+ switch (config.dense[d].dtype) {
+ case DT_INT64: {
+ std::vector<int64> list;
+ if (!feature.ParseInt64List(&list)) return parse_error(feature_name);
+ if (list.size() != num_elements) {
+ return shape_error(list.size(), "int64");
+ }
+ auto out_p = out.flat<int64>().data() + offset;
+ std::copy_n(list.begin(), list.size(), out_p);
+ break;
+ }
+ case DT_FLOAT: {
+ std::vector<float> list;
+ if (!feature.ParseFloatList(&list)) return parse_error(feature_name);
+ if (list.size() != num_elements) {
+ return shape_error(list.size(), "float");
+ }
+ auto out_p = out.flat<float>().data() + offset;
+ std::copy_n(list.begin(), list.size(), out_p);
+ break;
+ }
+ case DT_STRING: {
+ std::vector<string> list;
+ if (!feature.ParseBytesList(&list)) return parse_error(feature_name);
+ if (list.size() != num_elements) {
+ return shape_error(list.size(), "bytes");
+ }
+ auto out_p = out.flat<string>().data() + offset;
+ for (size_t i = 0; i < list.size(); ++i) {
+ out_p[i] = std::move(list[i]);
+ }
+ break;
+ }
+ default:
+ CHECK(false) << "Should not happen.";
+ }
+ } else {
+ // Handle sparse features.
+ sparse_features_found[d] = example_index;
+ const string& feature_name = config.sparse[d].feature_name;
+ SparseBuffer& out = (*output_sparse)[d];
+ DataType example_dtype;
+ TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
+ if (example_dtype != DT_INVALID &&
+ example_dtype != config.sparse[d].dtype) {
+ return errors::InvalidArgument(
+ "Name: ", example_name, ", Feature: ",
+ config.sparse[d].feature_name, ". Data types don't match. ",
+ "Expected type: ", DataTypeString(config.sparse[d].dtype));
+ }
+
+ switch (config.sparse[d].dtype) {
+ case DT_INT64: {
+ if (example_dtype != DT_INVALID) {
+ if (!feature.ParseInt64List(&out.int64_list)) {
+ return parse_error(feature_name);
+ }
+ }
+ out.example_end_indices.push_back(out.int64_list.size());
+ break;
+ }
+ case DT_FLOAT: {
+ if (example_dtype != DT_INVALID) {
+ if (!feature.ParseFloatList(&out.float_list)) {
+ return parse_error(feature_name);
+ }
+ }
+ out.example_end_indices.push_back(out.float_list.size());
+ break;
+ }
+ case DT_STRING: {
+ if (example_dtype != DT_INVALID) {
+ if (!feature.ParseBytesList(&out.bytes_list)) {
+ return parse_error(feature_name);
+ }
+ }
+ out.example_end_indices.push_back(out.bytes_list.size());
+ break;
+ }
+ default:
+ CHECK(false) << "Should not happen.";
+ }
+ }
+ }
+
+ // Handle missing dense features.
+ for (size_t d = 0; d < config.dense.size(); ++d) {
+ if (dense_features_found[d] == example_index) continue;
+ if (config.dense[d].default_value.NumElements() == 0) {
+ return errors::InvalidArgument("Name: ", example_name, ", Feature: ",
+ config.dense[d].feature_name,
+ " is required but could not be found.");
+ }
+
+ const Tensor& in = config.dense[d].default_value;
+ Tensor& out = (*output_dense)[d];
+ const std::size_t num_elements = in.shape().num_elements();
+ const std::size_t offset = example_index * num_elements;
+
+ switch (config.dense[d].dtype) {
+ case DT_INT64: {
+ std::copy_n(in.flat<int64>().data(), num_elements,
+ out.flat<int64>().data() + offset);
+ break;
+ }
+ case DT_FLOAT: {
+ std::copy_n(in.flat<float>().data(), num_elements,
+ out.flat<float>().data() + offset);
+ break;
+ }
+ case DT_STRING: {
+ std::copy_n(in.flat<string>().data(), num_elements,
+ out.flat<string>().data() + offset);
+ break;
+ }
+ default:
+ CHECK(false) << "Should not happen.";
+ }
+ }
+
+ // Handle missing sparse features.
+ for (size_t d = 0; d < config.sparse.size(); ++d) {
+ if (sparse_features_found[d] == example_index) continue;
+ SparseBuffer& out = (*output_sparse)[d];
+ size_t prev_example_end_index =
+ out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
+ out.example_end_indices.push_back(prev_example_end_index);
+ }
+
+ return Status::OK();
+}
+
+Status CheckConfigDataType(DataType dtype) {
+ switch (dtype) {
+ case DT_INT64:
+ case DT_FLOAT:
+ case DT_STRING:
+ return Status::OK();
+ default:
+ return errors::InvalidArgument("Invalid config dtype: ",
+ DataTypeString(dtype));
+ }
+}
+
+} // namespace
+
+Status FastParseExample(const Config& config,
+ gtl::ArraySlice<string> serialized,
+ gtl::ArraySlice<string> example_names,
+ thread::ThreadPool* thread_pool, Result* result) {
+ DCHECK(thread_pool != nullptr);
+ DCHECK(result != nullptr);
+ // Check config so we can safely CHECK(false) in switches on config.*.dtype
+ for (auto& c : config.sparse) {
+ TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
+ }
+ for (auto& c : config.dense) {
+ TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
+ }
+
+ size_t config_size = config.dense.size() + config.sparse.size();
+ SeededHasher hasher;
+ // Build config index.
+ PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size);
+ bool ok = true;
+ for (size_t i = 0; i < 1000; ++i) {
+ for (size_t d = 0; d < config.dense.size(); ++d) {
+ ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name),
+ {d, Type::Dense});
+ }
+ for (size_t d = 0; d < config.sparse.size(); ++d) {
+ ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name),
+ {d, Type::Sparse});
+ }
+ if (ok) break;
+ LOG(WARNING) << "Collision found. This should happen only if you have "
+ "around 2^32 entries in your config.";
+ hasher.seed++;
+ config_index.Clear(config_size);
+ }
+ if (!ok) {
+ return errors::Internal(
+ "Could not avoid collision. This should not happen.");
+ }
+
+ // Allocate dense output (sparse have to be buffered).
+ for (size_t d = 0; d < config.dense.size(); ++d) {
+ TensorShape out_shape;
+ out_shape.AddDim(serialized.size());
+ for (const int64 dim : config.dense[d].shape.dim_sizes()) {
+ out_shape.AddDim(dim);
+ }
+ result->dense_values.emplace_back(config.dense[d].dtype, out_shape);
+ }
+
+ // This parameter affects performance in a big and data-dependent way.
+ const size_t kMiniBatchSizeBytes = 100000;
+
+ // Split examples into mini-batches for parallel processing.
+ auto first_example_of_minibatch = [&] {
+ std::vector<size_t> result;
+ size_t minibatch_bytes = 0;
+ for (size_t i = 0; i < serialized.size(); i++) {
+ if (minibatch_bytes == 0) { // start minibatch
+ result.push_back(i);
+ }
+ minibatch_bytes += serialized[i].size() + 1;
+ if (minibatch_bytes > kMiniBatchSizeBytes) {
+ minibatch_bytes = 0;
+ }
+ }
+ return result;
+ }();
+
+ size_t num_minibatches = first_example_of_minibatch.size();
+
+ // Do minibatches in parallel.
+ std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches);
+ std::vector<Status> status_of_minibatch(num_minibatches);
+
+ auto ProcessMiniBatch = [&](size_t minibatch) {
+ sparse_buffers[minibatch].resize(config.sparse.size());
+ size_t start = first_example_of_minibatch[minibatch];
+ size_t end = minibatch + 1 < num_minibatches
+ ? first_example_of_minibatch[minibatch + 1]
+ : serialized.size();
+ for (size_t e = start; e < end; ++e) {
+ status_of_minibatch[minibatch] = FastParseSerializedExample(
+ serialized[e],
+ (example_names.size() > 0 ? example_names[e] : "<unknown>"), e,
+ config, config_index, hasher, &result->dense_values,
+ &sparse_buffers[minibatch]);
+ if (!status_of_minibatch[minibatch].ok()) break;
+ }
+ };
+
+ ParallelFor(ProcessMiniBatch, num_minibatches, thread_pool);
+
+ for (Status& status : status_of_minibatch) {
+ TF_RETURN_IF_ERROR(status);
+ }
+
+ // Merge SparseBuffers from all minibatches for every config.sparse.
+ auto MergeMinibatches = [&](size_t d) {
+ // Loop over minibatches
+ size_t total_num_features = 0;
+ size_t max_num_features = 0;
+ for (auto& sparse_values_tmp : sparse_buffers) {
+ std::vector<size_t>& end_indices =
+ sparse_values_tmp[d].example_end_indices;
+ total_num_features += end_indices.back();
+ max_num_features = std::max(max_num_features, end_indices[0]);
+ for (size_t i = 1; i < end_indices.size(); ++i) {
+ size_t example_size = end_indices[i] - end_indices[i - 1];
+ max_num_features = std::max(max_num_features, example_size);
+ }
+ }
+
+ TensorShape indices_shape;
+ indices_shape.AddDim(total_num_features);
+ indices_shape.AddDim(2);
+ result->sparse_indices.emplace_back(DT_INT64, indices_shape);
+ Tensor* indices = &result->sparse_indices.back();
+
+ TensorShape values_shape;
+ values_shape.AddDim(total_num_features);
+ result->sparse_values.emplace_back(config.sparse[d].dtype, values_shape);
+ Tensor* values = &result->sparse_values.back();
+
+ result->sparse_shapes.emplace_back(DT_INT64, TensorShape({2}));
+ auto shapes_shape_t = result->sparse_shapes.back().vec<int64>();
+ shapes_shape_t(0) = serialized.size();
+ shapes_shape_t(1) = max_num_features;
+
+ size_t offset = 0;
+ for (size_t i = 0; i < sparse_buffers.size(); ++i) {
+ const SparseBuffer& buffer = sparse_buffers[i][d];
+
+ // Update indices.
+ int64* ix_p = &indices->matrix<int64>()(offset, 0);
+ size_t delta = 0;
+ size_t example_index = first_example_of_minibatch[i];
+ for (size_t example_end_index : buffer.example_end_indices) {
+ size_t feature_index = 0;
+ for (; delta < example_end_index; ++delta) {
+ // Column 0: example index
+ *ix_p = example_index;
+ // Column 1: the feature index buffer example
+ *(ix_p + 1) = feature_index;
+ ix_p += 2;
+ ++feature_index;
+ }
+ ++example_index;
+ }
+
+ // Copy values over.
+ switch (config.sparse[d].dtype) {
+ case DT_INT64: {
+ std::copy(buffer.int64_list.begin(), buffer.int64_list.end(),
+ values->flat<int64>().data() + offset);
+ break;
+ }
+ case DT_FLOAT: {
+ std::copy(buffer.float_list.begin(), buffer.float_list.end(),
+ values->flat<float>().data() + offset);
+ break;
+ }
+ case DT_STRING: {
+ std::move(buffer.bytes_list.begin(), buffer.bytes_list.end(),
+ values->flat<string>().data() + offset);
+ break;
+ }
+ default:
+ CHECK(false) << "Should not happen.";
+ }
+
+ offset += delta;
+ }
+ };
+
+ for (size_t d = 0; d < config.sparse.size(); ++d) {
+ MergeMinibatches(d);
+ }
+
+ return Status::OK();
+}
+
+} // namespace example
+} // namespace tensorflow
diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h
new file mode 100644
index 0000000000..6ed9d57838
--- /dev/null
+++ b/tensorflow/core/util/example_proto_fast_parsing.h
@@ -0,0 +1,88 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+namespace tensorflow {
+namespace example {
+
+// FastParseExampleConfig defines how to parse features in Example.
+// Each sub-config is responsible for one feature identified with feautre_name.
+// FastParseExampleConfig can't have two sub-configs with the same feature_name.
+// dtype identifies the type of output vector and the kind of Feature expected
+// in Example.
+struct FastParseExampleConfig {
+ struct Dense {
+ string feature_name;
+ DataType dtype;
+ // These 2 fields correspond exactly to dense_shapes and dense_defaults in
+ // ParseExample op.
+ // Documentation is avaliable in: tensorflow/core/ops/parsing_ops.cc
+ TensorShape shape;
+ Tensor default_value;
+ };
+
+ struct Sparse {
+ string feature_name;
+ DataType dtype;
+ };
+
+ std::vector<Dense> dense;
+ std::vector<Sparse> sparse;
+};
+
+// This is exactly the output of TF's ParseExample Op.
+// Documentation is avaliable in: tensorflow/core/ops/parsing_ops.cc
+struct Result {
+ std::vector<Tensor> sparse_indices;
+ std::vector<Tensor> sparse_values;
+ std::vector<Tensor> sparse_shapes;
+ std::vector<Tensor> dense_values;
+};
+
+// Parses a batch of serialized Example protos and converts them into result
+// according to given config.
+// Given example names have to either be empty or the same size as serialized.
+// example_names are used only for error messages.
+Status FastParseExample(const FastParseExampleConfig& config,
+ gtl::ArraySlice<string> serialized,
+ gtl::ArraySlice<string> example_names,
+ thread::ThreadPool* thread_pool, Result* result);
+
+// This function parses serialized Example and populates given example.
+// It uses the same specialized parser as FastParseExample which is efficient.
+// But then constructs Example which is relatively slow.
+// It is exported here as a convenient API to test parser part separately.
+bool TestFastParse(const string& serialized, Example* example);
+
+} // namespace example
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_
diff --git a/tensorflow/core/util/example_proto_fast_parsing_test.cc b/tensorflow/core/util/example_proto_fast_parsing_test.cc
new file mode 100644
index 0000000000..6d3b548851
--- /dev/null
+++ b/tensorflow/core/util/example_proto_fast_parsing_test.cc
@@ -0,0 +1,184 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/util/example_proto_fast_parsing.h"
+
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace example {
+namespace {
+
+constexpr char kDenseInt64Key[] = "dense_int64";
+constexpr char kDenseFloatKey[] = "dense_float";
+constexpr char kDenseStringKey[] = "dense_string";
+
+constexpr char kSparseInt64Key[] = "sparse_int64";
+constexpr char kSparseFloatKey[] = "sparse_float";
+constexpr char kSparseStringKey[] = "sparse_string";
+
+string SerializedToReadable(string serialized) {
+ string result;
+ result += '"';
+ for (char c : serialized)
+ result += strings::StrCat("\\x", strings::Hex(c, strings::ZERO_PAD_2));
+ result += '"';
+ return result;
+}
+
+string Serialize(const Example& example) {
+ string serialized;
+ example.SerializeToString(&serialized);
+ return serialized;
+}
+
+void TestCorrectness(const string& serialized) {
+ Example example;
+ Example fast_example;
+ EXPECT_TRUE(example.ParseFromString(serialized));
+ EXPECT_TRUE(TestFastParse(serialized, &fast_example));
+ EXPECT_EQ(example.DebugString(), fast_example.DebugString());
+ if (example.DebugString() != fast_example.DebugString()) {
+ LOG(ERROR) << "Bad serialized: " << SerializedToReadable(serialized);
+ }
+}
+
+// Fast parsing does not differentiate between EmptyExample and EmptyFeatures
+// TEST(FastParse, EmptyExample) {
+// Example example;
+// TestCorrectness(example);
+// }
+
+TEST(FastParse, NonPacked) {
+ TestCorrectness(
+ "\x0a\x0e\x0a\x0c\x0a\x03\x61\x67\x65\x12\x05\x1a\x03\x0a\x01\x0d");
+}
+
+TEST(FastParse, Packed) {
+ TestCorrectness(
+ "\x0a\x0d\x0a\x0b\x0a\x03\x61\x67\x65\x12\x04\x1a\x02\x08\x0d");
+}
+
+TEST(FastParse, EmptyFeatures) {
+ Example example;
+ example.mutable_features();
+ TestCorrectness(Serialize(example));
+}
+
+void TestCorrectnessJson(const string& json) {
+ auto resolver = protobuf::util::NewTypeResolverForDescriptorPool(
+ "type.googleapis.com", protobuf::DescriptorPool::generated_pool());
+ string serialized;
+ auto s = protobuf::util::JsonToBinaryString(
+ resolver, "type.googleapis.com/tensorflow.Example", json, &serialized);
+ EXPECT_TRUE(s.ok()) << s;
+ delete resolver;
+ TestCorrectness(serialized);
+}
+
+TEST(FastParse, JsonUnivalent) {
+ TestCorrectnessJson(
+ "{'features': {"
+ " 'feature': {'age': {'int64_list': {'value': [0]} }}, "
+ " 'feature': {'flo': {'float_list': {'value': [1.1]} }}, "
+ " 'feature': {'byt': {'bytes_list': {'value': ['WW8='] }}}"
+ "}}");
+}
+
+TEST(FastParse, JsonMultivalent) {
+ TestCorrectnessJson(
+ "{'features': {"
+ " 'feature': {'age': {'int64_list': {'value': [0, 13, 23]} }}, "
+ " 'feature': {'flo': {'float_list': {'value': [1.1, 1.2, 1.3]} }}, "
+ " 'feature': {'byt': {'bytes_list': {'value': ['WW8=', 'WW8K'] }}}"
+ "}}");
+}
+
+TEST(FastParse, SingleInt64) {
+ Example example;
+ (*example.mutable_features()->mutable_feature())["age"]
+ .mutable_int64_list()
+ ->add_value(13);
+ TestCorrectness(Serialize(example));
+}
+
+TEST(FastParse, SomeFeatures) {
+ Example example;
+
+ (*example.mutable_features()->mutable_feature())[""];
+
+ (*example.mutable_features()->mutable_feature())["empty_bytes_list"]
+ .mutable_bytes_list();
+ (*example.mutable_features()->mutable_feature())["empty_float_list"]
+ .mutable_float_list();
+ (*example.mutable_features()->mutable_feature())["empty_int64_list"]
+ .mutable_int64_list();
+
+ BytesList* bytes_list =
+ (*example.mutable_features()->mutable_feature())["bytes_list"]
+ .mutable_bytes_list();
+ bytes_list->add_value("bytes1");
+ bytes_list->add_value("bytes2");
+
+ FloatList* float_list =
+ (*example.mutable_features()->mutable_feature())["float_list"]
+ .mutable_float_list();
+ float_list->add_value(1.0);
+ float_list->add_value(2.0);
+
+ Int64List* int64_list =
+ (*example.mutable_features()->mutable_feature())["int64_list"]
+ .mutable_int64_list();
+ int64_list->add_value(3);
+ int64_list->add_value(270);
+ int64_list->add_value(86942);
+
+ TestCorrectness(Serialize(example));
+}
+
+string MakeSerializedExample() {
+ Example example;
+ const int kFeatureNameLength = 10;
+ const int kFeatureValueLength = 20;
+ const int kBytesFeatureCount = 200;
+ const int kFloatFeatureCount = 200;
+ const int kInt64FeatureCount = 200;
+ auto& fmap = *example.mutable_features()->mutable_feature();
+ for (int i = 0; i < kBytesFeatureCount; ++i) {
+ fmap[strings::StrCat(string('b', kFeatureNameLength), i)]
+ .mutable_bytes_list()
+ ->add_value(string('v', kFeatureValueLength));
+ }
+ for (int i = 0; i < kFloatFeatureCount; ++i) {
+ fmap[strings::StrCat(string('f', kFeatureNameLength), i)]
+ .mutable_float_list()
+ ->add_value(123123123.123);
+ }
+ for (int i = 0; i < kInt64FeatureCount; ++i) {
+ fmap[strings::StrCat(string('i', kFeatureNameLength), i)]
+ .mutable_int64_list()
+ ->add_value(10 * i);
+ }
+ string serialized;
+ example.SerializeToString(&serialized);
+ return serialized;
+}
+
+} // namespace
+
+} // namespace example
+} // namespace tensorflow
diff --git a/tensorflow/core/util/presized_cuckoo_map.h b/tensorflow/core/util/presized_cuckoo_map.h
index b488d32e03..5244ab693a 100644
--- a/tensorflow/core/util/presized_cuckoo_map.h
+++ b/tensorflow/core/util/presized_cuckoo_map.h
@@ -50,7 +50,10 @@ class PresizedCuckooMap {
// The key type is fixed as a pre-hashed key for this specialized use.
typedef uint64 key_type;
- explicit PresizedCuckooMap(uint64 num_entries) : cpq_(new CuckooPathQueue) {
+ explicit PresizedCuckooMap(uint64 num_entries) { Clear(num_entries); }
+
+ void Clear(uint64 num_entries) {
+ cpq_.reset(new CuckooPathQueue());
double n(num_entries);
n /= kLoadFactor;
num_buckets_ = (static_cast<uint64>(n) / kSlotsPerBucket);
@@ -62,6 +65,7 @@ class PresizedCuckooMap {
for (int i = 0; i < kSlotsPerBucket; i++) {
empty_bucket.keys[i] = kUnusedSlot;
}
+ buckets_.clear();
buckets_.resize(num_buckets_, empty_bucket);
#if !defined(__GCUDACC__) && !defined(__GCUDACC_HOST__)
buckets_divisor_ = Eigen::internal::TensorIntDivisor<uint64>(num_buckets_);
@@ -317,7 +321,7 @@ class PresizedCuckooMap {
std::vector<Bucket> buckets_;
Eigen::internal::TensorIntDivisor<uint64> buckets_divisor_; // for fast mod
- const std::unique_ptr<CuckooPathQueue> cpq_;
+ std::unique_ptr<CuckooPathQueue> cpq_;
CuckooPathEntry visited_[kVisitedListSize];
TF_DISALLOW_COPY_AND_ASSIGN(PresizedCuckooMap);
diff --git a/tensorflow/core/util/presized_cuckoo_map_test.cc b/tensorflow/core/util/presized_cuckoo_map_test.cc
index 64ad315518..fe8e5dcfbd 100644
--- a/tensorflow/core/util/presized_cuckoo_map_test.cc
+++ b/tensorflow/core/util/presized_cuckoo_map_test.cc
@@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/util/presized_cuckoo_map.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
-#include "tensorflow/core/util/presized_cuckoo_map.h"
namespace tensorflow {
namespace {
@@ -64,6 +64,22 @@ TEST(PresizedCuckooMapTest, ZeroSizeMap) {
}
}
+TEST(PresizedCuckooMapTest, RepeatedClear) {
+ PresizedCuckooMap<int> pscm(2);
+ int out;
+ for (int i = 0; i < 100; ++i) {
+ pscm.InsertUnique(0, 0);
+ pscm.InsertUnique(1, 1);
+ EXPECT_TRUE(pscm.Find(0, &out));
+ EXPECT_EQ(0, out);
+ EXPECT_TRUE(pscm.Find(1, &out));
+ EXPECT_EQ(1, out);
+ pscm.Clear(2);
+ EXPECT_FALSE(pscm.Find(0, &out));
+ EXPECT_FALSE(pscm.Find(1, &out));
+ }
+}
+
void RunFill(int64 table_size) {
PresizedCuckooMap<int> pscm(table_size);
for (int64 i = 0; i < table_size; i++) {