aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-08-07 10:26:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-07 10:37:21 -0700
commitb8886649c75ae864f2532bca044e2f44fb138c95 (patch)
treed053e1c2a91a3125a7ebb4f6d084a4be713febde /tensorflow/core/util
parent90bf05c0d147a7e0c6e48720e17e51233b2bcd3c (diff)
[tf.data] Add feature statistics collection hooks to the tf.Example parsers.
PiperOrigin-RevId: 207737913
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.cc89
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.h24
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing_test.cc80
3 files changed, 187 insertions, 6 deletions
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc
index 418e97ac24..1fec0010a1 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing.cc
@@ -495,7 +495,8 @@ Status FastParseSerializedExample(
const PresizedCuckooMap<std::pair<size_t, Type>>& config_index,
SeededHasher hasher, std::vector<Tensor>* output_dense,
std::vector<SparseBuffer>* output_varlen_dense,
- std::vector<SparseBuffer>* output_sparse) {
+ std::vector<SparseBuffer>* output_sparse,
+ PerExampleFeatureStats* output_stats) {
DCHECK(output_dense != nullptr);
DCHECK(output_sparse != nullptr);
parsed::Example parsed_example;
@@ -508,6 +509,14 @@ Status FastParseSerializedExample(
// Handle features present in the example.
const size_t parsed_example_size = parsed_example.size();
+
+ if (output_stats) {
+ // TODO(b/111553342): This may over-count the number of features if there
+ // are duplicate keys in the feature map. Consider deduplicating the keys
+ // before computing the count.
+ output_stats->features_count = parsed_example_size;
+ }
+
for (size_t i = 0; i < parsed_example_size; ++i) {
// This is a logic that standard protobuf parsing is implementing.
// I.e. last entry in the map overwrites all the previous ones.
@@ -567,6 +576,13 @@ Status FastParseSerializedExample(
Tensor& out = (*output_dense)[d];
const std::size_t num_elements = config.dense[d].elements_per_stride;
+ if (output_stats) {
+ // TODO(b/111553342): If desirable, we could add support for counting
+ // elements in the features that aren't parsed, but this could add
+ // considerable runtime cost.
+ output_stats->feature_values_count += num_elements;
+ }
+
const std::size_t offset = example_index * num_elements;
auto shape_error = [&](size_t size, StringPiece type_str) {
@@ -669,6 +685,23 @@ Status FastParseSerializedExample(
default:
LOG(FATAL) << "Should not happen.";
}
+
+ if (output_stats) {
+ // Use `out.example_end_indices` to determine the feature-value count
+ // for this feature, because the preceding switch statement pushes
+ // the length of the appropriate feature list to that vector.
+ // TODO(b/111553342): If desirable, we could add support for counting
+ // elements in the features that aren't parsed, but this could add
+ // considerable runtime cost.
+ const size_t out_examples_count = out.example_end_indices.size();
+ if (out_examples_count == 1) {
+ output_stats->feature_values_count += out.example_end_indices[0];
+ } else {
+ output_stats->feature_values_count +=
+ out.example_end_indices[out_examples_count - 1] -
+ out.example_end_indices[out_examples_count - 2];
+ }
+ }
}
} else {
// If feature was already visited, skip.
@@ -720,6 +753,23 @@ Status FastParseSerializedExample(
default:
LOG(FATAL) << "Should not happen.";
}
+
+ if (output_stats) {
+ // Use `out.example_end_indices` to determine the feature-value count
+ // for this feature, because the preceding switch statement pushes
+ // the length of the appropriate feature list to that vector.
+ // TODO(b/111553342): If desirable, we could add support for counting
+ // elements in the features that aren't parsed, but this could add
+ // considerable runtime cost.
+ const size_t out_examples_count = out.example_end_indices.size();
+ if (out_examples_count == 1) {
+ output_stats->feature_values_count += out.example_end_indices[0];
+ } else {
+ output_stats->feature_values_count +=
+ out.example_end_indices[out_examples_count - 1] -
+ out.example_end_indices[out_examples_count - 2];
+ }
+ }
}
}
@@ -877,6 +927,10 @@ Status FastParseExample(const Config& config,
TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
}
+ if (config.collect_feature_stats) {
+ result->feature_stats.resize(serialized.size());
+ }
+
size_t config_size = config.dense.size() + config.sparse.size();
SeededHasher hasher;
// Build config index.
@@ -962,11 +1016,15 @@ Status FastParseExample(const Config& config,
size_t start = first_example_of_minibatch(minibatch);
size_t end = first_example_of_minibatch(minibatch + 1);
for (size_t e = start; e < end; ++e) {
+ PerExampleFeatureStats* stats = nullptr;
+ if (config.collect_feature_stats) {
+ stats = &result->feature_stats[e];
+ }
status_of_minibatch[minibatch] = FastParseSerializedExample(
serialized[e],
(!example_names.empty() ? example_names[e] : "<unknown>"), e, config,
config_index, hasher, &fixed_dense_values,
- &varlen_dense_buffers[minibatch], &sparse_buffers[minibatch]);
+ &varlen_dense_buffers[minibatch], &sparse_buffers[minibatch], stats);
if (!status_of_minibatch[minibatch].ok()) break;
}
};
@@ -1079,7 +1137,7 @@ Status FastParseExample(const Config& config,
const size_t stride_size = config.dense[d].elements_per_stride;
const size_t max_num_elements = max_num_features / stride_size;
TensorShape values_shape;
- DCHECK(max_num_features % config.dense[d].elements_per_stride == 0);
+ DCHECK_EQ(max_num_features % config.dense[d].elements_per_stride, 0);
const size_t batch_size = serialized.size();
values_shape.AddDim(batch_size);
values_shape.AddDim(max_num_elements);
@@ -1138,6 +1196,12 @@ Status FastParseSingleExample(const Config& config, const string& serialized,
TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
}
+ PerExampleFeatureStats* stats = nullptr;
+ if (config.collect_feature_stats) {
+ result->feature_stats.emplace_back();
+ stats = &result->feature_stats.back();
+ }
+
// TODO(mrry): Cache the construction of this map at Op construction time.
size_t config_size = config.dense.size() + config.sparse.size();
SeededHasher hasher;
@@ -1196,6 +1260,13 @@ Status FastParseSingleExample(const Config& config, const string& serialized,
std::vector<bool> sparse_feature_already_seen(config.sparse.size(), false);
std::vector<bool> dense_feature_already_seen(config.dense.size(), false);
+ if (stats) {
+ // TODO(b/111553342): This may over-count the number of features if there
+ // are duplicate keys in the feature map. Consider deduplicating the keys
+ // before computing the count.
+ stats->features_count = parsed_example.size();
+ }
+
// Handle features present in the example.
const size_t parsed_example_size = parsed_example.size();
for (size_t i = 0; i < parsed_example_size; ++i) {
@@ -1254,7 +1325,12 @@ Status FastParseSingleExample(const Config& config, const string& serialized,
Tensor* out = &result->dense_values[d];
const std::size_t num_elements = config.dense[d].elements_per_stride;
-
+ if (stats) {
+ // TODO(b/111553342): If desirable, we could add support for counting
+ // elements in the features that aren't parsed, but this could add
+ // considerable runtime cost.
+ stats->feature_values_count += num_elements;
+ }
switch (example_dtype) {
case DT_INT64: {
auto out_p = out->flat<int64>().data();
@@ -1362,6 +1438,10 @@ Status FastParseSingleExample(const Config& config, const string& serialized,
return parse_error();
}
+ if (stats) {
+ stats->feature_values_count += num_elements;
+ }
+
Tensor* out;
if (is_dense) {
TensorShape values_shape;
@@ -1636,6 +1716,7 @@ inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
}
// TODO(sundberg): Use the threadpool to parallelize example parsing.
+// TODO(b/111553342): Support extracting feature statistics from the examples.
Status FastParseSequenceExample(
const FastParseExampleConfig& context_config,
const FastParseExampleConfig& feature_list_config,
diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h
index 024a4518ee..db5b5ff929 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.h
+++ b/tensorflow/core/util/example_proto_fast_parsing.h
@@ -59,6 +59,26 @@ struct FastParseExampleConfig {
std::vector<Dense> dense;
std::vector<Sparse> sparse;
+
+ // If `true`, `Result::feature_stats` will contain one
+ // `PerExampleFeatureStats` for each serialized example in the input.
+ bool collect_feature_stats = false;
+};
+
+// Statistics about the features in each example passed to
+// `FastParse[Single]Example()`.
+//
+// TODO(b/111553342): The gathered statistics currently have two limitations:
+// * Feature names that appear more than once will be counted multiple times.
+// * The feature values count only represents the counts for features that were
+// requested in the `FastParseExampleConfig`.
+// These could be addressed with additional work at runtime.
+struct PerExampleFeatureStats {
+ // The number of feature names in an example.
+ size_t features_count = 0;
+
+ // The sum of the number of values in each feature that is parsed.
+ size_t feature_values_count = 0;
};
// This is exactly the output of TF's ParseExample Op.
@@ -68,6 +88,10 @@ struct Result {
std::vector<Tensor> sparse_values;
std::vector<Tensor> sparse_shapes;
std::vector<Tensor> dense_values;
+
+ // This vector will be populated with one element per example if
+ // `FastParseExampleConfig::collect_feature_stats` is set to `true`.
+ std::vector<PerExampleFeatureStats> feature_stats;
};
// Parses a batch of serialized Example protos and converts them into result
diff --git a/tensorflow/core/util/example_proto_fast_parsing_test.cc b/tensorflow/core/util/example_proto_fast_parsing_test.cc
index 1a804e154c..37faa927bf 100644
--- a/tensorflow/core/util/example_proto_fast_parsing_test.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing_test.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <utility>
+
#include "tensorflow/core/util/example_proto_fast_parsing.h"
#include "tensorflow/core/example/example.pb.h"
@@ -211,7 +213,7 @@ TEST(FastParse, SingleInt64) {
TestCorrectness(Serialize(example));
}
-TEST(FastParse, SomeFeatures) {
+static string ExampleWithSomeFeatures() {
Example example;
(*example.mutable_features()->mutable_feature())[""];
@@ -242,7 +244,81 @@ TEST(FastParse, SomeFeatures) {
int64_list->add_value(270);
int64_list->add_value(86942);
- TestCorrectness(Serialize(example));
+ return Serialize(example);
+}
+
+TEST(FastParse, SomeFeatures) { TestCorrectness(ExampleWithSomeFeatures()); }
+
+static void AddDenseFeature(const char* feature_name, DataType dtype,
+ PartialTensorShape shape, bool variable_length,
+ size_t elements_per_stride,
+ FastParseExampleConfig* out_config) {
+ out_config->dense.emplace_back();
+ auto& new_feature = out_config->dense.back();
+ new_feature.feature_name = feature_name;
+ new_feature.dtype = dtype;
+ new_feature.shape = std::move(shape);
+ new_feature.default_value = Tensor(dtype, {});
+ new_feature.variable_length = variable_length;
+ new_feature.elements_per_stride = elements_per_stride;
+}
+
+static void AddSparseFeature(const char* feature_name, DataType dtype,
+ FastParseExampleConfig* out_config) {
+ out_config->sparse.emplace_back();
+ auto& new_feature = out_config->sparse.back();
+ new_feature.feature_name = feature_name;
+ new_feature.dtype = dtype;
+}
+
+TEST(FastParse, StatsCollection) {
+ const size_t kNumExamples = 13;
+ std::vector<string> serialized(kNumExamples, ExampleWithSomeFeatures());
+
+ FastParseExampleConfig config_dense;
+ AddDenseFeature("bytes_list", DT_STRING, {2}, false, 2, &config_dense);
+ AddDenseFeature("float_list", DT_FLOAT, {2}, false, 2, &config_dense);
+ AddDenseFeature("int64_list", DT_INT64, {3}, false, 3, &config_dense);
+ config_dense.collect_feature_stats = true;
+
+ FastParseExampleConfig config_varlen;
+ AddDenseFeature("bytes_list", DT_STRING, {-1}, true, 1, &config_varlen);
+ AddDenseFeature("float_list", DT_FLOAT, {-1}, true, 1, &config_varlen);
+ AddDenseFeature("int64_list", DT_INT64, {-1}, true, 1, &config_varlen);
+ config_varlen.collect_feature_stats = true;
+
+ FastParseExampleConfig config_sparse;
+ AddSparseFeature("bytes_list", DT_STRING, &config_sparse);
+ AddSparseFeature("float_list", DT_FLOAT, &config_sparse);
+ AddSparseFeature("int64_list", DT_INT64, &config_sparse);
+ config_sparse.collect_feature_stats = true;
+
+ FastParseExampleConfig config_mixed;
+ AddDenseFeature("bytes_list", DT_STRING, {2}, false, 2, &config_mixed);
+ AddDenseFeature("float_list", DT_FLOAT, {-1}, true, 1, &config_mixed);
+ AddSparseFeature("int64_list", DT_INT64, &config_mixed);
+ config_mixed.collect_feature_stats = true;
+
+ for (const FastParseExampleConfig& config :
+ {config_dense, config_varlen, config_sparse, config_mixed}) {
+ {
+ Result result;
+ TF_CHECK_OK(FastParseExample(config, serialized, {}, nullptr, &result));
+ EXPECT_EQ(kNumExamples, result.feature_stats.size());
+ for (const PerExampleFeatureStats& stats : result.feature_stats) {
+ EXPECT_EQ(7, stats.features_count);
+ EXPECT_EQ(7, stats.feature_values_count);
+ }
+ }
+
+ {
+ Result result;
+ TF_CHECK_OK(FastParseSingleExample(config, serialized[0], &result));
+ EXPECT_EQ(1, result.feature_stats.size());
+ EXPECT_EQ(7, result.feature_stats[0].features_count);
+ EXPECT_EQ(7, result.feature_stats[0].feature_values_count);
+ }
+ }
}
string RandStr(random::SimplePhilox* rng) {