aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/dataset.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/dataset.h')
-rw-r--r--tensorflow/core/framework/dataset.h34
1 files changed, 10 insertions, 24 deletions
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 697e0604bf..964a7d5f8c 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -278,15 +278,8 @@ class IteratorContext {
// Function call support.
std::function<void(std::function<void()>)> runner = nullptr;
- // A function that returns the current `StatsAggregator` instance to be
- // used when recording statistics about the iterator.
- //
- // NOTE(mrry): This is somewhat awkward, because (i) the `StatsAggregator`
- // is a property of the `IteratorResource` (which this class does not know
- // about), and (ii) it can change after the `IteratorContext` has been
- // created. Better suggestions are welcome!
- std::function<std::shared_ptr<StatsAggregator>()> stats_aggregator_getter =
- nullptr;
+ // The `StatsAggregator` object to record statistics about the iterator.
+ std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
// The FunctionLibraryRuntime object to be used to make function calls.
FunctionLibraryRuntime* lib = nullptr;
@@ -320,13 +313,6 @@ class IteratorContext {
return &params_.runner;
}
- std::shared_ptr<StatsAggregator> stats_aggregator() {
- if (params_.stats_aggregator_getter) {
- return params_.stats_aggregator_getter();
- } else {
- return nullptr;
- }
- }
std::shared_ptr<const FunctionLibraryDefinition> function_library() {
return params_.function_library;
@@ -344,8 +330,8 @@ class IteratorContext {
return params_.allocator_getter;
}
- std::function<std::shared_ptr<StatsAggregator>()> stats_aggregator_getter() {
- return params_.stats_aggregator_getter;
+ std::shared_ptr<StatsAggregator> stats_aggregator() {
+ return params_.stats_aggregator;
}
std::shared_ptr<model::Model> model() { return params_.model; }
@@ -657,15 +643,15 @@ class DatasetBaseIterator : public IteratorBase {
// When performance modeling is enabled, this method adds a tunable parameter
// to the model node corresponding to this iterator.
//
- // The performance modeling logic may use `value` to set the value of the
+ // The performance modeling logic may use `state` to set the value of the
// tunable parameter at any point during the lifetime of this iterator. When
- // it does, it notifies `cond_var`.
+ // it does, it acquires `state->mu` and notifies `state->cond_var`.
void AddTunableParameter(IteratorContext* ctx, const string& name,
- std::atomic<int64>* value, int64 min, int64 max,
- condition_variable* cond_var) {
+ std::shared_ptr<model::SharedState> state, int64 min,
+ int64 max) {
if (ctx->model()) {
- ctx->model()->AddTunableParameter(prefix(), name, value, min, max,
- cond_var);
+ ctx->model()->AddTunableParameter(prefix(), name, std::move(state), min,
+ max);
}
}