diff options
author | Shivani Agrawal <shivaniagrawal@google.com> | 2018-10-03 12:44:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 12:51:54 -0700 |
commit | 808b1dcb318b1feb5a8c9fed5558f95cd05728e4 (patch) | |
tree | a2286241b6a0c8cba24b1da629fa6e7db475d7dc /tensorflow/core/kernels | |
parent | 19833284cc8fa555115aacde350ad66652b250dc (diff) |
[data-stats] Sets user given `tag` and `counter_prefix` with `set_stats_aggregator`. `tag` would get prep-end with all the statistics recorded as summary and `counter_prefix` would set the prefix for the statistics recorded as counter.
Note: `counter` defaults to `\tensorflow`, and `tag` and `prefix` gets associated with the dataset (not the stats_aggregator).
PiperOrigin-RevId: 215609159
Diffstat (limited to 'tensorflow/core/kernels')
5 files changed, 81 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 6333853cdf..451f8c1a6c 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -458,6 +458,7 @@ tf_kernel_library( srcs = ["stats_aggregator_dataset_op.cc"], deps = [ ":dataset", + "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:lib_internal", ], diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc index c80493d3a1..8d561ca0e3 100644 --- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc @@ -191,7 +191,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { params.runner = [pool](std::function<void()> c) { pool->Schedule(std::move(c)); }; - params.stats_aggregator_getter = ctx->stats_aggregator_getter(); + params.stats_aggregator = ctx->stats_aggregator(); params.lib = ctx->lib(); params.function_library = ctx->function_library(); params.allocator_getter = ctx->allocator_getter(); diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc index c28c06da62..1d1a717062 100644 --- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc @@ -253,7 +253,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { for (example::PerExampleFeatureStats feature_stats : example_result.feature_stats) { stats_aggregator->AddToHistogram( - strings::StrCat("record_stats", ":features"), + "features", {static_cast<double>(feature_stats.features_count)}); stats_aggregator->IncrementCounter( "features_count", "trainer", feature_stats.features_count); @@ -261,7 +261,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { "feature_values_count", "trainer", feature_stats.feature_values_count); stats_aggregator->AddToHistogram( - strings::StrCat("record_stats", ":feature-values"), + "feature-values", {static_cast<double>(feature_stats.feature_values_count)}); } } diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc index c8abfb9eb5..c09a73fff1 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc @@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <memory> #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/lib/random/random.h" @@ -22,6 +24,52 @@ namespace tensorflow { namespace data { namespace { +class StatsAggregatorWithTagAndPrefix : public StatsAggregator { + public: + StatsAggregatorWithTagAndPrefix( + std::shared_ptr<StatsAggregator> stats_aggregator, const string& tag, + const string& prefix) + : wrapped_(stats_aggregator), tag_(tag), prefix_(prefix) {} + + void AddToHistogram(const string& name, + gtl::ArraySlice<double> values) override { + if (!tag_.empty()) { + wrapped_->AddToHistogram(strings::StrCat(tag_, "_", name), values); + } else { + wrapped_->AddToHistogram(name, values); + } + } + + void AddScalar(const string& name, float value) override { + if (!tag_.empty()) { + wrapped_->AddScalar(strings::StrCat(tag_, "_", name), value); + } else { + wrapped_->AddScalar(name, value); + } + } + + void EncodeToProto(Summary* out_summary) override { + wrapped_->EncodeToProto(out_summary); + } + + void IncrementCounter(const string& name, const string& label, + int64 val) override { + if (!prefix_.empty()) { + wrapped_->IncrementCounter(strings::StrCat(prefix_, "/", name), label, + val); + } else { + wrapped_->IncrementCounter(strings::StrCat("/tensorflow/", name), label, + val); + } + } + + private: + std::shared_ptr<StatsAggregator> wrapped_; + string tag_; + string prefix_; + TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorWithTagAndPrefix); +}; + class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { public: explicit SetStatsAggregatorDatasetOp(OpKernelConstruction* ctx) @@ -33,8 +81,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &stats_aggregator_resource)); core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource); + string tag; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag)); + string prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "counter_prefix", &prefix)); - *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource); + *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource, + tag, prefix); } private: @@ -42,11 +95,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { public: explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, const Tensor& resource_handle, - StatsAggregatorResource* stats_aggregator_resource) + StatsAggregatorResource* stats_aggregator_resource, + const string& tag, const string& prefix) : DatasetBase(DatasetContext(ctx)), input_(input), resource_handle_(resource_handle), - stats_aggregator_resource_(stats_aggregator_resource) { + stats_aggregator_resource_(stats_aggregator_resource), + tag_(tag), + prefix_(prefix) { input_->Ref(); stats_aggregator_resource_->Ref(); } @@ -81,8 +137,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* resource_handle_node = nullptr; TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node)); + Node* tag_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node)); + Node* prefix_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(prefix_, &prefix_node)); TF_RETURN_IF_ERROR(b->AddDataset( - this, {input_graph_node, resource_handle_node}, output)); + this, {input_graph_node, resource_handle_node, tag_node, prefix_node}, + output)); return Status::OK(); } @@ -105,9 +166,10 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { IteratorContext::Params params; params.env = ctx->env(); params.runner = *(ctx->runner()); - params.stats_aggregator_getter = [stats_aggregator_resource]() { - return stats_aggregator_resource->stats_aggregator(); - }; + params.stats_aggregator = std::shared_ptr<StatsAggregator>( + new StatsAggregatorWithTagAndPrefix( + stats_aggregator_resource->stats_aggregator(), dataset()->tag_, + dataset()->prefix_)); params.lib = ctx->lib(); params.function_library = ctx->function_library(); params.allocator_getter = ctx->allocator_getter(); @@ -136,6 +198,8 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { const DatasetBase* const input_; const Tensor resource_handle_; StatsAggregatorResource* stats_aggregator_resource_; + string tag_; + string prefix_; }; }; diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc index a7ded67876..2d51467616 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc @@ -82,11 +82,12 @@ class StatsAggregatorImpl : public StatsAggregator { auto counters_map = get_counters_map(); if (counters_map->find(name) == counters_map->end()) { counters_map->emplace( - name, monitoring::Counter<1>::New( - /*streamz name*/ "/tensorflow/" + name, - /*streamz description*/ - name + " generated or consumed by the component.", - /*streamz label name*/ "component_descriptor")); + name, + monitoring::Counter<1>::New( + /*streamz name*/ name, + /*streamz description*/ + strings::StrCat(name, " generated or consumed by the component."), + /*streamz label name*/ "component_descriptor")); } counters_map->at(name)->GetCell(label)->IncrementBy(val); } |