diff options
Diffstat (limited to 'tensorflow/core/framework/dataset.h')
-rw-r--r-- | tensorflow/core/framework/dataset.h | 34 |
1 files changed, 10 insertions, 24 deletions
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 697e0604bf..964a7d5f8c 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -278,15 +278,8 @@ class IteratorContext { // Function call support. std::function<void(std::function<void()>)> runner = nullptr; - // A function that returns the current `StatsAggregator` instance to be - // used when recording statistics about the iterator. - // - // NOTE(mrry): This is somewhat awkward, because (i) the `StatsAggregator` - // is a property of the `IteratorResource` (which this class does not know - // about), and (ii) it can change after the `IteratorContext` has been - // created. Better suggestions are welcome! - std::function<std::shared_ptr<StatsAggregator>()> stats_aggregator_getter = - nullptr; + // The `StatsAggregator` object to record statistics about the iterator. + std::shared_ptr<StatsAggregator> stats_aggregator = nullptr; // The FunctionLibraryRuntime object to be used to make function calls. FunctionLibraryRuntime* lib = nullptr; @@ -320,13 +313,6 @@ class IteratorContext { return ¶ms_.runner; } - std::shared_ptr<StatsAggregator> stats_aggregator() { - if (params_.stats_aggregator_getter) { - return params_.stats_aggregator_getter(); - } else { - return nullptr; - } - } std::shared_ptr<const FunctionLibraryDefinition> function_library() { return params_.function_library; @@ -344,8 +330,8 @@ class IteratorContext { return params_.allocator_getter; } - std::function<std::shared_ptr<StatsAggregator>()> stats_aggregator_getter() { - return params_.stats_aggregator_getter; + std::shared_ptr<StatsAggregator> stats_aggregator() { + return params_.stats_aggregator; } std::shared_ptr<model::Model> model() { return params_.model; } @@ -657,15 +643,15 @@ class DatasetBaseIterator : public IteratorBase { // When performance modeling is enabled, this method adds a tunable parameter // to the model node corresponding to this iterator. // - // The performance modeling logic may use `value` to set the value of the + // The performance modeling logic may use `state` to set the value of the // tunable parameter at any point during the lifetime of this iterator. When - // it does, it notifies `cond_var`. + // it does, it acquires `state->mu` and notifies `state->cond_var`. void AddTunableParameter(IteratorContext* ctx, const string& name, - std::atomic<int64>* value, int64 min, int64 max, - condition_variable* cond_var) { + std::shared_ptr<model::SharedState> state, int64 min, + int64 max) { if (ctx->model()) { - ctx->model()->AddTunableParameter(prefix(), name, value, min, max, - cond_var); + ctx->model()->AddTunableParameter(prefix(), name, std::move(state), min, + max); } } |