diff options
author | 2018-10-03 12:44:47 -0700 | |
---|---|---|
committer | 2018-10-03 12:51:54 -0700 | |
commit | 808b1dcb318b1feb5a8c9fed5558f95cd05728e4 (patch) | |
tree | a2286241b6a0c8cba24b1da629fa6e7db475d7dc /tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc | |
parent | 19833284cc8fa555115aacde350ad66652b250dc (diff) |
[data-stats] Sets user given `tag` and `counter_prefix` with `set_stats_aggregator`. `tag` would get prep-end with all the statistics recorded as summary and `counter_prefix` would set the prefix for the statistics recorded as counter.
Note: `counter` defaults to `\tensorflow`, and `tag` and `prefix` gets associated with the dataset (not the stats_aggregator).
PiperOrigin-RevId: 215609159
Diffstat (limited to 'tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc | 78 |
1 files changed, 71 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc index c8abfb9eb5..c09a73fff1 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc @@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <memory> #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/lib/random/random.h" @@ -22,6 +24,52 @@ namespace tensorflow { namespace data { namespace { +class StatsAggregatorWithTagAndPrefix : public StatsAggregator { + public: + StatsAggregatorWithTagAndPrefix( + std::shared_ptr<StatsAggregator> stats_aggregator, const string& tag, + const string& prefix) + : wrapped_(stats_aggregator), tag_(tag), prefix_(prefix) {} + + void AddToHistogram(const string& name, + gtl::ArraySlice<double> values) override { + if (!tag_.empty()) { + wrapped_->AddToHistogram(strings::StrCat(tag_, "_", name), values); + } else { + wrapped_->AddToHistogram(name, values); + } + } + + void AddScalar(const string& name, float value) override { + if (!tag_.empty()) { + wrapped_->AddScalar(strings::StrCat(tag_, "_", name), value); + } else { + wrapped_->AddScalar(name, value); + } + } + + void EncodeToProto(Summary* out_summary) override { + wrapped_->EncodeToProto(out_summary); + } + + void IncrementCounter(const string& name, const string& label, + int64 val) override { + if (!prefix_.empty()) { + wrapped_->IncrementCounter(strings::StrCat(prefix_, "/", name), label, + val); + } else { + wrapped_->IncrementCounter(strings::StrCat("/tensorflow/", name), label, + val); + } + } + + private: + std::shared_ptr<StatsAggregator> wrapped_; + string tag_; + string prefix_; + TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorWithTagAndPrefix); +}; + class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { public: explicit SetStatsAggregatorDatasetOp(OpKernelConstruction* ctx) @@ -33,8 +81,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &stats_aggregator_resource)); core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource); + string tag; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag)); + string prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "counter_prefix", &prefix)); - *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource); + *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource, + tag, prefix); } private: @@ -42,11 +95,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { public: explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, const Tensor& resource_handle, - StatsAggregatorResource* stats_aggregator_resource) + StatsAggregatorResource* stats_aggregator_resource, + const string& tag, const string& prefix) : DatasetBase(DatasetContext(ctx)), input_(input), resource_handle_(resource_handle), - stats_aggregator_resource_(stats_aggregator_resource) { + stats_aggregator_resource_(stats_aggregator_resource), + tag_(tag), + prefix_(prefix) { input_->Ref(); stats_aggregator_resource_->Ref(); } @@ -81,8 +137,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* resource_handle_node = nullptr; TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node)); + Node* tag_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node)); + Node* prefix_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(prefix_, &prefix_node)); TF_RETURN_IF_ERROR(b->AddDataset( - this, {input_graph_node, resource_handle_node}, output)); + this, {input_graph_node, resource_handle_node, tag_node, prefix_node}, + output)); return Status::OK(); } @@ -105,9 +166,10 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { IteratorContext::Params params; params.env = ctx->env(); params.runner = *(ctx->runner()); - params.stats_aggregator_getter = [stats_aggregator_resource]() { - return stats_aggregator_resource->stats_aggregator(); - }; + params.stats_aggregator = std::shared_ptr<StatsAggregator>( + new StatsAggregatorWithTagAndPrefix( + stats_aggregator_resource->stats_aggregator(), dataset()->tag_, + dataset()->prefix_)); params.lib = ctx->lib(); params.function_library = ctx->function_library(); params.allocator_getter = ctx->allocator_getter(); @@ -136,6 +198,8 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { const DatasetBase* const input_; const Tensor resource_handle_; StatsAggregatorResource* stats_aggregator_resource_; + string tag_; + string prefix_; }; }; |