From 865684c15d572adcfec12c3fbb2236c0e0a49c2d Mon Sep 17 00:00:00 2001 From: Shivani Agrawal Date: Wed, 22 Aug 2018 15:27:46 -0700 Subject: [data-stats] Supports stats aggregation for ParseExampleDataset. PiperOrigin-RevId: 209840812 --- .../python/kernel_tests/stats_dataset_ops_test.py | 5 ++-- tensorflow/contrib/data/python/ops/readers.py | 2 -- .../core/kernels/data/parse_example_dataset_op.cc | 34 +++++++++++++++++++--- tensorflow/core/ops/dataset_ops.cc | 1 - 4 files changed, 33 insertions(+), 9 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index a41d21f8c1..53c22628c7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -190,7 +190,7 @@ class FeatureStatsDatasetTest( batch_size=batch_size, shuffle=True, shuffle_seed=5, - drop_final_batch=True).apply( + drop_final_batch=False).apply( stats_ops.set_stats_aggregator(stats_aggregator)) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() @@ -198,7 +198,8 @@ class FeatureStatsDatasetTest( with self.test_session() as sess: sess.run(iterator.initializer) - for _ in range(total_records // batch_size): + for _ in range(total_records // batch_size + 1 if total_records % + batch_size else total_records // batch_size): sess.run(next_element) with self.assertRaises(errors.OutOfRangeError): diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 151f12b082..cafe0a4091 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -778,8 +778,6 @@ def make_batched_features_dataset(file_pattern, dataset = _maybe_shuffle_and_repeat( dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) - dataset = dataset.apply(stats_ops.feature_stats("record_stats")) - # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to # improve the shape inference, because it makes the batch dimension static. # It is safe to do this because in that case we are repeating the input diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc index 1ab2af3e92..cc5007ee92 100644 --- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/kernels/data/parallel_map_iterator.h" #include "tensorflow/core/util/example_proto_fast_parsing.h" @@ -188,7 +189,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { auto map_fn = [this](IteratorContext* ctx, std::vector input_element, std::vector* result, StatusCallback done) { - (*ctx->runner())([this, input_element, result, done]() { + (*ctx->runner())([this, ctx, input_element, result, done]() { std::vector slice_vec; for (Tensor t : input_element) { auto serialized_t = t.flat(); @@ -197,10 +198,15 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { for (auto it = slice.begin(); it != slice.end(); it++) slice_vec.push_back(*it); } + example::FastParseExampleConfig config = config_; + // local copy of config_ for modification. + auto stats_aggregator = ctx->stats_aggregator(); + if (stats_aggregator) { + config.collect_feature_stats = true; + } example::Result example_result; - // TODO(b/111553342): Add stats collection logic here. - Status s = FastParseExample(config_, slice_vec, {}, - device_threadpool_, &example_result); + Status s = FastParseExample(config, slice_vec, {}, device_threadpool_, + &example_result); if (s.ok()) { (*result).resize(key_to_output_index_.size()); for (int d = 0; d < dense_keys_.size(); ++d) { @@ -241,6 +247,26 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { << serialized_sparse.shape().DebugString() << ")."; (*result)[output_index] = serialized_sparse; } + // TODO(b/111553342): User provided tags instead of fixed tag. + if (stats_aggregator) { + stats_aggregator->IncrementCounter( + "examples_count", "trainer", + example_result.feature_stats.size()); + for (example::PerExampleFeatureStats feature_stats : + example_result.feature_stats) { + stats_aggregator->AddToHistogram( + strings::StrCat("record_stats", ":features"), + {static_cast(feature_stats.features_count)}); + stats_aggregator->IncrementCounter( + "features_count", "trainer", feature_stats.features_count); + stats_aggregator->IncrementCounter( + "feature_values_count", "trainer", + feature_stats.feature_values_count); + stats_aggregator->AddToHistogram( + strings::StrCat("record_stats", ":feature-values"), + {static_cast(feature_stats.feature_values_count)}); + } + } } done(s); }); diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 07e735c7cb..41f5f9aebe 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -170,7 +170,6 @@ REGISTER_OP("ParseExampleDataset") .Input("input_dataset: variant") .Input("num_parallel_calls: int64") .Input("dense_defaults: Tdense") - .Output("handle: variant") .Attr("sparse_keys: list(string) >= 0") .Attr("dense_keys: list(string) >= 0") -- cgit v1.2.3