aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-20 16:48:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 16:53:24 -0700
commit0e1efc3d9129c740a16081fdc53bdc482f8f0c11 (patch)
treefd153014209d962e8fccf90c31a372bc59ec38a9 /tensorflow/core/framework
parent863b3bd6bb2214065f95fa20f551a8fc8568e55d (diff)
[tf.data] Moving auto-tuning optimizations into a background thread, refactoring the API for exposing tunable parameters, and removing `model::Node` from the public API.
PiperOrigin-RevId: 213907565
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/dataset.h87
-rw-r--r--tensorflow/core/framework/model.cc204
-rw-r--r--tensorflow/core/framework/model.h613
3 files changed, 475 insertions, 429 deletions
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 91b1e61d3c..697e0604bf 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -529,25 +529,11 @@ class DatasetBase : public core::RefCounted {
std::unique_ptr<IteratorBase>* iterator) const {
*iterator = MakeIteratorInternal(prefix);
if (ctx->model()) {
- // The prefix might contain an index. We need to strip it to make it
- // possible for the model to successfully identify the output node.
- string sanitized_prefix = prefix;
- if (str_util::EndsWith(prefix, "]")) {
- sanitized_prefix = prefix.substr(0, prefix.rfind('['));
- }
- std::shared_ptr<model::Node> node =
- ctx->model()->AddNode((*iterator)->prefix(), sanitized_prefix);
- std::vector<string> tokens =
- str_util::Split((*iterator)->prefix(), ':', str_util::SkipEmpty());
- node->set_name(tokens[tokens.size() - 1]);
+ ctx->model()->AddNode((*iterator)->prefix(), prefix);
std::shared_ptr<model::Model> model = ctx->model();
const string& prefix = (*iterator)->prefix();
- (*iterator)->AddCleanupFunction([model, node, prefix]() {
- if (node->output()) {
- node->output()->remove_input(node);
- }
- model->RemoveNode(prefix);
- });
+ (*iterator)->AddCleanupFunction(
+ [model, prefix]() { model->RemoveNode(prefix); });
}
return (*iterator)->Initialize(ctx);
}
@@ -629,23 +615,10 @@ class DatasetBaseIterator : public IteratorBase {
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) final {
tracing::ScopedActivity activity(params_.prefix);
- Status s;
- if (ctx->model()) {
- std::shared_ptr<model::Node> node =
- ctx->model()->LookupNode(params_.prefix);
- if (node->output()) {
- node->output()->stop_work();
- }
- node->start_work();
- s = GetNextInternal(ctx, out_tensors, end_of_sequence);
- node->stop_work();
- node->add_element();
- if (node->output()) {
- node->output()->start_work();
- }
- } else {
- s = GetNextInternal(ctx, out_tensors, end_of_sequence);
- }
+ RecordStart(ctx, true /* stop_output */);
+ Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+ if (s.ok() && !*end_of_sequence) RecordElement(ctx);
+ RecordStop(ctx, true /* start_output */);
if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) {
s = errors::Internal(
"Iterator \"", params_.prefix,
@@ -677,52 +650,46 @@ class DatasetBaseIterator : public IteratorBase {
void AddConstantParameter(IteratorContext* ctx, const string& name,
int64 value) {
if (ctx->model()) {
- std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
- if (node) {
- node->add_constant_param(name, value);
- }
+ ctx->model()->AddConstantParameter(prefix(), name, value);
}
}
// When performance modeling is enabled, this method adds a tunable parameter
// to the model node corresponding to this iterator.
//
- // The `set_fn` function should set the tunable parameter to the value of
- // its input argument. The function should be thread-safe; in particular, the
- // state it updates should be protected by a lock as the function can be
- // invoked asynchronously. It is guaranteed that this function will not be
- // invoked after the iterator is deleted because the model node that owns
- // the function is deleted when the iterator is deleted.
+ // The performance modeling logic may use `value` to set the value of the
+ // tunable parameter at any point during the lifetime of this iterator. When
+ // it does, it notifies `cond_var`.
void AddTunableParameter(IteratorContext* ctx, const string& name,
- int64 value, int64 min, int64 max,
- std::function<void(int64)>&& set_fn) {
+ std::atomic<int64>* value, int64 min, int64 max,
+ condition_variable* cond_var) {
if (ctx->model()) {
- std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
- if (node) {
- node->add_tunable_param(name, value, min, max, std::move(set_fn));
- }
+ ctx->model()->AddTunableParameter(prefix(), name, value, min, max,
+ cond_var);
+ }
+ }
+
+ // When performance modeling is enabled, this method records the fact that
+ // this iterator has produced an element.
+ void RecordElement(IteratorContext* ctx) {
+ if (ctx->model()) {
+ ctx->model()->RecordElement(prefix());
}
}
// When performance modeling is enabled, this method records the fact that
// a thread of this iterator has started work.
- void StartWork(IteratorContext* ctx) {
+ void RecordStart(IteratorContext* ctx, bool stop_output = false) {
if (ctx->model()) {
- std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
- if (node) {
- node->start_work();
- }
+ ctx->model()->RecordStart(prefix(), stop_output);
}
}
// When performance modeling is enabled, this method records the fact that
// a thread of this iterator has stopped work.
- void StopWork(IteratorContext* ctx) {
+ void RecordStop(IteratorContext* ctx, bool start_output = false) {
if (ctx->model()) {
- std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
- if (node) {
- node->stop_work();
- }
+ ctx->model()->RecordStop(prefix(), start_output);
}
}
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
index 112298c344..b0330ec990 100644
--- a/tensorflow/core/framework/model.cc
+++ b/tensorflow/core/framework/model.cc
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -17,16 +17,14 @@ limitations under the License.
#include <memory>
-#include "tensorflow/core/lib/gtl/map_util.h"
-
namespace tensorflow {
namespace data {
namespace model {
// TODO(jsimsa): Use `Node` subclassing instead of types and node statements.
-void Node::CollectTunables(
+void Model::Node::CollectTunables(
std::vector<std::shared_ptr<Node::Tunable>>* tunables) {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (auto input : inputs_) {
input->CollectTunables(tunables);
}
@@ -45,14 +43,14 @@ void Node::CollectTunables(
}
}
-int64 Node::GetParameterValue(const string& name) {
+int64 Model::Node::GetParameterValue(const string& name) {
if (auto* tunable_param = gtl::FindOrNull(tunable_params_, name)) {
return (*tunable_param)->value;
}
return constant_params_[name];
}
-int64 Node::ProcessingTimeLocked() {
+int64 Model::Node::ProcessingTimeLocked() {
switch (type_) {
case Type::BATCH:
case Type::MAP_AND_BATCH:
@@ -101,7 +99,7 @@ int64 Node::ProcessingTimeLocked() {
}
}
-int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
+int64 Model::Node::OutputTimeLocked(std::vector<int64>* input_times) {
switch (type_) {
case Type::BATCH:
case Type::PADDED_BATCH: {
@@ -251,15 +249,34 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
}
}
-std::shared_ptr<Node> Model::AddNode(const string& name,
- const string& output_name) {
- mutex_lock l(mu_);
+void Model::AddConstantParameter(const string& node_name,
+ const string& parameter_name, int64 value) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, node_name);
+ if (node) {
+ (*node)->add_constant_param(parameter_name, value);
+ }
+}
+
+void Model::AddNode(const string& name, const string& output_name) {
+ // The name captures the sequence of iterators joined by `::`. We use the full
+ // sequence as the key in the lookup table, but only the last element of the
+ // sequence as the name node.
+ std::vector<string> tokens =
+ str_util::Split(name, ':', str_util::SkipEmpty());
+ // The output name might contain an index. We need to strip it to make it
+ // possible for the model to successfully identify the output node.
+ string sanitized_output_name = output_name;
+ if (str_util::EndsWith(output_name, "]")) {
+ sanitized_output_name = output_name.substr(0, output_name.rfind('['));
+ }
std::shared_ptr<Node> output;
- auto it = lookup_table_.find(output_name);
+ mutex_lock l(mu_);
+ auto it = lookup_table_.find(sanitized_output_name);
if (it != lookup_table_.end()) {
output = it->second;
}
- std::shared_ptr<Node> node(new Node(id_counter_++, output));
+ std::shared_ptr<Node> node(new Node(id_counter_++, tokens.back(), output));
if (!output_) {
output_ = node;
}
@@ -267,88 +284,125 @@ std::shared_ptr<Node> Model::AddNode(const string& name,
output->add_input(node);
}
lookup_table_.insert(std::make_pair(name, node));
- return node;
}
-std::shared_ptr<Node> Model::LookupNode(const string& name) {
+void Model::AddProcessingTime(const string& name, int64 delta) {
tf_shared_lock l(mu_);
- std::shared_ptr<Node> result;
- auto it = lookup_table_.find(name);
- if (it != lookup_table_.end()) {
- result = it->second;
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->add_processing_time(delta);
}
- return result;
+}
+
+void Model::AddTunableParameter(const string& node_name,
+ const string& parameter_name,
+ std::atomic<int64>* value, int64 min, int64 max,
+ condition_variable* cond_var) {
+ tf_shared_lock l(mu_);
+ auto node = *gtl::FindOrNull(lookup_table_, node_name);
+ DCHECK(node);
+ node->add_tunable_param(parameter_name, value, min, max, cond_var);
}
// The optimization algorithm starts by setting all tunable parallelism
-// parameters to 1. It then repeatedly identifies the parameter that whose
-// increase in parallelism decreases the output time the most. This process is
-// repeated until all parameters reach their maximum values or the
-// projected output time is less than or equal to the processing time needed to
-// produce an element divided by CPU budget.
+// parameters to 1. It then repeatedly identifies the parameter whose increase
+// in parallelism decreases the output time the most. This process is repeated
+// until all parameters reach their maximum values or the projected output time
+// is less than or equal to the processing time needed to produce an element
+// divided by CPU budget.
void Model::Optimize(int64 cpu_budget) {
- mutex_lock l(optimization_mu_);
- std::vector<std::shared_ptr<Node::Tunable>> tunables;
- {
- mutex_lock l2(mu_);
- const int64 processing_time = ProcessingTime();
- tunables = CollectTunables();
- for (auto tunable : tunables) {
- tunable->value = 1;
- }
- while (true) {
- const int64 output_time = OutputTime();
- bool all_tunables = true;
- for (auto& tunable : tunables) {
- if (tunable->value < tunable->max) {
- all_tunables = false;
- break;
- }
- }
- if (output_time < processing_time / cpu_budget || all_tunables) {
+ tf_shared_lock lock(mu_);
+ std::vector<std::shared_ptr<Model::Node::Tunable>> tunables;
+ const int64 processing_time = ProcessingTime();
+ tunables = CollectTunables();
+ for (auto tunable : tunables) {
+ tunable->value = 1;
+ }
+ while (true) {
+ const int64 output_time = OutputTime();
+ bool all_tunables = true;
+ for (auto& tunable : tunables) {
+ if (tunable->value < tunable->max) {
+ all_tunables = false;
break;
}
- int64 best_delta = -1;
- Node::Tunable* best_tunable = nullptr;
- for (auto& tunable : tunables) {
- if (tunable->value == tunable->max) {
- continue;
- }
- tunable->value++;
- int64 delta = output_time - OutputTime();
- if (delta > best_delta) {
- best_delta = delta;
- best_tunable = tunable.get();
- }
- tunable->value--;
+ }
+ if (output_time < processing_time / cpu_budget || all_tunables) {
+ break;
+ }
+ int64 best_delta = -1;
+ Model::Node::Tunable* best_tunable = nullptr;
+ for (auto& tunable : tunables) {
+ if (tunable->value == tunable->max) {
+ continue;
}
- if (!best_tunable) {
- // NOTE: This can happen because we are performing the optimization
- // while the model data is changing. If this becomes an issue, we should
- // look into performing the optimization using a model snapshot.
- break;
+ tunable->value++;
+ int64 delta = output_time - OutputTime();
+ if (delta > best_delta) {
+ best_delta = delta;
+ best_tunable = tunable.get();
}
- best_tunable->value++;
+ tunable->value--;
+ }
+ if (!best_tunable) {
+ // NOTE: This can happen because we are performing the optimization
+ // while the model data is changing. If this becomes an issue, we should
+ // look into performing the optimization using a model snapshot.
+ break;
}
+ best_tunable->value++;
}
- // The `set_fn` functions should be invoked without holding a lock to avoid a
- // potential deadlock.
+ VLOG(2) << "Number of knobs: " << tunables.size();
for (auto& tunable : tunables) {
- tunable->set_fn(tunable->value);
+ VLOG(2) << "Setting tunable parameter: " << tunable->value;
+ tunable->value_ptr->store(tunable->value);
+ if (tunable->cond_var) {
+ tunable->cond_var->notify_all();
+ }
+ }
+}
+
+void Model::RecordElement(const string& name) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->record_element();
}
}
-void Model::RemoveNode(const string& prefix) {
- // Nodes are not allowed to be removed when optimization is in progress to
- // prevent the optimization from trying to access an iterator that was
- // concurrently deleted.
- mutex_lock l(optimization_mu_);
- mutex_lock l2(mu_);
- lookup_table_.erase(prefix);
+void Model::RecordStart(const string& name, bool stop_output) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ if (stop_output && (*node)->output()) {
+ (*node)->output()->record_stop();
+ }
+ (*node)->record_start();
+ }
+}
+
+void Model::RecordStop(const string& name, bool start_output) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->record_stop();
+ if (start_output && (*node)->output()) {
+ (*node)->output()->record_start();
+ }
+ }
+}
+
+void Model::RemoveNode(const string& name) {
+ mutex_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node && (*node)->output()) {
+ (*node)->output()->remove_input(*node);
+ }
+ lookup_table_.erase(name);
}
-std::vector<std::shared_ptr<Node::Tunable>> Model::CollectTunables() {
- std::vector<std::shared_ptr<Node::Tunable>> tunables;
+std::vector<std::shared_ptr<Model::Node::Tunable>> Model::CollectTunables() {
+ std::vector<std::shared_ptr<Model::Node::Tunable>> tunables;
output_->CollectTunables(&tunables);
return tunables;
}
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
index f88ec06ef3..26402f5cd3 100644
--- a/tensorflow/core/framework/model.h
+++ b/tensorflow/core/framework/model.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
@@ -32,341 +33,365 @@ namespace tensorflow {
namespace data {
namespace model {
-class Model;
-class Node;
-
-// Abstract representation of a TensorFlow input pipeline node. It collects
-// information about inputs to this node, processing time spent executing the
-// node logic, number of elements produced by the node, various other
-// information (e.g. batch size or execution parallelism).
+// Abstract representation of a TensorFlow input pipeline that can be used
+// for collecting runtime information and optimizing performance. It collects
+// runtime information about execution of the input pipeline that is used to
+// create a performance model, which is in turn used to identify optimal values
+// of tunable parameters.
//
// Developers of tf.data transformations are not expected to interact with this
// class directly. Boiler plate code for creating the abstract representation of
-// the input pipeline and collecting common information has been added to the
+// the input pipeline and collecting runtime information has been added to the
// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
-//
-// In addition, `DatasetBaseIterator` provides wrappers that can be used for
-// transformation-specific information collection. The `SetMetadata` wrapper can
-// be used to pass arbitrary metadata to the modeling framework, while the
-// `StartWork` and `StopWork` wrappers should be used to correctly account for
-// processing time of multi-threaded transformation that yield the CPU; such
-// transformations should invoke `StartWork()` when a transformation thread
-// starts executing (e.g. when created or woken up) and `StopWork()` when a
-// transformation thread stops executing (e.g. when returning or waiting).
-//
-// TODO(jsimsa): Create an API to capture the abstract semantics of each
-// tf.data transformation and replace switch-case blocks with inheritance.
-class Node {
+class Model {
public:
- Node(int64 id, std::shared_ptr<Node> output) : id_(id), output_(output) {}
-
- // Adds a constant parameter.
- void add_constant_param(const string& name, int64 value) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- constant_params_[name] = value;
- }
-
- // Records that the node produced an element.
- void add_element() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- num_elements_++;
- }
-
- // Adds an input.
- void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- inputs_.push_back(node);
- }
-
- // Increments the aggregate processing time by the given delta.
- void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- processing_time_ += delta;
- }
-
- // Adds a tunable parameter.
- void add_tunable_param(const string& name, int64 value, int64 min, int64 max,
- std::function<void(int64)>&& set_fn)
- LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- tunable_params_[name] =
- std::make_shared<Tunable>(value, min, max, std::move(set_fn));
- }
-
- // Returns the unique node ID.
- int64 id() LOCKS_EXCLUDED(mu_) { return id_; }
-
- // Returns the node inputs.
- std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) {
- tf_shared_lock l(mu_);
- return inputs_;
- }
-
- // Returns the node name.
- const string& name() LOCKS_EXCLUDED(mu_) {
- tf_shared_lock l(mu_);
- return name_;
- }
-
- // Returns the number of elements produced by the node.
- int64 num_elements() LOCKS_EXCLUDED(mu_) {
- tf_shared_lock l(mu_);
- return num_elements_;
- }
-
- // Returns the node output.
- std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
- tf_shared_lock l(mu_);
- return output_;
- }
-
- // Removes an input.
- void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- inputs_.remove(input);
- }
-
- // Sets the node name.
- void set_name(const string& name) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- name_ = name;
- type_ = TypeFromName(name);
- }
-
- // Set the node output.
- void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- output_ = output;
- }
-
- // Records that a node thread has started work.
- void start_work() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos();
- }
-
- // Records that a node thread has stopped work.
- void stop_work() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- auto iter = work_start_.find(std::this_thread::get_id());
- CHECK(work_start_.end() != iter)
- << "Encountered a stop event that was not preceded by a start event.";
- processing_time_ += Env::Default()->NowNanos() - iter->second;
- work_start_.erase(iter);
- }
-
- private:
- // Represents a tunable parameter.
- struct Tunable {
- Tunable(int64 value, int64 min, int64 max,
- std::function<void(int64)> set_fn)
- : value(value), min(min), max(max), set_fn(std::move(set_fn)) {}
-
- int64 value;
- int64 min;
- int64 max;
- std::function<void(int64)> set_fn;
- };
+ Model() = default;
- enum class Type {
- BATCH = 0,
- CACHE,
- CONCATENATE,
- FILTER,
- FLAT_MAP,
- INTERLEAVE,
- MAP,
- MAP_AND_BATCH,
- PADDED_BATCH,
- PARALLEL_INTERLEAVE,
- PARALLEL_INTERLEAVE_V2,
- PARALLEL_MAP,
- PREFETCH,
- REPEAT,
- SHUFFLE,
- SKIP,
- TAKE,
- ZIP,
- UNKNOWN,
- };
+ // Adds a constant parameter for the given node.
+ void AddConstantParameter(const string& node_name,
+ const string& parameter_name, int64 value)
+ LOCKS_EXCLUDED(mu_);
- // Collects tunable parameters in the subtree rooted in this node.
- void CollectTunables(std::vector<std::shared_ptr<Node::Tunable>>* tunables)
+ // Adds a node with the given name and given output (identified by name).
+ void AddNode(const string& name, const string& output_name)
LOCKS_EXCLUDED(mu_);
- // Gets a value of the given parameter (tunable or constant).
- int64 GetParameterValue(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ // Increments the processing time for the given node..
+ void AddProcessingTime(const string& name, int64 delta) LOCKS_EXCLUDED(mu_);
- // Returns the per-element processing time spent in this node.
- int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- return NanosPerElementLocked();
- }
+ // Adds a tunable parameter for the given node.
+ void AddTunableParameter(const string& node_name,
+ const string& parameter_name,
+ std::atomic<int64>* value, int64 min, int64 max,
+ condition_variable* cond_var) LOCKS_EXCLUDED(mu_);
- int64 NanosPerElementLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (num_elements_ == 0) {
- return 0;
- }
- return (int64)((double)processing_time_ / (double)num_elements_);
- }
-
- // Returns the per-element output time for this node.
- int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- return OutputTimeLocked(input_times);
- }
-
- int64 OutputTimeLocked(std::vector<int64>* input_times)
- EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
- int64 OutputTimeForInputs(std::vector<int64>* input_times)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- int64 sum = 0;
- for (auto input : inputs_) {
- sum += input->OutputTime(input_times);
- }
- return sum;
- }
-
- // Returns the per-element processing time spent in the subtree rooted in this
- // node.
- int64 ProcessingTime() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- return ProcessingTimeLocked();
- }
-
- int64 ProcessingTimeLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
- // Returns the per-element processing time spent in the inputs of this node.
- int64 ProcessingTimeForInputs() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- int64 sum = 0;
- for (auto input : inputs_) {
- sum += input->ProcessingTimeLocked();
- }
- return sum;
- }
+ // Runs optimization.
+ void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_);
- Type TypeFromName(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (name_ == "Batch") {
- return Type::BATCH;
- }
- if (str_util::EndsWith(name_, "Cache")) {
- return Type::CACHE;
- }
- if (name_ == "Concatenate") {
- return Type::CONCATENATE;
- }
- if (name_ == "Filter") {
- return Type::FILTER;
+ // Records that a node has produced an element.
+ void RecordElement(const string& name) LOCKS_EXCLUDED(mu_);
+
+ // Records that the given node has started work. If `stop_output` is set, it
+ // also records that the output of the given node has stopped work.
+ void RecordStart(const string& name, bool stop_output) LOCKS_EXCLUDED(mu_);
+
+ // Records that the given node has stopped work. If `stop_output` is set, it
+ // also records that the output of the given node has started work.
+ void RecordStop(const string& name, bool start_output) LOCKS_EXCLUDED(mu_);
+
+ // Removes the given node.
+ void RemoveNode(const string& name) LOCKS_EXCLUDED(mu_);
+
+ private:
+ // Abstract representation of a TensorFlow input pipeline node. It collects
+ // information about inputs to this node, processing time spent executing the
+ // node logic, number of elements produced by the node, various other
+ // information (e.g. batch size or execution parallelism).
+ //
+ // Developers of tf.data transformations are not expected to interact with
+ // this class directly. Boiler plate code for creating the abstract
+ // representation of the input pipeline and collecting common information has
+ // been added to the implementation of `DatasetBase` and `DatasetBaseIterator`
+ // respectively.
+ //
+ // In addition, `DatasetBaseIterator` provides wrappers that can be used for
+ // transformation-specific information collection. The `SetMetadata` wrapper
+ // can be used to pass arbitrary metadata to the modeling framework, while the
+ // `StartWork` and `StopWork` wrappers should be used to correctly account for
+ // processing time of multi-threaded transformation that yield the CPU; such
+ // transformations should invoke `StartWork()` when a transformation thread
+ // starts executing (e.g. when created or woken up) and `StopWork()` when a
+ // transformation thread stops executing (e.g. when returning or waiting).
+ //
+ // TODO(jsimsa): Create an API to capture the abstract semantics of each
+ // tf.data transformation and replace switch-case blocks with inheritance.
+ class Node {
+ public:
+ // Represents a tunable parameter.
+ struct Tunable {
+ Tunable(std::atomic<int64>* value, int64 min, int64 max,
+ condition_variable* cond_var)
+ : value(*value),
+ min(min),
+ max(max),
+ value_ptr(value),
+ cond_var(cond_var) {}
+
+ // Identifies the model value of the parameter. This can be different from
+ // the actual value (e.g. during optimization search).
+ int64 value;
+
+ // Identifies the minimum value of the parameter.
+ int64 min;
+
+ // Identifies the maximum value of the parameter.
+ int64 max;
+
+ // Points to the actual value of the parameter. Not owned.
+ std::atomic<int64>* value_ptr;
+
+ // If non-null, this condition variable is notified when the model updates
+ // the actual value of the parameter (via `value_ptr`). Not owned.
+ condition_variable* cond_var;
+ };
+
+ Node(int64 id, const string& name, std::shared_ptr<Node> output)
+ : id_(id), name_(name), type_(TypeFromName(name)), output_(output) {}
+
+ // Adds a constant parameter.
+ void add_constant_param(const string& name, int64 value)
+ LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ constant_params_[name] = value;
}
- if (name_ == "FlatMap") {
- return Type::FLAT_MAP;
+
+ // Adds an input.
+ void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ inputs_.push_back(node);
}
- if (name_ == "Interleave") {
- return Type::INTERLEAVE;
+
+ // Increments the aggregate processing time by the given delta.
+ void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ processing_time_ += delta;
}
- if (name_ == "Map") {
- return Type::MAP;
+
+ // Adds a tunable parameter.
+ void add_tunable_param(const string& name, std::atomic<int64>* value,
+ int64 min, int64 max, condition_variable* cond_var)
+ LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ tunable_params_[name] =
+ std::make_shared<Tunable>(value, min, max, cond_var);
}
- if (name_ == "MapAndBatch") {
- return Type::MAP_AND_BATCH;
+
+ // Returns the unique node ID.
+ int64 id() LOCKS_EXCLUDED(mu_) { return id_; }
+
+ // Returns the node inputs.
+ std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return inputs_;
}
- if (name_ == "PaddedBatch") {
- return Type::PADDED_BATCH;
+
+ // Returns the node name.
+ const string& name() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return name_;
}
- if (name_ == "ParallelInterleave") {
- return Type::PARALLEL_INTERLEAVE;
+
+ // Returns the number of elements produced by the node.
+ int64 num_elements() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return num_elements_;
}
- if (name_ == "ParallelInterleaveV2") {
- return Type::PARALLEL_INTERLEAVE_V2;
+
+ // Returns the node output.
+ std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return output_;
}
- if (name_ == "ParallelMap") {
- return Type::PARALLEL_MAP;
+
+ // Records that the node produced an element.
+ void record_element() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ num_elements_++;
}
- if (name_ == "Prefetch") {
- return Type::PREFETCH;
+
+ // Records that a node thread has started executing.
+ void record_start() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos();
}
- if (str_util::EndsWith(name_, "Repeat")) {
- return Type::REPEAT;
+
+ // Records that a node thread has stopped executing.
+ void record_stop() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ std::thread::id tid = std::this_thread::get_id();
+ auto start_time = gtl::FindOrNull(work_start_, tid);
+ DCHECK(start_time)
+ << "Encountered a stop event that was not preceded by a start event.";
+ if (start_time) {
+ processing_time_ += Env::Default()->NowNanos() - *start_time;
+ work_start_.erase(tid);
+ }
}
- if (name_ == "Shuffle") {
- return Type::SHUFFLE;
+
+ // Removes an input.
+ void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ inputs_.remove(input);
}
- if (str_util::EndsWith(name_, "Skip")) {
- return Type::SKIP;
+
+ // Set the node output.
+ void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ output_ = output;
}
- if (str_util::EndsWith(name_, "Take")) {
- return Type::TAKE;
+
+ // Collects tunable parameters in the subtree rooted in this node.
+ void CollectTunables(std::vector<std::shared_ptr<Tunable>>* tunables)
+ LOCKS_EXCLUDED(mu_);
+
+ // Returns the per-element output time for this node.
+ int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return OutputTimeLocked(input_times);
}
- if (name_ == "Zip") {
- return Type::ZIP;
+
+ // Returns the per-element processing time spent in the subtree rooted in
+ // this node.
+ int64 ProcessingTime() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return ProcessingTimeLocked();
}
- return Type::UNKNOWN;
- }
- mutex mu_;
- const int64 id_;
- Type type_ GUARDED_BY(mu_);
- string name_ GUARDED_BY(mu_);
- int64 processing_time_ GUARDED_BY(mu_) = 0;
- int64 num_elements_ GUARDED_BY(mu_) = 0;
- std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
- std::map<string, int64> constant_params_ GUARDED_BY(mu_);
- // Tunables are shared with the model during optimization.
- std::map<string, std::shared_ptr<Tunable>> tunable_params_ GUARDED_BY(mu_);
- std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
- std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+ private:
+ enum class Type {
+ BATCH = 0,
+ CACHE,
+ CONCATENATE,
+ FILTER,
+ FLAT_MAP,
+ INTERLEAVE,
+ MAP,
+ MAP_AND_BATCH,
+ PADDED_BATCH,
+ PARALLEL_INTERLEAVE,
+ PARALLEL_INTERLEAVE_V2,
+ PARALLEL_MAP,
+ PREFETCH,
+ REPEAT,
+ SHUFFLE,
+ SKIP,
+ TAKE,
+ ZIP,
+ UNKNOWN,
+ };
+
+ // Gets a value of the given parameter (tunable or constant).
+ int64 GetParameterValue(const string& name) SHARED_LOCKS_REQUIRED(mu_);
+
+ // Returns the per-element processing time spent in this node.
+ int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return NanosPerElementLocked();
+ }
- friend class Model;
-};
+ int64 NanosPerElementLocked() SHARED_LOCKS_REQUIRED(mu_) {
+ if (num_elements_ == 0) {
+ return 0;
+ }
+ return (int64)((double)processing_time_ / (double)num_elements_);
+ }
-// Abstract representation of a TensorFlow input pipeline that can be used
-// for collecting runtime information and optimizing performance. It collects
-// runtime information about execution of the input pipeline that is used to
-// create a performance model, which is in turn used to identify optimal values
-// of tunable parameters.
-//
-// Developers of tf.data transformations are not expected to interact with this
-// class directly. Boiler plate code for creating the abstract representation of
-// the input pipeline and collecting runtime information has been added to the
-// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
-class Model {
- public:
- Model() = default;
+ int64 OutputTimeLocked(std::vector<int64>* input_times)
+ SHARED_LOCKS_REQUIRED(mu_);
- // Returns the model output node.
- std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
- tf_shared_lock l(mu_);
- return output_;
- }
+ int64 OutputTimeForInputs(std::vector<int64>* input_times)
+ SHARED_LOCKS_REQUIRED(mu_) {
+ int64 sum = 0;
+ for (auto input : inputs_) {
+ sum += input->OutputTime(input_times);
+ }
+ return sum;
+ }
- // Adds a node with the given name and given output (identified by name).
- std::shared_ptr<Node> AddNode(const string& name, const string& output_name)
- LOCKS_EXCLUDED(mu_);
+ int64 ProcessingTimeLocked() SHARED_LOCKS_REQUIRED(mu_);
- // Looks up the node using the given name.
- std::shared_ptr<Node> LookupNode(const string& name) LOCKS_EXCLUDED(mu_);
+ // Returns the per-element processing time spent in the inputs of this node.
+ int64 ProcessingTimeForInputs() SHARED_LOCKS_REQUIRED(mu_) {
+ int64 sum = 0;
+ for (auto input : inputs_) {
+ sum += input->ProcessingTime();
+ }
+ return sum;
+ }
- // Runs optimization.
- void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_);
+ Type TypeFromName(const string& name) SHARED_LOCKS_REQUIRED(mu_) {
+ if (name_ == "Batch") {
+ return Type::BATCH;
+ }
+ if (str_util::EndsWith(name_, "Cache")) {
+ return Type::CACHE;
+ }
+ if (name_ == "Concatenate") {
+ return Type::CONCATENATE;
+ }
+ if (name_ == "Filter") {
+ return Type::FILTER;
+ }
+ if (name_ == "FlatMap") {
+ return Type::FLAT_MAP;
+ }
+ if (name_ == "Interleave") {
+ return Type::INTERLEAVE;
+ }
+ if (name_ == "Map") {
+ return Type::MAP;
+ }
+ if (name_ == "MapAndBatch") {
+ return Type::MAP_AND_BATCH;
+ }
+ if (name_ == "PaddedBatch") {
+ return Type::PADDED_BATCH;
+ }
+ if (name_ == "ParallelInterleave") {
+ return Type::PARALLEL_INTERLEAVE;
+ }
+ if (name_ == "ParallelInterleaveV2") {
+ return Type::PARALLEL_INTERLEAVE_V2;
+ }
+ if (name_ == "ParallelMap") {
+ return Type::PARALLEL_MAP;
+ }
+ if (name_ == "Prefetch") {
+ return Type::PREFETCH;
+ }
+ if (str_util::EndsWith(name_, "Repeat")) {
+ return Type::REPEAT;
+ }
+ if (name_ == "Shuffle") {
+ return Type::SHUFFLE;
+ }
+ if (str_util::EndsWith(name_, "Skip")) {
+ return Type::SKIP;
+ }
+ if (str_util::EndsWith(name_, "Take")) {
+ return Type::TAKE;
+ }
+ if (name_ == "Zip") {
+ return Type::ZIP;
+ }
+ return Type::UNKNOWN;
+ }
- // Removes the node identified by the given name.
- void RemoveNode(const string& prefix) LOCKS_EXCLUDED(mu_);
+ mutex mu_;
+ const int64 id_;
+ const string name_;
+ const Type type_;
+ int64 processing_time_ GUARDED_BY(mu_) = 0;
+ int64 num_elements_ GUARDED_BY(mu_) = 0;
+ std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
+ std::map<string, int64> constant_params_ GUARDED_BY(mu_);
+ // Tunables are shared with the model during optimization.
+ std::map<string, std::shared_ptr<Tunable>> tunable_params_ GUARDED_BY(mu_);
+ std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
+ std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+ };
- private:
std::vector<std::shared_ptr<Node::Tunable>> CollectTunables()
- EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ SHARED_LOCKS_REQUIRED(mu_);
- int64 OutputTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ int64 OutputTime() SHARED_LOCKS_REQUIRED(mu_);
- int64 ProcessingTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ int64 ProcessingTime() SHARED_LOCKS_REQUIRED(mu_);
- // Used for coordination between different input pipeline threads.
+ // Used for coordination between different input pipeline threads. Exclusive
+ // access is required only when adding or removing nodes. Concurrent access to
+ // existing nodes is protected by a node mutex.
mutex mu_;
- // Used for preventing iterator deletion when optimization is in progress
- // because the optimization may try to update the values of tunable
- // parameters.
- mutex optimization_mu_ ACQUIRED_BEFORE(mu_);
int64 id_counter_ GUARDED_BY(mu_) = 1;
std::shared_ptr<Node> output_ GUARDED_BY(mu_);
std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_);