aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/framework/dataset.h22
-rw-r--r--tensorflow/core/kernels/data/BUILD1
-rw-r--r--tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/parse_example_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc78
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_ops.cc11
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt8
-rw-r--r--tensorflow/core/ops/dataset_ops.cc2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py69
-rw-r--r--tensorflow/python/data/experimental/ops/stats_ops.py17
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt2
12 files changed, 179 insertions, 39 deletions
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 8c1151cb56..964a7d5f8c 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -278,15 +278,8 @@ class IteratorContext {
// Function call support.
std::function<void(std::function<void()>)> runner = nullptr;
- // A function that returns the current `StatsAggregator` instance to be
- // used when recording statistics about the iterator.
- //
- // NOTE(mrry): This is somewhat awkward, because (i) the `StatsAggregator`
- // is a property of the `IteratorResource` (which this class does not know
- // about), and (ii) it can change after the `IteratorContext` has been
- // created. Better suggestions are welcome!
- std::function<std::shared_ptr<StatsAggregator>()> stats_aggregator_getter =
- nullptr;
+ // The `StatsAggregator` object to record statistics about the iterator.
+ std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
// The FunctionLibraryRuntime object to be used to make function calls.
FunctionLibraryRuntime* lib = nullptr;
@@ -320,13 +313,6 @@ class IteratorContext {
return &params_.runner;
}
- std::shared_ptr<StatsAggregator> stats_aggregator() {
- if (params_.stats_aggregator_getter) {
- return params_.stats_aggregator_getter();
- } else {
- return nullptr;
- }
- }
std::shared_ptr<const FunctionLibraryDefinition> function_library() {
return params_.function_library;
@@ -344,8 +330,8 @@ class IteratorContext {
return params_.allocator_getter;
}
- std::function<std::shared_ptr<StatsAggregator>()> stats_aggregator_getter() {
- return params_.stats_aggregator_getter;
+ std::shared_ptr<StatsAggregator> stats_aggregator() {
+ return params_.stats_aggregator;
}
std::shared_ptr<model::Model> model() { return params_.model; }
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 6333853cdf..451f8c1a6c 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -458,6 +458,7 @@ tf_kernel_library(
srcs = ["stats_aggregator_dataset_op.cc"],
deps = [
":dataset",
+ "//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal",
],
diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
index c80493d3a1..8d561ca0e3 100644
--- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
@@ -191,7 +191,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
params.runner = [pool](std::function<void()> c) {
pool->Schedule(std::move(c));
};
- params.stats_aggregator_getter = ctx->stats_aggregator_getter();
+ params.stats_aggregator = ctx->stats_aggregator();
params.lib = ctx->lib();
params.function_library = ctx->function_library();
params.allocator_getter = ctx->allocator_getter();
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
index c28c06da62..1d1a717062 100644
--- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -253,7 +253,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
for (example::PerExampleFeatureStats feature_stats :
example_result.feature_stats) {
stats_aggregator->AddToHistogram(
- strings::StrCat("record_stats", ":features"),
+ "features",
{static_cast<double>(feature_stats.features_count)});
stats_aggregator->IncrementCounter(
"features_count", "trainer", feature_stats.features_count);
@@ -261,7 +261,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
"feature_values_count", "trainer",
feature_stats.feature_values_count);
stats_aggregator->AddToHistogram(
- strings::StrCat("record_stats", ":feature-values"),
+ "feature-values",
{static_cast<double>(feature_stats.feature_values_count)});
}
}
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
index c8abfb9eb5..c09a73fff1 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <memory>
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/random/random.h"
@@ -22,6 +24,52 @@ namespace tensorflow {
namespace data {
namespace {
+class StatsAggregatorWithTagAndPrefix : public StatsAggregator {
+ public:
+ StatsAggregatorWithTagAndPrefix(
+ std::shared_ptr<StatsAggregator> stats_aggregator, const string& tag,
+ const string& prefix)
+ : wrapped_(stats_aggregator), tag_(tag), prefix_(prefix) {}
+
+ void AddToHistogram(const string& name,
+ gtl::ArraySlice<double> values) override {
+ if (!tag_.empty()) {
+ wrapped_->AddToHistogram(strings::StrCat(tag_, "_", name), values);
+ } else {
+ wrapped_->AddToHistogram(name, values);
+ }
+ }
+
+ void AddScalar(const string& name, float value) override {
+ if (!tag_.empty()) {
+ wrapped_->AddScalar(strings::StrCat(tag_, "_", name), value);
+ } else {
+ wrapped_->AddScalar(name, value);
+ }
+ }
+
+ void EncodeToProto(Summary* out_summary) override {
+ wrapped_->EncodeToProto(out_summary);
+ }
+
+ void IncrementCounter(const string& name, const string& label,
+ int64 val) override {
+ if (!prefix_.empty()) {
+ wrapped_->IncrementCounter(strings::StrCat(prefix_, "/", name), label,
+ val);
+ } else {
+ wrapped_->IncrementCounter(strings::StrCat("/tensorflow/", name), label,
+ val);
+ }
+ }
+
+ private:
+ std::shared_ptr<StatsAggregator> wrapped_;
+ string tag_;
+ string prefix_;
+ TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorWithTagAndPrefix);
+};
+
class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
public:
explicit SetStatsAggregatorDatasetOp(OpKernelConstruction* ctx)
@@ -33,8 +81,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
&stats_aggregator_resource));
core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
+ string tag;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
+ string prefix;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "counter_prefix", &prefix));
- *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource);
+ *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource,
+ tag, prefix);
}
private:
@@ -42,11 +95,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
const Tensor& resource_handle,
- StatsAggregatorResource* stats_aggregator_resource)
+ StatsAggregatorResource* stats_aggregator_resource,
+ const string& tag, const string& prefix)
: DatasetBase(DatasetContext(ctx)),
input_(input),
resource_handle_(resource_handle),
- stats_aggregator_resource_(stats_aggregator_resource) {
+ stats_aggregator_resource_(stats_aggregator_resource),
+ tag_(tag),
+ prefix_(prefix) {
input_->Ref();
stats_aggregator_resource_->Ref();
}
@@ -81,8 +137,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* resource_handle_node = nullptr;
TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
+ Node* tag_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node));
+ Node* prefix_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(prefix_, &prefix_node));
TF_RETURN_IF_ERROR(b->AddDataset(
- this, {input_graph_node, resource_handle_node}, output));
+ this, {input_graph_node, resource_handle_node, tag_node, prefix_node},
+ output));
return Status::OK();
}
@@ -105,9 +166,10 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
IteratorContext::Params params;
params.env = ctx->env();
params.runner = *(ctx->runner());
- params.stats_aggregator_getter = [stats_aggregator_resource]() {
- return stats_aggregator_resource->stats_aggregator();
- };
+ params.stats_aggregator = std::shared_ptr<StatsAggregator>(
+ new StatsAggregatorWithTagAndPrefix(
+ stats_aggregator_resource->stats_aggregator(), dataset()->tag_,
+ dataset()->prefix_));
params.lib = ctx->lib();
params.function_library = ctx->function_library();
params.allocator_getter = ctx->allocator_getter();
@@ -136,6 +198,8 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
const DatasetBase* const input_;
const Tensor resource_handle_;
StatsAggregatorResource* stats_aggregator_resource_;
+ string tag_;
+ string prefix_;
};
};
diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
index a7ded67876..2d51467616 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
@@ -82,11 +82,12 @@ class StatsAggregatorImpl : public StatsAggregator {
auto counters_map = get_counters_map();
if (counters_map->find(name) == counters_map->end()) {
counters_map->emplace(
- name, monitoring::Counter<1>::New(
- /*streamz name*/ "/tensorflow/" + name,
- /*streamz description*/
- name + " generated or consumed by the component.",
- /*streamz label name*/ "component_descriptor"));
+ name,
+ monitoring::Counter<1>::New(
+ /*streamz name*/ name,
+ /*streamz description*/
+ strings::StrCat(name, " generated or consumed by the component."),
+ /*streamz label name*/ "component_descriptor"));
}
counters_map->at(name)->GetCell(label)->IncrementBy(val);
}
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 4845767405..33f18ae13f 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -59785,6 +59785,14 @@ op {
name: "stats_aggregator"
type: DT_RESOURCE
}
+ input_arg {
+ name: "tag"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "counter_prefix"
+ type: DT_STRING
+ }
output_arg {
name: "handle"
type: DT_VARIANT
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 71f4cc3c4c..889a6a4640 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -185,6 +185,8 @@ REGISTER_OP("ParseExampleDataset")
REGISTER_OP("SetStatsAggregatorDataset")
.Input("input_dataset: variant")
.Input("stats_aggregator: resource")
+ .Input("tag: string")
+ .Input("counter_prefix: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
index 6761fbd16b..19f5a62d45 100644
--- a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
from tensorflow.python.data.experimental.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
@@ -248,6 +249,74 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
sess.run(next_element)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
+ def testMultipleDatasetWithTags(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator, "dataset1"))
+ dataset2 = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator, "dataset2"))
+ iterator_0 = dataset.make_initializable_iterator()
+ iterator_1 = dataset2.make_initializable_iterator()
+ next_element = iterator_0.get_next() + iterator_1.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run([iterator_0.initializer, iterator_1.initializer])
+ for i in range(100):
+ self.assertEqual(i * 2, sess.run(next_element))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "dataset1_record_latency", float(i + 1))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "dataset2_record_latency", float(i + 1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "dataset1_record_latency", 100.0)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "dataset2_record_latency", 100.0)
+
+
+class FeatureStatsDatasetTest(
+ stats_dataset_test_base.StatsDatasetTestBase,
+ reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
+
+ def testFeaturesStats(self):
+ num_epochs = 5
+ total_records = num_epochs * self._num_records
+ batch_size = 2
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5,
+ drop_final_batch=False).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator, "record_stats"))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for _ in range(total_records // batch_size + 1 if total_records %
+ batch_size else total_records // batch_size):
+ sess.run(next_element)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_stats_features", total_records)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_stats_feature-values", total_records)
+ self._assertSummaryHasSum(
+ sess.run(summary_t), "record_stats_features", total_records * 4)
+ self._assertSummaryHasSum(
+ sess.run(summary_t), "record_stats_feature-values",
+ self._sum_keywords(1) * num_epochs + 3 * total_records)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/experimental/ops/stats_ops.py b/tensorflow/python/data/experimental/ops/stats_ops.py
index c918d223e8..54ef6fc3e8 100644
--- a/tensorflow/python/data/experimental/ops/stats_ops.py
+++ b/tensorflow/python/data/experimental/ops/stats_ops.py
@@ -89,15 +89,19 @@ class StatsAggregator(object):
class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and sets given stats_aggregator."""
- def __init__(self, input_dataset, stats_aggregator):
+ def __init__(self, input_dataset, stats_aggregator, tag, prefix):
super(_SetStatsAggregatorDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._stats_aggregator = stats_aggregator
+ self._tag = tag
+ self._prefix = prefix
def _as_variant_tensor(self):
return gen_dataset_ops.set_stats_aggregator_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._stats_aggregator._resource, # pylint: disable=protected-access
+ self._tag,
+ self._prefix,
**dataset_ops.flat_structure(self))
@property
@@ -114,11 +118,15 @@ class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset):
@tf_export("data.experimental.set_stats_aggregator")
-def set_stats_aggregator(stats_aggregator):
+def set_stats_aggregator(stats_aggregator, tag="", counter_prefix=""):
"""Set the given `stats_aggregator` for aggregating the input dataset stats.
Args:
- stats_aggregator: A `tf.data.experimental.StatsAggregator` object.
+ stats_aggregator: A `tf.contrib.data.StatsAggregator` object.
+ tag: (Optional) String, all statistics recorded for the input `dataset`
+ will have given `tag` prepend with the name.
+ counter_prefix: (Optional) String, all statistics recorded as `counters`
+ will have the given `prefix` for the counter. Defaults to "/tesorflow".
Returns:
A `Dataset` transformation function, which can be passed to
@@ -126,7 +134,8 @@ def set_stats_aggregator(stats_aggregator):
"""
def _apply_fn(dataset):
- return _SetStatsAggregatorDataset(dataset, stats_aggregator)
+ return _SetStatsAggregatorDataset(dataset, stats_aggregator, tag,
+ counter_prefix)
return _apply_fn
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt
index b14585f8d7..2a1f899dc0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt
@@ -122,7 +122,7 @@ tf_module {
}
member_method {
name: "set_stats_aggregator"
- argspec: "args=[\'stats_aggregator\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'stats_aggregator\', \'tag\', \'counter_prefix\'], varargs=None, keywords=None, defaults=[\'\', \'\'], "
}
member_method {
name: "shuffle_and_repeat"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt
index b14585f8d7..2a1f899dc0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt
@@ -122,7 +122,7 @@ tf_module {
}
member_method {
name: "set_stats_aggregator"
- argspec: "args=[\'stats_aggregator\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'stats_aggregator\', \'tag\', \'counter_prefix\'], varargs=None, keywords=None, defaults=[\'\', \'\'], "
}
member_method {
name: "shuffle_and_repeat"