diff options
author | 2018-09-17 09:21:14 -0700 | |
---|---|---|
committer | 2018-09-17 09:24:34 -0700 | |
commit | c8a0dfc741736a59f8fd1776b71f38619d66da56 (patch) | |
tree | 0a3ff87aed44e895ca7b3a09a93653f8eea7da59 /tensorflow/core/framework | |
parent | 07bc3696135483612c727ca7687342922ff0d5de (diff) |
[tf.data] Adding support for `tf.data.AUTOTUNE` as a special value for the `num_parallel_calls` argument of `tf.data.Dataset.map()`, `tf.data.Dataset.interleave()`, and `tf.contrib.data.map_and_batch()`.
When `tf.data.AUTOTUNE` is specified, the level of parallelism is determined at runtime. The underlying mechanism instruments the input pipeline to build a performance model and then uses the model to find the optimal values for the parallelism knobs.
PiperOrigin-RevId: 213283297
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/dataset.cc | 1 | ||||
-rw-r--r-- | tensorflow/core/framework/dataset.h | 31 | ||||
-rw-r--r-- | tensorflow/core/framework/model.cc | 251 | ||||
-rw-r--r-- | tensorflow/core/framework/model.h | 97 | ||||
-rw-r--r-- | tensorflow/core/framework/model.proto | 30 |
5 files changed, 177 insertions, 233 deletions
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index 5281c56f04..284dafb886 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -20,7 +20,6 @@ limitations under the License. namespace tensorflow { namespace data { - namespace { // A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor. diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 4ee6749eea..91b1e61d3c 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -47,6 +47,8 @@ class GraphDefBuilder; class Node; namespace data { +// A constant that can be used to enable auto-tuning. +constexpr int kAutoTune = -1; class DatasetBase; class SerializationContext; @@ -670,13 +672,34 @@ class DatasetBaseIterator : public IteratorBase { return strings::StrCat(params_.prefix, ":", name); } - // When performance modeling is enabled, this method sets metadata entry for - // the model node corresponding to this iterator. - void SetMetadata(IteratorContext* ctx, const string& key, int64 value) { + // When performance modeling is enabled, this method adds a constant parameter + // to the model node corresponding to this iterator. + void AddConstantParameter(IteratorContext* ctx, const string& name, + int64 value) { if (ctx->model()) { std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix()); if (node) { - node->set_metadata(key, value); + node->add_constant_param(name, value); + } + } + } + + // When performance modeling is enabled, this method adds a tunable parameter + // to the model node corresponding to this iterator. + // + // The `set_fn` function should set the tunable parameter to the value of + // its input argument. The function should be thread-safe; in particular, the + // state it updates should be protected by a lock as the function can be + // invoked asynchronously. It is guaranteed that this function will not be + // invoked after the iterator is deleted because the model node that owns + // the function is deleted when the iterator is deleted. + void AddTunableParameter(IteratorContext* ctx, const string& name, + int64 value, int64 min, int64 max, + std::function<void(int64)>&& set_fn) { + if (ctx->model()) { + std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix()); + if (node) { + node->add_tunable_param(name, value, min, max, std::move(set_fn)); } } } diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index 250b006641..b3fe357ea1 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -15,52 +15,28 @@ limitations under the License. #include "tensorflow/core/framework/model.h" +#include <memory> + +#include "tensorflow/core/lib/gtl/map_util.h" + namespace tensorflow { namespace data { namespace model { // TODO(jsimsa): Use `Node` subclassing instead of types and node statements. -void Node::CollectKnobs(std::vector<Node::Knob>* knobs) { +void Node::CollectTunables( + std::vector<std::shared_ptr<Node::Tunable>>* tunables) { mutex_lock l(mu_); + for (auto input : inputs_) { + input->CollectTunables(tunables); + } switch (type_) { - case Type::PARALLEL_INTERLEAVE_V2: { - for (auto input : inputs_) { - input->CollectKnobs(knobs); - } - int64 processing_time = static_cast<int64>( - static_cast<double>(ProcessingTimeLocked() - - inputs_.front()->ProcessingTime()) / - static_cast<double>(inputs_.size() - 1)); - knobs->emplace_back( - Node::Knob{this, processing_time, metadata_["parallelism"]}); - return; - } case Type::MAP_AND_BATCH: + case Type::PARALLEL_INTERLEAVE_V2: case Type::PARALLEL_MAP: { - for (auto input : inputs_) { - input->CollectKnobs(knobs); - } - knobs->emplace_back( - Node::Knob{this, NanosPerElementLocked(), metadata_["parallelism"]}); - return; - } - case Type::BATCH: - case Type::CACHE: - case Type::CONCATENATE: - case Type::FILTER: - case Type::FLAT_MAP: - case Type::INTERLEAVE: - case Type::MAP: - case Type::PADDED_BATCH: - case Type::PARALLEL_INTERLEAVE: - case Type::PREFETCH: - case Type::REPEAT: - case Type::SHUFFLE: - case Type::SKIP: - case Type::TAKE: - case Type::ZIP: { - for (auto input : inputs_) { - input->CollectKnobs(knobs); + if (auto* tunable_param = + gtl::FindOrNull(tunable_params_, "parallelism")) { + tunables->push_back(*tunable_param); } return; } @@ -69,12 +45,19 @@ void Node::CollectKnobs(std::vector<Node::Knob>* knobs) { } } +int64 Node::GetParameterValue(const string& name) { + if (auto* tunable_param = gtl::FindOrNull(tunable_params_, name)) { + return (*tunable_param)->value; + } + return constant_params_[name]; +} + int64 Node::ProcessingTimeLocked() { switch (type_) { case Type::BATCH: case Type::MAP_AND_BATCH: case Type::PADDED_BATCH: { - int64 batch_size = metadata_["batch_size"]; + int64 batch_size = GetParameterValue("batch_size"); return NanosPerElementLocked() + batch_size * ProcessingTimeForInputs(); } case Type::FILTER: { @@ -122,7 +105,7 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) { switch (type_) { case Type::BATCH: case Type::PADDED_BATCH: { - double batch_size = metadata_["batch_size"]; + double batch_size = GetParameterValue("batch_size"); int64 old_value = (*input_times)[input_times->size() - 1]; (*input_times)[input_times->size() - 1] = static_cast<int64>( static_cast<double>(old_value + NanosPerElementLocked()) / @@ -168,8 +151,8 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) { static_cast<double>(inputs_.size() - 1); } case Type::MAP_AND_BATCH: { - double batch_size = metadata_["batch_size"]; - double parallelism = metadata_["parallelism"]; + double batch_size = GetParameterValue("batch_size"); + double parallelism = GetParameterValue("parallelism"); int64 delta = static_cast<int64>(static_cast<double>(NanosPerElementLocked()) / (batch_size * parallelism)); @@ -182,22 +165,41 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) { return std::max(0LL, output_time - input_times->at(input_times->size() - 2)); } - case Type::PARALLEL_INTERLEAVE: + case Type::PARALLEL_INTERLEAVE: { + // TODO(jsimsa): model the first input + if (inputs_.size() <= 1) { + return NanosPerElementLocked(); + } + int64 delta = static_cast<double>(NanosPerElementLocked()) * + static_cast<double>(inputs_.size() - 1); + input_times->push_back(delta); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + int64 inputs_output_time = OutputTimeForInputs(input_times) - + inputs_.front()->OutputTime(input_times); + double parallelism = GetParameterValue("parallelism"); + int64 output_time = + NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) / + static_cast<double>(inputs_.size() - 1)) / + parallelism); + return std::max(0LL, + output_time - input_times->at(input_times->size() - 2)); + } case Type::PARALLEL_INTERLEAVE_V2: { // TODO(jsimsa): model the first input if (inputs_.size() <= 1) { return NanosPerElementLocked(); } - int64 delta = - static_cast<int64>(static_cast<double>(NanosPerElementLocked()) * - static_cast<double>(inputs_.size() - 1)); + int64 delta = static_cast<double>(NanosPerElementLocked()) * + static_cast<double>(inputs_.size() - 1); input_times->push_back(delta); auto cleanup = gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); int64 inputs_output_time = OutputTimeForInputs(input_times) - inputs_.front()->OutputTime(input_times); - double parallelism = std::min(port::NumSchedulableCPUs(), - static_cast<int>(metadata_["parallelism"])); + double parallelism = + std::min(static_cast<int>(GetParameterValue("cycle_length")), + static_cast<int>(GetParameterValue("parallelism"))); int64 output_time = NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) / static_cast<double>(inputs_.size() - 1)) / @@ -206,8 +208,9 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) { output_time - input_times->at(input_times->size() - 2)); } case Type::PARALLEL_MAP: { - double parallelism = std::min(port::NumSchedulableCPUs(), - static_cast<int>(metadata_["parallelism"])); + double parallelism = + std::min(port::NumSchedulableCPUs(), + static_cast<int>(GetParameterValue("parallelism"))); int64 delta = static_cast<int64>( static_cast<double>(NanosPerElementLocked()) / parallelism); input_times->push_back(delta); @@ -248,23 +251,6 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) { } } -Model::Model(const proto::Model& model_proto) { - id_counter_ = model_proto.id_counter(); - std::map<int64, std::shared_ptr<Node>> lookup_table; - for (auto node_proto : model_proto.node()) { - std::shared_ptr<Node> node(new Node(node_proto)); - lookup_table[node_proto.id()] = node; - } - for (auto node_proto : model_proto.node()) { - std::shared_ptr<Node> node = lookup_table[node_proto.id()]; - for (int64 id : node_proto.input()) { - node->add_input(lookup_table[id]); - } - node->set_output(lookup_table[node_proto.output()]); - } - output_ = lookup_table[model_proto.output()]; -} - std::shared_ptr<Node> Model::AddNode(const string& name, const string& output_name) { mutex_lock l(mu_); @@ -294,94 +280,77 @@ std::shared_ptr<Node> Model::LookupNode(const string& name) { return result; } -void Model::Optimize() { - mutex_lock l(mu_); - int64 processing_time = ProcessingTime(); - int64 num_cpus = port::NumSchedulableCPUs(); - std::vector<Node::Knob> knobs = CollectKnobs(); - // The optimization algorithm starts by setting all parallelism knobs to 1. It - // then repeatedly identifies the knob that, when turned up by 1, decreases - // the output time the most. This process is repeated until all knobs reach - // the number of schedulable CPUs or the projected output time is less than or - // equal to the processing time needed to produce an element divided by the - // number of schedulable CPUs. - for (auto& knob : knobs) { - LOG(INFO) << knob.node->name() << " " << knob.processing_time; - knob.value = 1; - knob.node->set_metadata("parallelism", knob.value); - } - while (true) { - int64 output_time = OutputTime(); - bool all_knobs = true; - for (auto knob : knobs) { - if (knob.value < num_cpus) { - all_knobs = false; +// The optimization algorithm starts by setting all tunable parallelism +// parameters to 1. It then repeatedly identifies the parameter that whose +// increase in parallelism decreases the output time the most. This process is +// repeated until all parameters reach their maximum values or the +// projected output time is less than or equal to the processing time needed to +// produce an element divided by CPU budget. +void Model::Optimize(int64 cpu_budget) { + mutex_lock l(optimization_mu_); + std::vector<std::shared_ptr<Node::Tunable>> tunables; + { + mutex_lock l2(mu_); + const int64 processing_time = ProcessingTime(); + tunables = CollectTunables(); + for (auto tunable : tunables) { + tunable->value = 1; + } + while (true) { + const int64 output_time = OutputTime(); + bool all_tunables = true; + for (auto& tunable : tunables) { + if (tunable->value < tunable->max) { + all_tunables = false; + break; + } + } + if (output_time < processing_time / cpu_budget || all_tunables) { break; } - } - if (output_time < processing_time / num_cpus || all_knobs) { - break; - } - int64 best_delta = -1; - int best_knob = -1; - for (int i = 0; i < knobs.size(); ++i) { - if (knobs[i].value == num_cpus) { - continue; + int64 best_delta = -1; + Node::Tunable* best_tunable = nullptr; + for (auto& tunable : tunables) { + if (tunable->value == tunable->max) { + continue; + } + tunable->value++; + int64 delta = output_time - OutputTime(); + if (delta > best_delta) { + best_delta = delta; + best_tunable = tunable.get(); + } + tunable->value--; } - knobs[i].node->set_metadata("parallelism", knobs[i].value + 1); - int64 delta = output_time - OutputTime(); - if (delta > best_delta) { - best_delta = delta; - best_knob = i; + if (best_tunable) { + // NOTE: This can happen because we are performing the optimization + // while the model data is changing. If this becomes an issue, we should + // look into performing the optimization using a model snapshot. + break; } - knobs[i].node->set_metadata("parallelism", knobs[i].value); + best_tunable->value++; } - knobs[best_knob].value++; - knobs[best_knob].node->set_metadata("parallelism", knobs[best_knob].value); } - for (auto knob : knobs) { - LOG(INFO) << knob.node->name() << " " << knob.value; + // The `set_fn` functions should be invoked without holding a lock to avoid a + // potential deadlock. + for (auto& tunable : tunables) { + tunable->set_fn(tunable->value); } - LOG(INFO) << "output time: " << OutputTime(); - LOG(INFO) << "processing time: " << ProcessingTime(); -} - -void Model::OutputToFile() { - proto::Model model_proto; - ToProto(&model_proto); - string filename; - Env::Default()->LocalTempFilename(&filename); - TF_CHECK_OK(WriteStringToFile(Env::Default(), filename, - model_proto.SerializeAsString())); - LOG(INFO) << filename; } void Model::RemoveNode(const string& prefix) { - mutex_lock l(mu_); + // Nodes are not allowed to be removed when optimization is in progress to + // prevent the optimization from trying to access an iterator that was + // concurrently deleted. + mutex_lock l(optimization_mu_); + mutex_lock l2(mu_); lookup_table_.erase(prefix); } -void Model::ToProto(proto::Model* model_proto) { - mutex_lock l(mu_); - model_proto->set_id_counter(id_counter_); - model_proto->set_output(output_->id()); - AddNodeToProto(output_, model_proto); -} - -// static -void Model::AddNodeToProto(const std::shared_ptr<Node>& node, - proto::Model* model_proto) { - proto::Node* node_proto = model_proto->add_node(); - node->ToProto(node_proto); - for (const std::shared_ptr<Node>& input : node->inputs()) { - AddNodeToProto(input, model_proto); - } -} - -std::vector<Node::Knob> Model::CollectKnobs() { - std::vector<Node::Knob> knobs; - output_->CollectKnobs(&knobs); - return knobs; +std::vector<std::shared_ptr<Node::Tunable>> Model::CollectTunables() { + std::vector<std::shared_ptr<Node::Tunable>> tunables; + output_->CollectTunables(&tunables); + return tunables; } int64 Model::OutputTime() { diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h index 98172909bf..f88ec06ef3 100644 --- a/tensorflow/core/framework/model.h +++ b/tensorflow/core/framework/model.h @@ -22,7 +22,6 @@ limitations under the License. #include <utility> #include <vector> -#include "tensorflow/core/framework/model.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" @@ -61,13 +60,10 @@ class Node { public: Node(int64 id, std::shared_ptr<Node> output) : id_(id), output_(output) {} - explicit Node(const proto::Node& node_proto) : id_(node_proto.id()) { - name_ = node_proto.name(); - type_ = TypeFromName(node_proto.name()); - processing_time_ = node_proto.processing_time(); - num_elements_ = node_proto.num_elements(); - metadata_.insert(node_proto.metadata().begin(), - node_proto.metadata().end()); + // Adds a constant parameter. + void add_constant_param(const string& name, int64 value) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + constant_params_[name] = value; } // Records that the node produced an element. @@ -88,6 +84,15 @@ class Node { processing_time_ += delta; } + // Adds a tunable parameter. + void add_tunable_param(const string& name, int64 value, int64 min, int64 max, + std::function<void(int64)>&& set_fn) + LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + tunable_params_[name] = + std::make_shared<Tunable>(value, min, max, std::move(set_fn)); + } + // Returns the unique node ID. int64 id() LOCKS_EXCLUDED(mu_) { return id_; } @@ -121,12 +126,6 @@ class Node { inputs_.remove(input); } - // Adds the given key-value pair to the node metadata. - void set_metadata(const string& key, int64 value) LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - metadata_[key] = value; - } - // Sets the node name. void set_name(const string& name) LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); @@ -157,11 +156,16 @@ class Node { } private: - // Represents a performance knob. - struct Knob { - Node* node; - int64 processing_time; + // Represents a tunable parameter. + struct Tunable { + Tunable(int64 value, int64 min, int64 max, + std::function<void(int64)> set_fn) + : value(value), min(min), max(max), set_fn(std::move(set_fn)) {} + int64 value; + int64 min; + int64 max; + std::function<void(int64)> set_fn; }; enum class Type { @@ -186,8 +190,12 @@ class Node { UNKNOWN, }; - // Collects performance knobs in the subtree rooted in this node. - void CollectKnobs(std::vector<Node::Knob>* knobs) LOCKS_EXCLUDED(mu_); + // Collects tunable parameters in the subtree rooted in this node. + void CollectTunables(std::vector<std::shared_ptr<Node::Tunable>>* tunables) + LOCKS_EXCLUDED(mu_); + + // Gets a value of the given parameter (tunable or constant). + int64 GetParameterValue(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Returns the per-element processing time spent in this node. int64 NanosPerElement() LOCKS_EXCLUDED(mu_) { @@ -238,22 +246,6 @@ class Node { return sum; } - // Serializes the node state into the given proto. - void ToProto(proto::Node* node_proto) LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - node_proto->set_id(id_); - node_proto->set_name(name_); - node_proto->set_num_elements(num_elements_); - node_proto->set_processing_time(processing_time_); - for (const std::shared_ptr<Node>& input : inputs_) { - node_proto->add_input(input->id()); - } - if (output_) { - node_proto->set_output(output_->id()); - } - node_proto->mutable_metadata()->insert(metadata_.begin(), metadata_.end()); - } - Type TypeFromName(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (name_ == "Batch") { return Type::BATCH; @@ -319,7 +311,9 @@ class Node { int64 processing_time_ GUARDED_BY(mu_) = 0; int64 num_elements_ GUARDED_BY(mu_) = 0; std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_); - std::map<string, int64> metadata_ GUARDED_BY(mu_); + std::map<string, int64> constant_params_ GUARDED_BY(mu_); + // Tunables are shared with the model during optimization. + std::map<string, std::shared_ptr<Tunable>> tunable_params_ GUARDED_BY(mu_); std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_); std::shared_ptr<Node> output_ GUARDED_BY(mu_); @@ -330,21 +324,15 @@ class Node { // for collecting runtime information and optimizing performance. It collects // runtime information about execution of the input pipeline that is used to // create a performance model, which is in turn used to identify optimal values -// of performance knobs. +// of tunable parameters. // // Developers of tf.data transformations are not expected to interact with this // class directly. Boiler plate code for creating the abstract representation of // the input pipeline and collecting runtime information has been added to the // implementation of `DatasetBase` and `DatasetBaseIterator` respectively. -// -// TODO(jsimsa): Add a mechanism for feeding the result of the optimization -// into the input pipeline. class Model { public: Model() = default; - explicit Model(const proto::Model& model_proto); - - ~Model() {} // Returns the model output node. std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) { @@ -360,30 +348,25 @@ class Model { std::shared_ptr<Node> LookupNode(const string& name) LOCKS_EXCLUDED(mu_); // Runs optimization. - void Optimize() LOCKS_EXCLUDED(mu_); - - // Outputs the state of a model to a file. - // - // TODO(jsimsa): Remove this method once the optimization loop is closed. - void OutputToFile() LOCKS_EXCLUDED(mu_); + void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_); // Removes the node identified by the given name. void RemoveNode(const string& prefix) LOCKS_EXCLUDED(mu_); - // Serializes the model state to the given proto. - void ToProto(proto::Model* model_proto) LOCKS_EXCLUDED(mu_); - private: - static void AddNodeToProto(const std::shared_ptr<Node>& node, - proto::Model* model_proto); - - std::vector<Node::Knob> CollectKnobs() EXCLUSIVE_LOCKS_REQUIRED(mu_); + std::vector<std::shared_ptr<Node::Tunable>> CollectTunables() + EXCLUSIVE_LOCKS_REQUIRED(mu_); int64 OutputTime() EXCLUSIVE_LOCKS_REQUIRED(mu_); int64 ProcessingTime() EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Used for coordination between different input pipeline threads. mutex mu_; + // Used for preventing iterator deletion when optimization is in progress + // because the optimization may try to update the values of tunable + // parameters. + mutex optimization_mu_ ACQUIRED_BEFORE(mu_); int64 id_counter_ GUARDED_BY(mu_) = 1; std::shared_ptr<Node> output_ GUARDED_BY(mu_); std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_); diff --git a/tensorflow/core/framework/model.proto b/tensorflow/core/framework/model.proto deleted file mode 100644 index 26000007af..0000000000 --- a/tensorflow/core/framework/model.proto +++ /dev/null @@ -1,30 +0,0 @@ -syntax = "proto3"; - -package tensorflow.data.model.proto; -option cc_enable_arenas = true; - -message Model { - // Counter used for generating new node IDs. - int64 id_counter = 1; - // Nodes of this model. - repeated Node node = 2; - // The ID of the output node. - int64 output = 3; -}; - -message Node { - // The node ID. - int64 id = 1; - // The node name. - string name = 2; - // Input node IDs. - repeated int64 input = 3; - // Output node ID. - int64 output = 4; - // Number of elements produced by the node. - int64 num_elements = 5; - // The CPU time spent by running threads of this node. - int64 processing_time = 6; - // Key-value store for node metadata (e.g. batch size or parallelism). - map<string, int32> metadata = 7; -}; |