aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-10-03 12:44:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 12:51:54 -0700
commit808b1dcb318b1feb5a8c9fed5558f95cd05728e4 (patch)
treea2286241b6a0c8cba24b1da629fa6e7db475d7dc /tensorflow/core/kernels
parent19833284cc8fa555115aacde350ad66652b250dc (diff)
[data-stats] Sets user given `tag` and `counter_prefix` with `set_stats_aggregator`. `tag` would get prep-end with all the statistics recorded as summary and `counter_prefix` would set the prefix for the statistics recorded as counter.
Note: `counter` defaults to `\tensorflow`, and `tag` and `prefix` gets associated with the dataset (not the stats_aggregator). PiperOrigin-RevId: 215609159
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/data/BUILD1
-rw-r--r--tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/parse_example_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc78
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_ops.cc11
5 files changed, 81 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 6333853cdf..451f8c1a6c 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -458,6 +458,7 @@ tf_kernel_library(
srcs = ["stats_aggregator_dataset_op.cc"],
deps = [
":dataset",
+ "//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal",
],
diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
index c80493d3a1..8d561ca0e3 100644
--- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
@@ -191,7 +191,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
params.runner = [pool](std::function<void()> c) {
pool->Schedule(std::move(c));
};
- params.stats_aggregator_getter = ctx->stats_aggregator_getter();
+ params.stats_aggregator = ctx->stats_aggregator();
params.lib = ctx->lib();
params.function_library = ctx->function_library();
params.allocator_getter = ctx->allocator_getter();
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
index c28c06da62..1d1a717062 100644
--- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -253,7 +253,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
for (example::PerExampleFeatureStats feature_stats :
example_result.feature_stats) {
stats_aggregator->AddToHistogram(
- strings::StrCat("record_stats", ":features"),
+ "features",
{static_cast<double>(feature_stats.features_count)});
stats_aggregator->IncrementCounter(
"features_count", "trainer", feature_stats.features_count);
@@ -261,7 +261,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
"feature_values_count", "trainer",
feature_stats.feature_values_count);
stats_aggregator->AddToHistogram(
- strings::StrCat("record_stats", ":feature-values"),
+ "feature-values",
{static_cast<double>(feature_stats.feature_values_count)});
}
}
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
index c8abfb9eb5..c09a73fff1 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <memory>
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/random/random.h"
@@ -22,6 +24,52 @@ namespace tensorflow {
namespace data {
namespace {
+class StatsAggregatorWithTagAndPrefix : public StatsAggregator {
+ public:
+ StatsAggregatorWithTagAndPrefix(
+ std::shared_ptr<StatsAggregator> stats_aggregator, const string& tag,
+ const string& prefix)
+ : wrapped_(stats_aggregator), tag_(tag), prefix_(prefix) {}
+
+ void AddToHistogram(const string& name,
+ gtl::ArraySlice<double> values) override {
+ if (!tag_.empty()) {
+ wrapped_->AddToHistogram(strings::StrCat(tag_, "_", name), values);
+ } else {
+ wrapped_->AddToHistogram(name, values);
+ }
+ }
+
+ void AddScalar(const string& name, float value) override {
+ if (!tag_.empty()) {
+ wrapped_->AddScalar(strings::StrCat(tag_, "_", name), value);
+ } else {
+ wrapped_->AddScalar(name, value);
+ }
+ }
+
+ void EncodeToProto(Summary* out_summary) override {
+ wrapped_->EncodeToProto(out_summary);
+ }
+
+ void IncrementCounter(const string& name, const string& label,
+ int64 val) override {
+ if (!prefix_.empty()) {
+ wrapped_->IncrementCounter(strings::StrCat(prefix_, "/", name), label,
+ val);
+ } else {
+ wrapped_->IncrementCounter(strings::StrCat("/tensorflow/", name), label,
+ val);
+ }
+ }
+
+ private:
+ std::shared_ptr<StatsAggregator> wrapped_;
+ string tag_;
+ string prefix_;
+ TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorWithTagAndPrefix);
+};
+
class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
public:
explicit SetStatsAggregatorDatasetOp(OpKernelConstruction* ctx)
@@ -33,8 +81,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
&stats_aggregator_resource));
core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
+ string tag;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
+ string prefix;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "counter_prefix", &prefix));
- *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource);
+ *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource,
+ tag, prefix);
}
private:
@@ -42,11 +95,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
const Tensor& resource_handle,
- StatsAggregatorResource* stats_aggregator_resource)
+ StatsAggregatorResource* stats_aggregator_resource,
+ const string& tag, const string& prefix)
: DatasetBase(DatasetContext(ctx)),
input_(input),
resource_handle_(resource_handle),
- stats_aggregator_resource_(stats_aggregator_resource) {
+ stats_aggregator_resource_(stats_aggregator_resource),
+ tag_(tag),
+ prefix_(prefix) {
input_->Ref();
stats_aggregator_resource_->Ref();
}
@@ -81,8 +137,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* resource_handle_node = nullptr;
TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
+ Node* tag_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node));
+ Node* prefix_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(prefix_, &prefix_node));
TF_RETURN_IF_ERROR(b->AddDataset(
- this, {input_graph_node, resource_handle_node}, output));
+ this, {input_graph_node, resource_handle_node, tag_node, prefix_node},
+ output));
return Status::OK();
}
@@ -105,9 +166,10 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
IteratorContext::Params params;
params.env = ctx->env();
params.runner = *(ctx->runner());
- params.stats_aggregator_getter = [stats_aggregator_resource]() {
- return stats_aggregator_resource->stats_aggregator();
- };
+ params.stats_aggregator = std::shared_ptr<StatsAggregator>(
+ new StatsAggregatorWithTagAndPrefix(
+ stats_aggregator_resource->stats_aggregator(), dataset()->tag_,
+ dataset()->prefix_));
params.lib = ctx->lib();
params.function_library = ctx->function_library();
params.allocator_getter = ctx->allocator_getter();
@@ -136,6 +198,8 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
const DatasetBase* const input_;
const Tensor resource_handle_;
StatsAggregatorResource* stats_aggregator_resource_;
+ string tag_;
+ string prefix_;
};
};
diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
index a7ded67876..2d51467616 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
@@ -82,11 +82,12 @@ class StatsAggregatorImpl : public StatsAggregator {
auto counters_map = get_counters_map();
if (counters_map->find(name) == counters_map->end()) {
counters_map->emplace(
- name, monitoring::Counter<1>::New(
- /*streamz name*/ "/tensorflow/" + name,
- /*streamz description*/
- name + " generated or consumed by the component.",
- /*streamz label name*/ "component_descriptor"));
+ name,
+ monitoring::Counter<1>::New(
+ /*streamz name*/ name,
+ /*streamz description*/
+ strings::StrCat(name, " generated or consumed by the component."),
+ /*streamz label name*/ "component_descriptor"));
}
counters_map->at(name)->GetCell(label)->IncrementBy(val);
}