diff options
12 files changed, 179 insertions, 39 deletions
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 8c1151cb56..964a7d5f8c 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -278,15 +278,8 @@ class IteratorContext { // Function call support. std::function<void(std::function<void()>)> runner = nullptr; - // A function that returns the current `StatsAggregator` instance to be - // used when recording statistics about the iterator. - // - // NOTE(mrry): This is somewhat awkward, because (i) the `StatsAggregator` - // is a property of the `IteratorResource` (which this class does not know - // about), and (ii) it can change after the `IteratorContext` has been - // created. Better suggestions are welcome! - std::function<std::shared_ptr<StatsAggregator>()> stats_aggregator_getter = - nullptr; + // The `StatsAggregator` object to record statistics about the iterator. + std::shared_ptr<StatsAggregator> stats_aggregator = nullptr; // The FunctionLibraryRuntime object to be used to make function calls. FunctionLibraryRuntime* lib = nullptr; @@ -320,13 +313,6 @@ class IteratorContext { return ¶ms_.runner; } - std::shared_ptr<StatsAggregator> stats_aggregator() { - if (params_.stats_aggregator_getter) { - return params_.stats_aggregator_getter(); - } else { - return nullptr; - } - } std::shared_ptr<const FunctionLibraryDefinition> function_library() { return params_.function_library; @@ -344,8 +330,8 @@ class IteratorContext { return params_.allocator_getter; } - std::function<std::shared_ptr<StatsAggregator>()> stats_aggregator_getter() { - return params_.stats_aggregator_getter; + std::shared_ptr<StatsAggregator> stats_aggregator() { + return params_.stats_aggregator; } std::shared_ptr<model::Model> model() { return params_.model; } diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 6333853cdf..451f8c1a6c 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -458,6 +458,7 @@ tf_kernel_library( srcs = ["stats_aggregator_dataset_op.cc"], deps = [ ":dataset", + "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:lib_internal", ], diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc index c80493d3a1..8d561ca0e3 100644 --- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc @@ -191,7 +191,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { params.runner = [pool](std::function<void()> c) { pool->Schedule(std::move(c)); }; - params.stats_aggregator_getter = ctx->stats_aggregator_getter(); + params.stats_aggregator = ctx->stats_aggregator(); params.lib = ctx->lib(); params.function_library = ctx->function_library(); params.allocator_getter = ctx->allocator_getter(); diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc index c28c06da62..1d1a717062 100644 --- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc @@ -253,7 +253,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { for (example::PerExampleFeatureStats feature_stats : example_result.feature_stats) { stats_aggregator->AddToHistogram( - strings::StrCat("record_stats", ":features"), + "features", {static_cast<double>(feature_stats.features_count)}); stats_aggregator->IncrementCounter( "features_count", "trainer", feature_stats.features_count); @@ -261,7 +261,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { "feature_values_count", "trainer", feature_stats.feature_values_count); stats_aggregator->AddToHistogram( - strings::StrCat("record_stats", ":feature-values"), + "feature-values", {static_cast<double>(feature_stats.feature_values_count)}); } } diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc index c8abfb9eb5..c09a73fff1 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc @@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <memory> #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/lib/random/random.h" @@ -22,6 +24,52 @@ namespace tensorflow { namespace data { namespace { +class StatsAggregatorWithTagAndPrefix : public StatsAggregator { + public: + StatsAggregatorWithTagAndPrefix( + std::shared_ptr<StatsAggregator> stats_aggregator, const string& tag, + const string& prefix) + : wrapped_(stats_aggregator), tag_(tag), prefix_(prefix) {} + + void AddToHistogram(const string& name, + gtl::ArraySlice<double> values) override { + if (!tag_.empty()) { + wrapped_->AddToHistogram(strings::StrCat(tag_, "_", name), values); + } else { + wrapped_->AddToHistogram(name, values); + } + } + + void AddScalar(const string& name, float value) override { + if (!tag_.empty()) { + wrapped_->AddScalar(strings::StrCat(tag_, "_", name), value); + } else { + wrapped_->AddScalar(name, value); + } + } + + void EncodeToProto(Summary* out_summary) override { + wrapped_->EncodeToProto(out_summary); + } + + void IncrementCounter(const string& name, const string& label, + int64 val) override { + if (!prefix_.empty()) { + wrapped_->IncrementCounter(strings::StrCat(prefix_, "/", name), label, + val); + } else { + wrapped_->IncrementCounter(strings::StrCat("/tensorflow/", name), label, + val); + } + } + + private: + std::shared_ptr<StatsAggregator> wrapped_; + string tag_; + string prefix_; + TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorWithTagAndPrefix); +}; + class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { public: explicit SetStatsAggregatorDatasetOp(OpKernelConstruction* ctx) @@ -33,8 +81,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &stats_aggregator_resource)); core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource); + string tag; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag)); + string prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "counter_prefix", &prefix)); - *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource); + *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource, + tag, prefix); } private: @@ -42,11 +95,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { public: explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, const Tensor& resource_handle, - StatsAggregatorResource* stats_aggregator_resource) + StatsAggregatorResource* stats_aggregator_resource, + const string& tag, const string& prefix) : DatasetBase(DatasetContext(ctx)), input_(input), resource_handle_(resource_handle), - stats_aggregator_resource_(stats_aggregator_resource) { + stats_aggregator_resource_(stats_aggregator_resource), + tag_(tag), + prefix_(prefix) { input_->Ref(); stats_aggregator_resource_->Ref(); } @@ -81,8 +137,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* resource_handle_node = nullptr; TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node)); + Node* tag_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node)); + Node* prefix_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(prefix_, &prefix_node)); TF_RETURN_IF_ERROR(b->AddDataset( - this, {input_graph_node, resource_handle_node}, output)); + this, {input_graph_node, resource_handle_node, tag_node, prefix_node}, + output)); return Status::OK(); } @@ -105,9 +166,10 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { IteratorContext::Params params; params.env = ctx->env(); params.runner = *(ctx->runner()); - params.stats_aggregator_getter = [stats_aggregator_resource]() { - return stats_aggregator_resource->stats_aggregator(); - }; + params.stats_aggregator = std::shared_ptr<StatsAggregator>( + new StatsAggregatorWithTagAndPrefix( + stats_aggregator_resource->stats_aggregator(), dataset()->tag_, + dataset()->prefix_)); params.lib = ctx->lib(); params.function_library = ctx->function_library(); params.allocator_getter = ctx->allocator_getter(); @@ -136,6 +198,8 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { const DatasetBase* const input_; const Tensor resource_handle_; StatsAggregatorResource* stats_aggregator_resource_; + string tag_; + string prefix_; }; }; diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc index a7ded67876..2d51467616 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc @@ -82,11 +82,12 @@ class StatsAggregatorImpl : public StatsAggregator { auto counters_map = get_counters_map(); if (counters_map->find(name) == counters_map->end()) { counters_map->emplace( - name, monitoring::Counter<1>::New( - /*streamz name*/ "/tensorflow/" + name, - /*streamz description*/ - name + " generated or consumed by the component.", - /*streamz label name*/ "component_descriptor")); + name, + monitoring::Counter<1>::New( + /*streamz name*/ name, + /*streamz description*/ + strings::StrCat(name, " generated or consumed by the component."), + /*streamz label name*/ "component_descriptor")); } counters_map->at(name)->GetCell(label)->IncrementBy(val); } diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 4845767405..33f18ae13f 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -59785,6 +59785,14 @@ op { name: "stats_aggregator" type: DT_RESOURCE } + input_arg { + name: "tag" + type: DT_STRING + } + input_arg { + name: "counter_prefix" + type: DT_STRING + } output_arg { name: "handle" type: DT_VARIANT diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 71f4cc3c4c..889a6a4640 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -185,6 +185,8 @@ REGISTER_OP("ParseExampleDataset") REGISTER_OP("SetStatsAggregatorDataset") .Input("input_dataset: variant") .Input("stats_aggregator: resource") + .Input("tag: string") + .Input("counter_prefix: string") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py index 6761fbd16b..19f5a62d45 100644 --- a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base from tensorflow.python.data.experimental.ops import stats_ops from tensorflow.python.data.ops import dataset_ops @@ -248,6 +249,74 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): sess.run(next_element) self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) + def testMultipleDatasetWithTags(self): + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator, "dataset1")) + dataset2 = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator, "dataset2")) + iterator_0 = dataset.make_initializable_iterator() + iterator_1 = dataset2.make_initializable_iterator() + next_element = iterator_0.get_next() + iterator_1.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run([iterator_0.initializer, iterator_1.initializer]) + for i in range(100): + self.assertEqual(i * 2, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "dataset1_record_latency", float(i + 1)) + self._assertSummaryHasCount( + sess.run(summary_t), "dataset2_record_latency", float(i + 1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount( + sess.run(summary_t), "dataset1_record_latency", 100.0) + self._assertSummaryHasCount( + sess.run(summary_t), "dataset2_record_latency", 100.0) + + +class FeatureStatsDatasetTest( + stats_dataset_test_base.StatsDatasetTestBase, + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): + + def testFeaturesStats(self): + num_epochs = 5 + total_records = num_epochs * self._num_records + batch_size = 2 + stats_aggregator = stats_ops.StatsAggregator() + dataset = self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5, + drop_final_batch=False).apply( + stats_ops.set_stats_aggregator(stats_aggregator, "record_stats")) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for _ in range(total_records // batch_size + 1 if total_records % + batch_size else total_records // batch_size): + sess.run(next_element) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount( + sess.run(summary_t), "record_stats_features", total_records) + self._assertSummaryHasCount( + sess.run(summary_t), "record_stats_feature-values", total_records) + self._assertSummaryHasSum( + sess.run(summary_t), "record_stats_features", total_records * 4) + self._assertSummaryHasSum( + sess.run(summary_t), "record_stats_feature-values", + self._sum_keywords(1) * num_epochs + 3 * total_records) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/experimental/ops/stats_ops.py b/tensorflow/python/data/experimental/ops/stats_ops.py index c918d223e8..54ef6fc3e8 100644 --- a/tensorflow/python/data/experimental/ops/stats_ops.py +++ b/tensorflow/python/data/experimental/ops/stats_ops.py @@ -89,15 +89,19 @@ class StatsAggregator(object): class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset): """A `Dataset` that acts as an identity, and sets given stats_aggregator.""" - def __init__(self, input_dataset, stats_aggregator): + def __init__(self, input_dataset, stats_aggregator, tag, prefix): super(_SetStatsAggregatorDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._stats_aggregator = stats_aggregator + self._tag = tag + self._prefix = prefix def _as_variant_tensor(self): return gen_dataset_ops.set_stats_aggregator_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._stats_aggregator._resource, # pylint: disable=protected-access + self._tag, + self._prefix, **dataset_ops.flat_structure(self)) @property @@ -114,11 +118,15 @@ class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset): @tf_export("data.experimental.set_stats_aggregator") -def set_stats_aggregator(stats_aggregator): +def set_stats_aggregator(stats_aggregator, tag="", counter_prefix=""): """Set the given `stats_aggregator` for aggregating the input dataset stats. Args: - stats_aggregator: A `tf.data.experimental.StatsAggregator` object. + stats_aggregator: A `tf.contrib.data.StatsAggregator` object. + tag: (Optional) String, all statistics recorded for the input `dataset` + will have given `tag` prepend with the name. + counter_prefix: (Optional) String, all statistics recorded as `counters` + will have the given `prefix` for the counter. Defaults to "/tesorflow". Returns: A `Dataset` transformation function, which can be passed to @@ -126,7 +134,8 @@ def set_stats_aggregator(stats_aggregator): """ def _apply_fn(dataset): - return _SetStatsAggregatorDataset(dataset, stats_aggregator) + return _SetStatsAggregatorDataset(dataset, stats_aggregator, tag, + counter_prefix) return _apply_fn diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt index b14585f8d7..2a1f899dc0 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt @@ -122,7 +122,7 @@ tf_module { } member_method { name: "set_stats_aggregator" - argspec: "args=[\'stats_aggregator\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'stats_aggregator\', \'tag\', \'counter_prefix\'], varargs=None, keywords=None, defaults=[\'\', \'\'], " } member_method { name: "shuffle_and_repeat" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt index b14585f8d7..2a1f899dc0 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt @@ -122,7 +122,7 @@ tf_module { } member_method { name: "set_stats_aggregator" - argspec: "args=[\'stats_aggregator\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'stats_aggregator\', \'tag\', \'counter_prefix\'], varargs=None, keywords=None, defaults=[\'\', \'\'], " } member_method { name: "shuffle_and_repeat" |