aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-08-22 15:27:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 15:53:44 -0700
commit865684c15d572adcfec12c3fbb2236c0e0a49c2d (patch)
tree9fb73a03986a1ee8d49eb2199419e1beec1b6ad1
parent3facf91c0468bec1bb8151dd42816c2827a31e6d (diff)
[data-stats] Supports stats aggregation for ParseExampleDataset.
PiperOrigin-RevId: 209840812
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py5
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py2
-rw-r--r--tensorflow/core/kernels/data/parse_example_dataset_op.cc34
-rw-r--r--tensorflow/core/ops/dataset_ops.cc1
4 files changed, 33 insertions, 9 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index a41d21f8c1..53c22628c7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -190,7 +190,7 @@ class FeatureStatsDatasetTest(
batch_size=batch_size,
shuffle=True,
shuffle_seed=5,
- drop_final_batch=True).apply(
+ drop_final_batch=False).apply(
stats_ops.set_stats_aggregator(stats_aggregator))
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
@@ -198,7 +198,8 @@ class FeatureStatsDatasetTest(
with self.test_session() as sess:
sess.run(iterator.initializer)
- for _ in range(total_records // batch_size):
+ for _ in range(total_records // batch_size + 1 if total_records %
+ batch_size else total_records // batch_size):
sess.run(next_element)
with self.assertRaises(errors.OutOfRangeError):
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 151f12b082..cafe0a4091 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -778,8 +778,6 @@ def make_batched_features_dataset(file_pattern,
dataset = _maybe_shuffle_and_repeat(
dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
- dataset = dataset.apply(stats_ops.feature_stats("record_stats"))
-
# NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
# improve the shape inference, because it makes the batch dimension static.
# It is safe to do this because in that case we are repeating the input
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
index 1ab2af3e92..cc5007ee92 100644
--- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <deque>
+#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
#include "tensorflow/core/util/example_proto_fast_parsing.h"
@@ -188,7 +189,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
auto map_fn = [this](IteratorContext* ctx,
std::vector<Tensor> input_element,
std::vector<Tensor>* result, StatusCallback done) {
- (*ctx->runner())([this, input_element, result, done]() {
+ (*ctx->runner())([this, ctx, input_element, result, done]() {
std::vector<string> slice_vec;
for (Tensor t : input_element) {
auto serialized_t = t.flat<string>();
@@ -197,10 +198,15 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
for (auto it = slice.begin(); it != slice.end(); it++)
slice_vec.push_back(*it);
}
+ example::FastParseExampleConfig config = config_;
+ // local copy of config_ for modification.
+ auto stats_aggregator = ctx->stats_aggregator();
+ if (stats_aggregator) {
+ config.collect_feature_stats = true;
+ }
example::Result example_result;
- // TODO(b/111553342): Add stats collection logic here.
- Status s = FastParseExample(config_, slice_vec, {},
- device_threadpool_, &example_result);
+ Status s = FastParseExample(config, slice_vec, {}, device_threadpool_,
+ &example_result);
if (s.ok()) {
(*result).resize(key_to_output_index_.size());
for (int d = 0; d < dense_keys_.size(); ++d) {
@@ -241,6 +247,26 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
<< serialized_sparse.shape().DebugString() << ").";
(*result)[output_index] = serialized_sparse;
}
+ // TODO(b/111553342): User provided tags instead of fixed tag.
+ if (stats_aggregator) {
+ stats_aggregator->IncrementCounter(
+ "examples_count", "trainer",
+ example_result.feature_stats.size());
+ for (example::PerExampleFeatureStats feature_stats :
+ example_result.feature_stats) {
+ stats_aggregator->AddToHistogram(
+ strings::StrCat("record_stats", ":features"),
+ {static_cast<double>(feature_stats.features_count)});
+ stats_aggregator->IncrementCounter(
+ "features_count", "trainer", feature_stats.features_count);
+ stats_aggregator->IncrementCounter(
+ "feature_values_count", "trainer",
+ feature_stats.feature_values_count);
+ stats_aggregator->AddToHistogram(
+ strings::StrCat("record_stats", ":feature-values"),
+ {static_cast<double>(feature_stats.feature_values_count)});
+ }
+ }
}
done(s);
});
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 07e735c7cb..41f5f9aebe 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -170,7 +170,6 @@ REGISTER_OP("ParseExampleDataset")
.Input("input_dataset: variant")
.Input("num_parallel_calls: int64")
.Input("dense_defaults: Tdense")
-
.Output("handle: variant")
.Attr("sparse_keys: list(string) >= 0")
.Attr("dense_keys: list(string) >= 0")