diff options
author | Jiri Simsa <jsimsa@google.com> | 2018-09-17 09:21:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-17 09:24:34 -0700 |
commit | c8a0dfc741736a59f8fd1776b71f38619d66da56 (patch) | |
tree | 0a3ff87aed44e895ca7b3a09a93653f8eea7da59 /tensorflow | |
parent | 07bc3696135483612c727ca7687342922ff0d5de (diff) |
[tf.data] Adding support for `tf.data.AUTOTUNE` as a special value for the `num_parallel_calls` argument of `tf.data.Dataset.map()`, `tf.data.Dataset.interleave()`, and `tf.contrib.data.map_and_batch()`.
When `tf.data.AUTOTUNE` is specified, the level of parallelism is determined at runtime. The underlying mechanism instruments the input pipeline to build a performance model and then uses the model to find the optimal values for the parallelism knobs.
PiperOrigin-RevId: 213283297
Diffstat (limited to 'tensorflow')
18 files changed, 299 insertions, 273 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py index 0a87d3e905..2b3ac85924 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py @@ -58,7 +58,8 @@ class ModelDatasetTest(test.TestCase): dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k), np.random.rand(4 * k, 1))).repeat() - dataset = dataset.map(math_ops.matmul, num_parallel_calls=56) + dataset = dataset.map( + math_ops.matmul, num_parallel_calls=optimization.AUTOTUNE) iterator = dataset.apply(optimization.model()).make_one_shot_iterator() get_next = iterator.get_next() @@ -84,7 +85,9 @@ class ModelDatasetTest(test.TestCase): 1))).repeat() dataset = dataset.apply( batching.map_and_batch( - math_ops.matmul, num_parallel_calls=28, batch_size=batch_size)) + math_ops.matmul, + num_parallel_calls=optimization.AUTOTUNE, + batch_size=batch_size)) iterator = dataset.apply(optimization.model()).make_one_shot_iterator() get_next = iterator.get_next() @@ -109,7 +112,9 @@ class ModelDatasetTest(test.TestCase): 1))).repeat() dataset = dataset.map(math_ops.matmul) dataset = dataset_ops.Dataset.range(1).repeat().interleave( - lambda _: dataset, cycle_length=56, num_parallel_calls=56) + lambda _: dataset, + cycle_length=10, + num_parallel_calls=optimization.AUTOTUNE) iterator = dataset.apply(optimization.model()).make_one_shot_iterator() get_next = iterator.get_next() @@ -146,15 +151,15 @@ class ModelDatasetTest(test.TestCase): x, y = c return a, b, math_ops.matmul(x, y) - dataset = dataset.map(f1, num_parallel_calls=32) + dataset = dataset.map(f1, num_parallel_calls=optimization.AUTOTUNE) dataset = dataset_ops.Dataset.range(1).repeat().interleave( lambda _: dataset, cycle_length=2) - dataset = dataset.map(f2, num_parallel_calls=16) + dataset = dataset.map(f2, num_parallel_calls=optimization.AUTOTUNE) dataset = dataset_ops.Dataset.range(1).repeat().interleave( lambda _: dataset, cycle_length=2) - dataset = dataset.map(f3, num_parallel_calls=10) + dataset = dataset.map(f3, num_parallel_calls=optimization.AUTOTUNE) iterator = dataset.apply(optimization.model()).make_one_shot_iterator() get_next = iterator.get_next() diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt index 1d6d9a60e5..0d8df93d11 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt @@ -10,7 +10,6 @@ tensorflow/core/framework/graph.pb.cc tensorflow/core/framework/graph_transfer_info.pb.cc tensorflow/core/framework/kernel_def.pb.cc tensorflow/core/framework/log_memory.pb.cc -tensorflow/core/framework/model.pb.cc tensorflow/core/framework/node_def.pb.cc tensorflow/core/framework/op_def.pb.cc tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt index 884461ecae..d982df9319 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt @@ -10,7 +10,6 @@ tensorflow/core/framework/graph.pb.h tensorflow/core/framework/graph_transfer_info.pb.h tensorflow/core/framework/kernel_def.pb.h tensorflow/core/framework/log_memory.pb.h -tensorflow/core/framework/model.pb.h tensorflow/core/framework/node_def.pb.h tensorflow/core/framework/op_def.pb.h tensorflow/core/framework/remote_fused_graph_execute_info.pb.h diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt index e23f499214..f94d70db90 100644 --- a/tensorflow/contrib/makefile/tf_pb_text_files.txt +++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt @@ -10,7 +10,6 @@ tensorflow/core/framework/graph.pb_text.cc tensorflow/core/framework/graph_transfer_info.pb_text.cc tensorflow/core/framework/kernel_def.pb_text.cc tensorflow/core/framework/log_memory.pb_text.cc -tensorflow/core/framework/model.pb_text.cc tensorflow/core/framework/node_def.pb_text.cc tensorflow/core/framework/op_def.pb_text.cc tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt index 5eae845d9b..8bec3e3e01 100644 --- a/tensorflow/contrib/makefile/tf_proto_files.txt +++ b/tensorflow/contrib/makefile/tf_proto_files.txt @@ -14,7 +14,6 @@ tensorflow/core/framework/graph.proto tensorflow/core/framework/graph_transfer_info.proto tensorflow/core/framework/kernel_def.proto tensorflow/core/framework/log_memory.proto -tensorflow/core/framework/model.proto tensorflow/core/framework/node_def.proto tensorflow/core/framework/op_def.proto tensorflow/core/framework/reader_base.proto diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 55715bb3a6..4074232c93 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -178,7 +178,6 @@ COMMON_PROTO_SRCS = [ "framework/iterator.proto", "framework/kernel_def.proto", "framework/log_memory.proto", - "framework/model.proto", "framework/node_def.proto", "framework/op_def.proto", "framework/reader_base.proto", @@ -842,7 +841,6 @@ tf_cuda_library( "framework/log_memory.h", "framework/lookup_interface.h", "framework/memory_types.h", - "framework/model.h", "framework/node_def_builder.h", "framework/node_def_util.h", "framework/numeric_op.h", diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index 5281c56f04..284dafb886 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -20,7 +20,6 @@ limitations under the License. namespace tensorflow { namespace data { - namespace { // A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor. diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 4ee6749eea..91b1e61d3c 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -47,6 +47,8 @@ class GraphDefBuilder; class Node; namespace data { +// A constant that can be used to enable auto-tuning. +constexpr int kAutoTune = -1; class DatasetBase; class SerializationContext; @@ -670,13 +672,34 @@ class DatasetBaseIterator : public IteratorBase { return strings::StrCat(params_.prefix, ":", name); } - // When performance modeling is enabled, this method sets metadata entry for - // the model node corresponding to this iterator. - void SetMetadata(IteratorContext* ctx, const string& key, int64 value) { + // When performance modeling is enabled, this method adds a constant parameter + // to the model node corresponding to this iterator. + void AddConstantParameter(IteratorContext* ctx, const string& name, + int64 value) { if (ctx->model()) { std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix()); if (node) { - node->set_metadata(key, value); + node->add_constant_param(name, value); + } + } + } + + // When performance modeling is enabled, this method adds a tunable parameter + // to the model node corresponding to this iterator. + // + // The `set_fn` function should set the tunable parameter to the value of + // its input argument. The function should be thread-safe; in particular, the + // state it updates should be protected by a lock as the function can be + // invoked asynchronously. It is guaranteed that this function will not be + // invoked after the iterator is deleted because the model node that owns + // the function is deleted when the iterator is deleted. + void AddTunableParameter(IteratorContext* ctx, const string& name, + int64 value, int64 min, int64 max, + std::function<void(int64)>&& set_fn) { + if (ctx->model()) { + std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix()); + if (node) { + node->add_tunable_param(name, value, min, max, std::move(set_fn)); } } } diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index 250b006641..b3fe357ea1 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -15,52 +15,28 @@ limitations under the License. #include "tensorflow/core/framework/model.h" +#include <memory> + +#include "tensorflow/core/lib/gtl/map_util.h" + namespace tensorflow { namespace data { namespace model { // TODO(jsimsa): Use `Node` subclassing instead of types and node statements. -void Node::CollectKnobs(std::vector<Node::Knob>* knobs) { +void Node::CollectTunables( + std::vector<std::shared_ptr<Node::Tunable>>* tunables) { mutex_lock l(mu_); + for (auto input : inputs_) { + input->CollectTunables(tunables); + } switch (type_) { - case Type::PARALLEL_INTERLEAVE_V2: { - for (auto input : inputs_) { - input->CollectKnobs(knobs); - } - int64 processing_time = static_cast<int64>( - static_cast<double>(ProcessingTimeLocked() - - inputs_.front()->ProcessingTime()) / - static_cast<double>(inputs_.size() - 1)); - knobs->emplace_back( - Node::Knob{this, processing_time, metadata_["parallelism"]}); - return; - } case Type::MAP_AND_BATCH: + case Type::PARALLEL_INTERLEAVE_V2: case Type::PARALLEL_MAP: { - for (auto input : inputs_) { - input->CollectKnobs(knobs); - } - knobs->emplace_back( - Node::Knob{this, NanosPerElementLocked(), metadata_["parallelism"]}); - return; - } - case Type::BATCH: - case Type::CACHE: - case Type::CONCATENATE: - case Type::FILTER: - case Type::FLAT_MAP: - case Type::INTERLEAVE: - case Type::MAP: - case Type::PADDED_BATCH: - case Type::PARALLEL_INTERLEAVE: - case Type::PREFETCH: - case Type::REPEAT: - case Type::SHUFFLE: - case Type::SKIP: - case Type::TAKE: - case Type::ZIP: { - for (auto input : inputs_) { - input->CollectKnobs(knobs); + if (auto* tunable_param = + gtl::FindOrNull(tunable_params_, "parallelism")) { + tunables->push_back(*tunable_param); } return; } @@ -69,12 +45,19 @@ void Node::CollectKnobs(std::vector<Node::Knob>* knobs) { } } +int64 Node::GetParameterValue(const string& name) { + if (auto* tunable_param = gtl::FindOrNull(tunable_params_, name)) { + return (*tunable_param)->value; + } + return constant_params_[name]; +} + int64 Node::ProcessingTimeLocked() { switch (type_) { case Type::BATCH: case Type::MAP_AND_BATCH: case Type::PADDED_BATCH: { - int64 batch_size = metadata_["batch_size"]; + int64 batch_size = GetParameterValue("batch_size"); return NanosPerElementLocked() + batch_size * ProcessingTimeForInputs(); } case Type::FILTER: { @@ -122,7 +105,7 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) { switch (type_) { case Type::BATCH: case Type::PADDED_BATCH: { - double batch_size = metadata_["batch_size"]; + double batch_size = GetParameterValue("batch_size"); int64 old_value = (*input_times)[input_times->size() - 1]; (*input_times)[input_times->size() - 1] = static_cast<int64>( static_cast<double>(old_value + NanosPerElementLocked()) / @@ -168,8 +151,8 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) { static_cast<double>(inputs_.size() - 1); } case Type::MAP_AND_BATCH: { - double batch_size = metadata_["batch_size"]; - double parallelism = metadata_["parallelism"]; + double batch_size = GetParameterValue("batch_size"); + double parallelism = GetParameterValue("parallelism"); int64 delta = static_cast<int64>(static_cast<double>(NanosPerElementLocked()) / (batch_size * parallelism)); @@ -182,22 +165,41 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) { return std::max(0LL, output_time - input_times->at(input_times->size() - 2)); } - case Type::PARALLEL_INTERLEAVE: + case Type::PARALLEL_INTERLEAVE: { + // TODO(jsimsa): model the first input + if (inputs_.size() <= 1) { + return NanosPerElementLocked(); + } + int64 delta = static_cast<double>(NanosPerElementLocked()) * + static_cast<double>(inputs_.size() - 1); + input_times->push_back(delta); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + int64 inputs_output_time = OutputTimeForInputs(input_times) - + inputs_.front()->OutputTime(input_times); + double parallelism = GetParameterValue("parallelism"); + int64 output_time = + NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) / + static_cast<double>(inputs_.size() - 1)) / + parallelism); + return std::max(0LL, + output_time - input_times->at(input_times->size() - 2)); + } case Type::PARALLEL_INTERLEAVE_V2: { // TODO(jsimsa): model the first input if (inputs_.size() <= 1) { return NanosPerElementLocked(); } - int64 delta = - static_cast<int64>(static_cast<double>(NanosPerElementLocked()) * - static_cast<double>(inputs_.size() - 1)); + int64 delta = static_cast<double>(NanosPerElementLocked()) * + static_cast<double>(inputs_.size() - 1); input_times->push_back(delta); auto cleanup = gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); int64 inputs_output_time = OutputTimeForInputs(input_times) - inputs_.front()->OutputTime(input_times); - double parallelism = std::min(port::NumSchedulableCPUs(), - static_cast<int>(metadata_["parallelism"])); + double parallelism = + std::min(static_cast<int>(GetParameterValue("cycle_length")), + static_cast<int>(GetParameterValue("parallelism"))); int64 output_time = NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) / static_cast<double>(inputs_.size() - 1)) / @@ -206,8 +208,9 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) { output_time - input_times->at(input_times->size() - 2)); } case Type::PARALLEL_MAP: { - double parallelism = std::min(port::NumSchedulableCPUs(), - static_cast<int>(metadata_["parallelism"])); + double parallelism = + std::min(port::NumSchedulableCPUs(), + static_cast<int>(GetParameterValue("parallelism"))); int64 delta = static_cast<int64>( static_cast<double>(NanosPerElementLocked()) / parallelism); input_times->push_back(delta); @@ -248,23 +251,6 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) { } } -Model::Model(const proto::Model& model_proto) { - id_counter_ = model_proto.id_counter(); - std::map<int64, std::shared_ptr<Node>> lookup_table; - for (auto node_proto : model_proto.node()) { - std::shared_ptr<Node> node(new Node(node_proto)); - lookup_table[node_proto.id()] = node; - } - for (auto node_proto : model_proto.node()) { - std::shared_ptr<Node> node = lookup_table[node_proto.id()]; - for (int64 id : node_proto.input()) { - node->add_input(lookup_table[id]); - } - node->set_output(lookup_table[node_proto.output()]); - } - output_ = lookup_table[model_proto.output()]; -} - std::shared_ptr<Node> Model::AddNode(const string& name, const string& output_name) { mutex_lock l(mu_); @@ -294,94 +280,77 @@ std::shared_ptr<Node> Model::LookupNode(const string& name) { return result; } -void Model::Optimize() { - mutex_lock l(mu_); - int64 processing_time = ProcessingTime(); - int64 num_cpus = port::NumSchedulableCPUs(); - std::vector<Node::Knob> knobs = CollectKnobs(); - // The optimization algorithm starts by setting all parallelism knobs to 1. It - // then repeatedly identifies the knob that, when turned up by 1, decreases - // the output time the most. This process is repeated until all knobs reach - // the number of schedulable CPUs or the projected output time is less than or - // equal to the processing time needed to produce an element divided by the - // number of schedulable CPUs. - for (auto& knob : knobs) { - LOG(INFO) << knob.node->name() << " " << knob.processing_time; - knob.value = 1; - knob.node->set_metadata("parallelism", knob.value); - } - while (true) { - int64 output_time = OutputTime(); - bool all_knobs = true; - for (auto knob : knobs) { - if (knob.value < num_cpus) { - all_knobs = false; +// The optimization algorithm starts by setting all tunable parallelism +// parameters to 1. It then repeatedly identifies the parameter that whose +// increase in parallelism decreases the output time the most. This process is +// repeated until all parameters reach their maximum values or the +// projected output time is less than or equal to the processing time needed to +// produce an element divided by CPU budget. +void Model::Optimize(int64 cpu_budget) { + mutex_lock l(optimization_mu_); + std::vector<std::shared_ptr<Node::Tunable>> tunables; + { + mutex_lock l2(mu_); + const int64 processing_time = ProcessingTime(); + tunables = CollectTunables(); + for (auto tunable : tunables) { + tunable->value = 1; + } + while (true) { + const int64 output_time = OutputTime(); + bool all_tunables = true; + for (auto& tunable : tunables) { + if (tunable->value < tunable->max) { + all_tunables = false; + break; + } + } + if (output_time < processing_time / cpu_budget || all_tunables) { break; } - } - if (output_time < processing_time / num_cpus || all_knobs) { - break; - } - int64 best_delta = -1; - int best_knob = -1; - for (int i = 0; i < knobs.size(); ++i) { - if (knobs[i].value == num_cpus) { - continue; + int64 best_delta = -1; + Node::Tunable* best_tunable = nullptr; + for (auto& tunable : tunables) { + if (tunable->value == tunable->max) { + continue; + } + tunable->value++; + int64 delta = output_time - OutputTime(); + if (delta > best_delta) { + best_delta = delta; + best_tunable = tunable.get(); + } + tunable->value--; } - knobs[i].node->set_metadata("parallelism", knobs[i].value + 1); - int64 delta = output_time - OutputTime(); - if (delta > best_delta) { - best_delta = delta; - best_knob = i; + if (best_tunable) { + // NOTE: This can happen because we are performing the optimization + // while the model data is changing. If this becomes an issue, we should + // look into performing the optimization using a model snapshot. + break; } - knobs[i].node->set_metadata("parallelism", knobs[i].value); + best_tunable->value++; } - knobs[best_knob].value++; - knobs[best_knob].node->set_metadata("parallelism", knobs[best_knob].value); } - for (auto knob : knobs) { - LOG(INFO) << knob.node->name() << " " << knob.value; + // The `set_fn` functions should be invoked without holding a lock to avoid a + // potential deadlock. + for (auto& tunable : tunables) { + tunable->set_fn(tunable->value); } - LOG(INFO) << "output time: " << OutputTime(); - LOG(INFO) << "processing time: " << ProcessingTime(); -} - -void Model::OutputToFile() { - proto::Model model_proto; - ToProto(&model_proto); - string filename; - Env::Default()->LocalTempFilename(&filename); - TF_CHECK_OK(WriteStringToFile(Env::Default(), filename, - model_proto.SerializeAsString())); - LOG(INFO) << filename; } void Model::RemoveNode(const string& prefix) { - mutex_lock l(mu_); + // Nodes are not allowed to be removed when optimization is in progress to + // prevent the optimization from trying to access an iterator that was + // concurrently deleted. + mutex_lock l(optimization_mu_); + mutex_lock l2(mu_); lookup_table_.erase(prefix); } -void Model::ToProto(proto::Model* model_proto) { - mutex_lock l(mu_); - model_proto->set_id_counter(id_counter_); - model_proto->set_output(output_->id()); - AddNodeToProto(output_, model_proto); -} - -// static -void Model::AddNodeToProto(const std::shared_ptr<Node>& node, - proto::Model* model_proto) { - proto::Node* node_proto = model_proto->add_node(); - node->ToProto(node_proto); - for (const std::shared_ptr<Node>& input : node->inputs()) { - AddNodeToProto(input, model_proto); - } -} - -std::vector<Node::Knob> Model::CollectKnobs() { - std::vector<Node::Knob> knobs; - output_->CollectKnobs(&knobs); - return knobs; +std::vector<std::shared_ptr<Node::Tunable>> Model::CollectTunables() { + std::vector<std::shared_ptr<Node::Tunable>> tunables; + output_->CollectTunables(&tunables); + return tunables; } int64 Model::OutputTime() { diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h index 98172909bf..f88ec06ef3 100644 --- a/tensorflow/core/framework/model.h +++ b/tensorflow/core/framework/model.h @@ -22,7 +22,6 @@ limitations under the License. #include <utility> #include <vector> -#include "tensorflow/core/framework/model.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" @@ -61,13 +60,10 @@ class Node { public: Node(int64 id, std::shared_ptr<Node> output) : id_(id), output_(output) {} - explicit Node(const proto::Node& node_proto) : id_(node_proto.id()) { - name_ = node_proto.name(); - type_ = TypeFromName(node_proto.name()); - processing_time_ = node_proto.processing_time(); - num_elements_ = node_proto.num_elements(); - metadata_.insert(node_proto.metadata().begin(), - node_proto.metadata().end()); + // Adds a constant parameter. + void add_constant_param(const string& name, int64 value) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + constant_params_[name] = value; } // Records that the node produced an element. @@ -88,6 +84,15 @@ class Node { processing_time_ += delta; } + // Adds a tunable parameter. + void add_tunable_param(const string& name, int64 value, int64 min, int64 max, + std::function<void(int64)>&& set_fn) + LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + tunable_params_[name] = + std::make_shared<Tunable>(value, min, max, std::move(set_fn)); + } + // Returns the unique node ID. int64 id() LOCKS_EXCLUDED(mu_) { return id_; } @@ -121,12 +126,6 @@ class Node { inputs_.remove(input); } - // Adds the given key-value pair to the node metadata. - void set_metadata(const string& key, int64 value) LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - metadata_[key] = value; - } - // Sets the node name. void set_name(const string& name) LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); @@ -157,11 +156,16 @@ class Node { } private: - // Represents a performance knob. - struct Knob { - Node* node; - int64 processing_time; + // Represents a tunable parameter. + struct Tunable { + Tunable(int64 value, int64 min, int64 max, + std::function<void(int64)> set_fn) + : value(value), min(min), max(max), set_fn(std::move(set_fn)) {} + int64 value; + int64 min; + int64 max; + std::function<void(int64)> set_fn; }; enum class Type { @@ -186,8 +190,12 @@ class Node { UNKNOWN, }; - // Collects performance knobs in the subtree rooted in this node. - void CollectKnobs(std::vector<Node::Knob>* knobs) LOCKS_EXCLUDED(mu_); + // Collects tunable parameters in the subtree rooted in this node. + void CollectTunables(std::vector<std::shared_ptr<Node::Tunable>>* tunables) + LOCKS_EXCLUDED(mu_); + + // Gets a value of the given parameter (tunable or constant). + int64 GetParameterValue(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Returns the per-element processing time spent in this node. int64 NanosPerElement() LOCKS_EXCLUDED(mu_) { @@ -238,22 +246,6 @@ class Node { return sum; } - // Serializes the node state into the given proto. - void ToProto(proto::Node* node_proto) LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - node_proto->set_id(id_); - node_proto->set_name(name_); - node_proto->set_num_elements(num_elements_); - node_proto->set_processing_time(processing_time_); - for (const std::shared_ptr<Node>& input : inputs_) { - node_proto->add_input(input->id()); - } - if (output_) { - node_proto->set_output(output_->id()); - } - node_proto->mutable_metadata()->insert(metadata_.begin(), metadata_.end()); - } - Type TypeFromName(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (name_ == "Batch") { return Type::BATCH; @@ -319,7 +311,9 @@ class Node { int64 processing_time_ GUARDED_BY(mu_) = 0; int64 num_elements_ GUARDED_BY(mu_) = 0; std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_); - std::map<string, int64> metadata_ GUARDED_BY(mu_); + std::map<string, int64> constant_params_ GUARDED_BY(mu_); + // Tunables are shared with the model during optimization. + std::map<string, std::shared_ptr<Tunable>> tunable_params_ GUARDED_BY(mu_); std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_); std::shared_ptr<Node> output_ GUARDED_BY(mu_); @@ -330,21 +324,15 @@ class Node { // for collecting runtime information and optimizing performance. It collects // runtime information about execution of the input pipeline that is used to // create a performance model, which is in turn used to identify optimal values -// of performance knobs. +// of tunable parameters. // // Developers of tf.data transformations are not expected to interact with this // class directly. Boiler plate code for creating the abstract representation of // the input pipeline and collecting runtime information has been added to the // implementation of `DatasetBase` and `DatasetBaseIterator` respectively. -// -// TODO(jsimsa): Add a mechanism for feeding the result of the optimization -// into the input pipeline. class Model { public: Model() = default; - explicit Model(const proto::Model& model_proto); - - ~Model() {} // Returns the model output node. std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) { @@ -360,30 +348,25 @@ class Model { std::shared_ptr<Node> LookupNode(const string& name) LOCKS_EXCLUDED(mu_); // Runs optimization. - void Optimize() LOCKS_EXCLUDED(mu_); - - // Outputs the state of a model to a file. - // - // TODO(jsimsa): Remove this method once the optimization loop is closed. - void OutputToFile() LOCKS_EXCLUDED(mu_); + void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_); // Removes the node identified by the given name. void RemoveNode(const string& prefix) LOCKS_EXCLUDED(mu_); - // Serializes the model state to the given proto. - void ToProto(proto::Model* model_proto) LOCKS_EXCLUDED(mu_); - private: - static void AddNodeToProto(const std::shared_ptr<Node>& node, - proto::Model* model_proto); - - std::vector<Node::Knob> CollectKnobs() EXCLUSIVE_LOCKS_REQUIRED(mu_); + std::vector<std::shared_ptr<Node::Tunable>> CollectTunables() + EXCLUSIVE_LOCKS_REQUIRED(mu_); int64 OutputTime() EXCLUSIVE_LOCKS_REQUIRED(mu_); int64 ProcessingTime() EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Used for coordination between different input pipeline threads. mutex mu_; + // Used for preventing iterator deletion when optimization is in progress + // because the optimization may try to update the values of tunable + // parameters. + mutex optimization_mu_ ACQUIRED_BEFORE(mu_); int64 id_counter_ GUARDED_BY(mu_) = 1; std::shared_ptr<Node> output_ GUARDED_BY(mu_); std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_); diff --git a/tensorflow/core/framework/model.proto b/tensorflow/core/framework/model.proto deleted file mode 100644 index 26000007af..0000000000 --- a/tensorflow/core/framework/model.proto +++ /dev/null @@ -1,30 +0,0 @@ -syntax = "proto3"; - -package tensorflow.data.model.proto; -option cc_enable_arenas = true; - -message Model { - // Counter used for generating new node IDs. - int64 id_counter = 1; - // Nodes of this model. - repeated Node node = 2; - // The ID of the output node. - int64 output = 3; -}; - -message Node { - // The node ID. - int64 id = 1; - // The node name. - string name = 2; - // Input node IDs. - repeated int64 input = 3; - // Output node ID. - int64 output = 4; - // Number of elements produced by the node. - int64 num_elements = 5; - // The CPU time spent by running threads of this node. - int64 processing_time = 6; - // Key-value store for node metadata (e.g. batch size or parallelism). - map<string, int32> metadata = 7; -}; diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc index 887b8c8365..d1db1d7bec 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op.cc @@ -117,7 +117,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator<Dataset>(params) {} Status Initialize(IteratorContext* ctx) override { - SetMetadata(ctx, "batch_size", dataset()->batch_size_); + AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); } diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index 85e49355d3..80efac5d4b 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/tracing.h" namespace tensorflow { @@ -39,7 +40,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { public: explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()), op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); @@ -77,7 +77,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { case 2: OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", &num_parallel_calls)); - OP_REQUIRES(ctx, num_parallel_calls > 0, + OP_REQUIRES(ctx, + num_parallel_calls > 0 || num_parallel_calls == kAutoTune, errors::InvalidArgument( "num_parallel_calls must be greater than zero.")); break; @@ -190,7 +191,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator<Dataset> { public: explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params) {} + : DatasetIterator<Dataset>(params), + num_parallel_calls_(params.dataset->num_parallel_calls_) {} ~Iterator() override { mutex_lock l(mu_); @@ -204,8 +206,24 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } Status Initialize(IteratorContext* ctx) override { - SetMetadata(ctx, "batch_size", dataset()->batch_size_); - SetMetadata(ctx, "parallelism", dataset()->num_parallel_calls_); + mutex_lock l(mu_); + AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); + if (num_parallel_calls_ == kAutoTune) { + num_parallel_calls_ = 1; + std::function<void(int64)> set_fn = [this](int64 value) { + { + mutex_lock l(mu_); + num_parallel_calls_ = value; + } + VLOG(2) << "setting parallelism knob to " << value; + cond_var_.notify_all(); + }; + AddTunableParameter( + ctx, "parallelism", num_parallel_calls_ /* value */, 1 /* min */, + port::NumSchedulableCPUs() /* max */, std::move(set_fn)); + } else { + AddConstantParameter(ctx, "parallelism", num_parallel_calls_); + } TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate(ctx); @@ -428,7 +446,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } int MaxBatchResults() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - return (dataset()->num_parallel_calls_ + dataset()->batch_size_ - 1) / + return (num_parallel_calls_ + dataset()->batch_size_ - 1) / dataset()->batch_size_; } @@ -480,15 +498,18 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) LOCKS_EXCLUDED(mu_) { std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls; - new_calls.reserve(dataset()->num_parallel_calls_); StartWork(ctx.get()); auto stop_cleanup = gtl::MakeCleanup([this, &ctx]() { StopWork(ctx.get()); }); + { + tf_shared_lock l(mu_); + new_calls.reserve(num_parallel_calls_); + } while (true) { { mutex_lock l(mu_); while (!cancelled_ && - (num_calls_ >= dataset()->num_parallel_calls_ || + (num_calls_ >= num_parallel_calls_ || batch_results_.size() > MaxBatchResults() || (batch_results_.size() == MaxBatchResults() && call_counter_ % dataset()->batch_size_ == 0))) { @@ -501,7 +522,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { return; } - while (num_calls_ < dataset()->num_parallel_calls_ && + while (num_calls_ < num_parallel_calls_ && (batch_results_.size() < MaxBatchResults() || (batch_results_.size() == MaxBatchResults() && call_counter_ % dataset()->batch_size_ != 0))) { @@ -648,6 +669,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { // user specified level of parallelism and there are slots available in // the `batch_results_` buffer. condition_variable cond_var_; + // Identifies the maximum number of parallel calls. + int64 num_parallel_calls_ GUARDED_BY(mu_) = 0; // Counts the number of outstanding calls for this batch. int64 num_calls_ GUARDED_BY(mu_) = 0; // Counts the total number of calls. @@ -671,7 +694,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { const Eigen::ThreadPoolDevice* device_; // not owned }; - const int graph_def_version_; const int op_version_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc index c7f929dbc1..63025d3371 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.cc +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -17,11 +17,14 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/cpu_info.h" namespace tensorflow { namespace data { namespace { +const int kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMicros; + class ModelDatasetOp : public UnaryDatasetOpKernel { public: explicit ModelDatasetOp(OpKernelConstruction* ctx) @@ -71,9 +74,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator<Dataset> { public: explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params), model_(new model::Model()) {} - - ~Iterator() override { model_->OutputToFile(); } + : DatasetIterator<Dataset>(params), + model_(std::make_shared<model::Model>()) {} Status Initialize(IteratorContext* ctx) override { IteratorContext ctx_with_model(CreateParams(ctx)); @@ -85,6 +87,21 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { std::vector<Tensor>* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); + int64 now = ctx->env()->NowMicros() / EnvTime::kMillisToMicros; + if (last_optimization_ms_ + optimization_period_ms_ < now) { + model_->Optimize(port::NumSchedulableCPUs()); + // Exponentially increase the period of running the optimization until + // a threshold is reached. + if (optimization_period_ms_ < kOptimizationPeriodThresholdMs) { + if (optimization_period_ms_ << 1 < kOptimizationPeriodThresholdMs) { + optimization_period_ms_ <<= 1; + } else { + optimization_period_ms_ = kOptimizationPeriodThresholdMs; + } + } + last_optimization_ms_ = + ctx->env()->NowMicros() / EnvTime::kMillisToMicros; + } IteratorContext ctx_with_model(CreateParams(ctx)); return input_impl_->GetNext(&ctx_with_model, out_tensors, end_of_sequence); @@ -113,6 +130,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { private: mutex mu_; std::shared_ptr<model::Model> model_; + int64 last_optimization_ms_ GUARDED_BY(mu_) = 0; + int64 optimization_period_ms_ GUARDED_BY(mu_) = 10; std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); }; diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc index 73eeafd797..7b01c3b4e0 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc @@ -207,7 +207,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator<Dataset>(params) {} Status Initialize(IteratorContext* ctx) override { - SetMetadata(ctx, "batch_size", dataset()->batch_size_); + AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); } diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index aa5e613e24..2f2db09508 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -252,7 +252,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } Status Initialize(IteratorContext* ctx) override { - SetMetadata(ctx, "parallelism", dataset()->cycle_length_); + AddConstantParameter(ctx, "parallelism", dataset()->cycle_length_); TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate(ctx); @@ -1120,7 +1120,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { int64 num_parallel_calls; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", &num_parallel_calls)); - OP_REQUIRES(ctx, num_parallel_calls > 0, + OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune, errors::InvalidArgument( "num_parallel_calls must be greater than zero.")); OP_REQUIRES( @@ -1233,6 +1233,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { args_list_(params.dataset->cycle_length_), current_elements_(params.dataset->cycle_length_), element_in_use_(params.dataset->cycle_length_, false), + num_parallel_calls_(params.dataset->num_parallel_calls_), thread_pool_(new thread::ThreadPool( Env::Default(), ThreadOptions(), "parallel_interleave", dataset()->cycle_length_ /* num_threads */, @@ -1250,7 +1251,24 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { } Status Initialize(IteratorContext* ctx) override { - SetMetadata(ctx, "parallelism", dataset()->num_parallel_calls_); + mutex_lock l(mu_); + if (num_parallel_calls_ == kAutoTune) { + num_parallel_calls_ = 1; + auto set_fn = [this](int64 value) { + { + mutex_lock l(mu_); + num_parallel_calls_ = value; + } + VLOG(2) << "setting parallelism knob to " << value; + cond_var_.notify_all(); + }; + AddTunableParameter( + ctx, "parallelism", num_parallel_calls_ /* value */, 1 /* min */, + dataset()->cycle_length_ /* max */, std::move(set_fn)); + } else { + AddConstantParameter(ctx, "parallelism", num_parallel_calls_); + } + AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_); TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate(ctx); @@ -1459,7 +1477,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { // not in use and there is space in the `invocation_results_` queue. while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && (element_in_use_[cycle_index_] || - num_calls_ >= dataset()->num_parallel_calls_ || + num_calls_ >= num_parallel_calls_ || invocation_results_.size() >= MaxInvocationResults())) { StopWork(ctx.get()); cond_var_.wait(l); @@ -1472,7 +1490,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { while (!element_in_use_[cycle_index_] && (!end_of_input_ || num_open_ > 0) && - num_calls_ < dataset()->num_parallel_calls_ && + num_calls_ < num_parallel_calls_ && invocation_results_.size() < MaxInvocationResults()) { if (!current_elements_[cycle_index_]) { // Try to create a new iterator from the next input element. @@ -1647,6 +1665,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { // Identifies the number of open iterators. int64 num_open_ GUARDED_BY(mu_) = 0; + // Identifies the maximum number of parallel calls. + int64 num_parallel_calls_ GUARDED_BY(mu_) = 0; + // Identifies the number of outstanding calls. int64 num_calls_ GUARDED_BY(mu_) = 0; diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 0795987431..b584316d69 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -55,7 +55,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { int32 num_parallel_calls; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", &num_parallel_calls)); - OP_REQUIRES(ctx, num_parallel_calls > 0, + OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune, errors::InvalidArgument( "num_parallel_calls must be greater than zero.")); diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index 0b6e587881..5f6052ce83 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -20,6 +20,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/cpu_info.h" namespace tensorflow { namespace data { @@ -55,7 +56,25 @@ class ParallelMapIterator : public DatasetBaseIterator { } Status Initialize(IteratorContext* ctx) override { - SetMetadata(ctx, "parallelism", num_parallel_calls_); + mutex_lock l(mu_); + if (num_parallel_calls_ == kAutoTune) { + num_parallel_calls_ = 1; + auto set_fn = [this](int64 value) { + { + mutex_lock l(mu_); + num_parallel_calls_ = value; + } + VLOG(2) << "setting parallelism knob to " << value; + cond_var_.notify_all(); + }; + // TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and + // use it here for the maximum. + AddTunableParameter(ctx, "parallelism", num_parallel_calls_ /* value */, + 1 /* min */, port::NumSchedulableCPUs() /* max */, + std::move(set_fn)); + } else { + AddConstantParameter(ctx, "parallelism", num_parallel_calls_); + } TF_RETURN_IF_ERROR( input_dataset_->MakeIterator(ctx, prefix(), &input_impl_)); if (init_func_) { @@ -211,8 +230,6 @@ class ParallelMapIterator : public DatasetBaseIterator { std::move(done)); } - int64 MaxInvocationResults() { return num_parallel_calls_; } - Status ProcessResult(const std::shared_ptr<InvocationResult>& result, std::vector<Tensor>* out_tensors, bool* end_of_sequence) { @@ -235,13 +252,16 @@ class ParallelMapIterator : public DatasetBaseIterator { StartWork(ctx.get()); auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); }); std::vector<std::shared_ptr<InvocationResult>> new_calls; - new_calls.reserve(num_parallel_calls_); + { + tf_shared_lock l(mu_); + new_calls.reserve(num_parallel_calls_); + } while (true) { { mutex_lock l(mu_); while (!cancelled_ && (num_calls_ >= num_parallel_calls_ || - invocation_results_.size() >= MaxInvocationResults())) { + invocation_results_.size() >= num_parallel_calls_)) { StopWork(ctx.get()); cond_var_.wait(l); StartWork(ctx.get()); @@ -250,7 +270,7 @@ class ParallelMapIterator : public DatasetBaseIterator { return; } while (num_calls_ < num_parallel_calls_ && - invocation_results_.size() < MaxInvocationResults()) { + invocation_results_.size() < num_parallel_calls_) { invocation_results_.emplace_back(new InvocationResult()); new_calls.push_back(invocation_results_.back()); num_calls_++; @@ -305,7 +325,6 @@ class ParallelMapIterator : public DatasetBaseIterator { const DatasetBase* const input_dataset_; // Not owned. const std::function<Status(IteratorContext*)> init_func_; const ParallelMapIteratorFunction map_func_; - const int32 num_parallel_calls_; // Used for coordination between the main thread and the runner thread. mutex mu_; // Used for coordination between the main thread and the runner thread. In @@ -314,6 +333,8 @@ class ParallelMapIterator : public DatasetBaseIterator { // parallelism and there are slots available in the `invocation_results_` // buffer. condition_variable cond_var_; + // Identifies the maximum number of parallel calls. + int64 num_parallel_calls_ GUARDED_BY(mu_) = 0; // Counts the number of outstanding calls. int64 num_calls_ GUARDED_BY(mu_) = 0; std::unique_ptr<IteratorBase> input_impl_; |