diff options
author | 2018-08-07 10:26:06 -0700 | |
---|---|---|
committer | 2018-08-07 10:37:21 -0700 | |
commit | b8886649c75ae864f2532bca044e2f44fb138c95 (patch) | |
tree | d053e1c2a91a3125a7ebb4f6d084a4be713febde /tensorflow/core/util | |
parent | 90bf05c0d147a7e0c6e48720e17e51233b2bcd3c (diff) |
[tf.data] Add feature statistics collection hooks to the tf.Example parsers.
PiperOrigin-RevId: 207737913
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r-- | tensorflow/core/util/example_proto_fast_parsing.cc | 89 | ||||
-rw-r--r-- | tensorflow/core/util/example_proto_fast_parsing.h | 24 | ||||
-rw-r--r-- | tensorflow/core/util/example_proto_fast_parsing_test.cc | 80 |
3 files changed, 187 insertions, 6 deletions
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index 418e97ac24..1fec0010a1 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -495,7 +495,8 @@ Status FastParseSerializedExample( const PresizedCuckooMap<std::pair<size_t, Type>>& config_index, SeededHasher hasher, std::vector<Tensor>* output_dense, std::vector<SparseBuffer>* output_varlen_dense, - std::vector<SparseBuffer>* output_sparse) { + std::vector<SparseBuffer>* output_sparse, + PerExampleFeatureStats* output_stats) { DCHECK(output_dense != nullptr); DCHECK(output_sparse != nullptr); parsed::Example parsed_example; @@ -508,6 +509,14 @@ Status FastParseSerializedExample( // Handle features present in the example. const size_t parsed_example_size = parsed_example.size(); + + if (output_stats) { + // TODO(b/111553342): This may over-count the number of features if there + // are duplicate keys in the feature map. Consider deduplicating the keys + // before computing the count. + output_stats->features_count = parsed_example_size; + } + for (size_t i = 0; i < parsed_example_size; ++i) { // This is a logic that standard protobuf parsing is implementing. // I.e. last entry in the map overwrites all the previous ones. @@ -567,6 +576,13 @@ Status FastParseSerializedExample( Tensor& out = (*output_dense)[d]; const std::size_t num_elements = config.dense[d].elements_per_stride; + if (output_stats) { + // TODO(b/111553342): If desirable, we could add support for counting + // elements in the features that aren't parsed, but this could add + // considerable runtime cost. + output_stats->feature_values_count += num_elements; + } + const std::size_t offset = example_index * num_elements; auto shape_error = [&](size_t size, StringPiece type_str) { @@ -669,6 +685,23 @@ Status FastParseSerializedExample( default: LOG(FATAL) << "Should not happen."; } + + if (output_stats) { + // Use `out.example_end_indices` to determine the feature-value count + // for this feature, because the preceding switch statement pushes + // the length of the appropriate feature list to that vector. + // TODO(b/111553342): If desirable, we could add support for counting + // elements in the features that aren't parsed, but this could add + // considerable runtime cost. + const size_t out_examples_count = out.example_end_indices.size(); + if (out_examples_count == 1) { + output_stats->feature_values_count += out.example_end_indices[0]; + } else { + output_stats->feature_values_count += + out.example_end_indices[out_examples_count - 1] - + out.example_end_indices[out_examples_count - 2]; + } + } } } else { // If feature was already visited, skip. @@ -720,6 +753,23 @@ Status FastParseSerializedExample( default: LOG(FATAL) << "Should not happen."; } + + if (output_stats) { + // Use `out.example_end_indices` to determine the feature-value count + // for this feature, because the preceding switch statement pushes + // the length of the appropriate feature list to that vector. + // TODO(b/111553342): If desirable, we could add support for counting + // elements in the features that aren't parsed, but this could add + // considerable runtime cost. + const size_t out_examples_count = out.example_end_indices.size(); + if (out_examples_count == 1) { + output_stats->feature_values_count += out.example_end_indices[0]; + } else { + output_stats->feature_values_count += + out.example_end_indices[out_examples_count - 1] - + out.example_end_indices[out_examples_count - 2]; + } + } } } @@ -877,6 +927,10 @@ Status FastParseExample(const Config& config, TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); } + if (config.collect_feature_stats) { + result->feature_stats.resize(serialized.size()); + } + size_t config_size = config.dense.size() + config.sparse.size(); SeededHasher hasher; // Build config index. @@ -962,11 +1016,15 @@ Status FastParseExample(const Config& config, size_t start = first_example_of_minibatch(minibatch); size_t end = first_example_of_minibatch(minibatch + 1); for (size_t e = start; e < end; ++e) { + PerExampleFeatureStats* stats = nullptr; + if (config.collect_feature_stats) { + stats = &result->feature_stats[e]; + } status_of_minibatch[minibatch] = FastParseSerializedExample( serialized[e], (!example_names.empty() ? example_names[e] : "<unknown>"), e, config, config_index, hasher, &fixed_dense_values, - &varlen_dense_buffers[minibatch], &sparse_buffers[minibatch]); + &varlen_dense_buffers[minibatch], &sparse_buffers[minibatch], stats); if (!status_of_minibatch[minibatch].ok()) break; } }; @@ -1079,7 +1137,7 @@ Status FastParseExample(const Config& config, const size_t stride_size = config.dense[d].elements_per_stride; const size_t max_num_elements = max_num_features / stride_size; TensorShape values_shape; - DCHECK(max_num_features % config.dense[d].elements_per_stride == 0); + DCHECK_EQ(max_num_features % config.dense[d].elements_per_stride, 0); const size_t batch_size = serialized.size(); values_shape.AddDim(batch_size); values_shape.AddDim(max_num_elements); @@ -1138,6 +1196,12 @@ Status FastParseSingleExample(const Config& config, const string& serialized, TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); } + PerExampleFeatureStats* stats = nullptr; + if (config.collect_feature_stats) { + result->feature_stats.emplace_back(); + stats = &result->feature_stats.back(); + } + // TODO(mrry): Cache the construction of this map at Op construction time. size_t config_size = config.dense.size() + config.sparse.size(); SeededHasher hasher; @@ -1196,6 +1260,13 @@ Status FastParseSingleExample(const Config& config, const string& serialized, std::vector<bool> sparse_feature_already_seen(config.sparse.size(), false); std::vector<bool> dense_feature_already_seen(config.dense.size(), false); + if (stats) { + // TODO(b/111553342): This may over-count the number of features if there + // are duplicate keys in the feature map. Consider deduplicating the keys + // before computing the count. + stats->features_count = parsed_example.size(); + } + // Handle features present in the example. const size_t parsed_example_size = parsed_example.size(); for (size_t i = 0; i < parsed_example_size; ++i) { @@ -1254,7 +1325,12 @@ Status FastParseSingleExample(const Config& config, const string& serialized, Tensor* out = &result->dense_values[d]; const std::size_t num_elements = config.dense[d].elements_per_stride; - + if (stats) { + // TODO(b/111553342): If desirable, we could add support for counting + // elements in the features that aren't parsed, but this could add + // considerable runtime cost. + stats->feature_values_count += num_elements; + } switch (example_dtype) { case DT_INT64: { auto out_p = out->flat<int64>().data(); @@ -1362,6 +1438,10 @@ Status FastParseSingleExample(const Config& config, const string& serialized, return parse_error(); } + if (stats) { + stats->feature_values_count += num_elements; + } + Tensor* out; if (is_dense) { TensorShape values_shape; @@ -1636,6 +1716,7 @@ inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream, } // TODO(sundberg): Use the threadpool to parallelize example parsing. +// TODO(b/111553342): Support extracting feature statistics from the examples. Status FastParseSequenceExample( const FastParseExampleConfig& context_config, const FastParseExampleConfig& feature_list_config, diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h index 024a4518ee..db5b5ff929 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.h +++ b/tensorflow/core/util/example_proto_fast_parsing.h @@ -59,6 +59,26 @@ struct FastParseExampleConfig { std::vector<Dense> dense; std::vector<Sparse> sparse; + + // If `true`, `Result::feature_stats` will contain one + // `PerExampleFeatureStats` for each serialized example in the input. + bool collect_feature_stats = false; +}; + +// Statistics about the features in each example passed to +// `FastParse[Single]Example()`. +// +// TODO(b/111553342): The gathered statistics currently have two limitations: +// * Feature names that appear more than once will be counted multiple times. +// * The feature values count only represents the counts for features that were +// requested in the `FastParseExampleConfig`. +// These could be addressed with additional work at runtime. +struct PerExampleFeatureStats { + // The number of feature names in an example. + size_t features_count = 0; + + // The sum of the number of values in each feature that is parsed. + size_t feature_values_count = 0; }; // This is exactly the output of TF's ParseExample Op. @@ -68,6 +88,10 @@ struct Result { std::vector<Tensor> sparse_values; std::vector<Tensor> sparse_shapes; std::vector<Tensor> dense_values; + + // This vector will be populated with one element per example if + // `FastParseExampleConfig::collect_feature_stats` is set to `true`. + std::vector<PerExampleFeatureStats> feature_stats; }; // Parses a batch of serialized Example protos and converts them into result diff --git a/tensorflow/core/util/example_proto_fast_parsing_test.cc b/tensorflow/core/util/example_proto_fast_parsing_test.cc index 1a804e154c..37faa927bf 100644 --- a/tensorflow/core/util/example_proto_fast_parsing_test.cc +++ b/tensorflow/core/util/example_proto_fast_parsing_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <utility> + #include "tensorflow/core/util/example_proto_fast_parsing.h" #include "tensorflow/core/example/example.pb.h" @@ -211,7 +213,7 @@ TEST(FastParse, SingleInt64) { TestCorrectness(Serialize(example)); } -TEST(FastParse, SomeFeatures) { +static string ExampleWithSomeFeatures() { Example example; (*example.mutable_features()->mutable_feature())[""]; @@ -242,7 +244,81 @@ TEST(FastParse, SomeFeatures) { int64_list->add_value(270); int64_list->add_value(86942); - TestCorrectness(Serialize(example)); + return Serialize(example); +} + +TEST(FastParse, SomeFeatures) { TestCorrectness(ExampleWithSomeFeatures()); } + +static void AddDenseFeature(const char* feature_name, DataType dtype, + PartialTensorShape shape, bool variable_length, + size_t elements_per_stride, + FastParseExampleConfig* out_config) { + out_config->dense.emplace_back(); + auto& new_feature = out_config->dense.back(); + new_feature.feature_name = feature_name; + new_feature.dtype = dtype; + new_feature.shape = std::move(shape); + new_feature.default_value = Tensor(dtype, {}); + new_feature.variable_length = variable_length; + new_feature.elements_per_stride = elements_per_stride; +} + +static void AddSparseFeature(const char* feature_name, DataType dtype, + FastParseExampleConfig* out_config) { + out_config->sparse.emplace_back(); + auto& new_feature = out_config->sparse.back(); + new_feature.feature_name = feature_name; + new_feature.dtype = dtype; +} + +TEST(FastParse, StatsCollection) { + const size_t kNumExamples = 13; + std::vector<string> serialized(kNumExamples, ExampleWithSomeFeatures()); + + FastParseExampleConfig config_dense; + AddDenseFeature("bytes_list", DT_STRING, {2}, false, 2, &config_dense); + AddDenseFeature("float_list", DT_FLOAT, {2}, false, 2, &config_dense); + AddDenseFeature("int64_list", DT_INT64, {3}, false, 3, &config_dense); + config_dense.collect_feature_stats = true; + + FastParseExampleConfig config_varlen; + AddDenseFeature("bytes_list", DT_STRING, {-1}, true, 1, &config_varlen); + AddDenseFeature("float_list", DT_FLOAT, {-1}, true, 1, &config_varlen); + AddDenseFeature("int64_list", DT_INT64, {-1}, true, 1, &config_varlen); + config_varlen.collect_feature_stats = true; + + FastParseExampleConfig config_sparse; + AddSparseFeature("bytes_list", DT_STRING, &config_sparse); + AddSparseFeature("float_list", DT_FLOAT, &config_sparse); + AddSparseFeature("int64_list", DT_INT64, &config_sparse); + config_sparse.collect_feature_stats = true; + + FastParseExampleConfig config_mixed; + AddDenseFeature("bytes_list", DT_STRING, {2}, false, 2, &config_mixed); + AddDenseFeature("float_list", DT_FLOAT, {-1}, true, 1, &config_mixed); + AddSparseFeature("int64_list", DT_INT64, &config_mixed); + config_mixed.collect_feature_stats = true; + + for (const FastParseExampleConfig& config : + {config_dense, config_varlen, config_sparse, config_mixed}) { + { + Result result; + TF_CHECK_OK(FastParseExample(config, serialized, {}, nullptr, &result)); + EXPECT_EQ(kNumExamples, result.feature_stats.size()); + for (const PerExampleFeatureStats& stats : result.feature_stats) { + EXPECT_EQ(7, stats.features_count); + EXPECT_EQ(7, stats.feature_values_count); + } + } + + { + Result result; + TF_CHECK_OK(FastParseSingleExample(config, serialized[0], &result)); + EXPECT_EQ(1, result.feature_stats.size()); + EXPECT_EQ(7, result.feature_stats[0].features_count); + EXPECT_EQ(7, result.feature_stats[0].feature_values_count); + } + } } string RandStr(random::SimplePhilox* rng) { |