aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-17 09:21:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 09:24:34 -0700
commitc8a0dfc741736a59f8fd1776b71f38619d66da56 (patch)
tree0a3ff87aed44e895ca7b3a09a93653f8eea7da59 /tensorflow
parent07bc3696135483612c727ca7687342922ff0d5de (diff)
[tf.data] Adding support for `tf.data.AUTOTUNE` as a special value for the `num_parallel_calls` argument of `tf.data.Dataset.map()`, `tf.data.Dataset.interleave()`, and `tf.contrib.data.map_and_batch()`.
When `tf.data.AUTOTUNE` is specified, the level of parallelism is determined at runtime. The underlying mechanism instruments the input pipeline to build a performance model and then uses the model to find the optimal values for the parallelism knobs. PiperOrigin-RevId: 213283297
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py17
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_cc_files.txt1
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_h_files.txt1
-rw-r--r--tensorflow/contrib/makefile/tf_pb_text_files.txt1
-rw-r--r--tensorflow/contrib/makefile/tf_proto_files.txt1
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/core/framework/dataset.cc1
-rw-r--r--tensorflow/core/framework/dataset.h31
-rw-r--r--tensorflow/core/framework/model.cc251
-rw-r--r--tensorflow/core/framework/model.h97
-rw-r--r--tensorflow/core/framework/model.proto30
-rw-r--r--tensorflow/core/kernels/data/batch_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc42
-rw-r--r--tensorflow/core/kernels/data/model_dataset_op.cc25
-rw-r--r--tensorflow/core/kernels/data/padded_batch_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc31
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc35
18 files changed, 299 insertions, 273 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
index 0a87d3e905..2b3ac85924 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
@@ -58,7 +58,8 @@ class ModelDatasetTest(test.TestCase):
dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
np.random.rand(4 * k,
1))).repeat()
- dataset = dataset.map(math_ops.matmul, num_parallel_calls=56)
+ dataset = dataset.map(
+ math_ops.matmul, num_parallel_calls=optimization.AUTOTUNE)
iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
get_next = iterator.get_next()
@@ -84,7 +85,9 @@ class ModelDatasetTest(test.TestCase):
1))).repeat()
dataset = dataset.apply(
batching.map_and_batch(
- math_ops.matmul, num_parallel_calls=28, batch_size=batch_size))
+ math_ops.matmul,
+ num_parallel_calls=optimization.AUTOTUNE,
+ batch_size=batch_size))
iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
get_next = iterator.get_next()
@@ -109,7 +112,9 @@ class ModelDatasetTest(test.TestCase):
1))).repeat()
dataset = dataset.map(math_ops.matmul)
dataset = dataset_ops.Dataset.range(1).repeat().interleave(
- lambda _: dataset, cycle_length=56, num_parallel_calls=56)
+ lambda _: dataset,
+ cycle_length=10,
+ num_parallel_calls=optimization.AUTOTUNE)
iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
get_next = iterator.get_next()
@@ -146,15 +151,15 @@ class ModelDatasetTest(test.TestCase):
x, y = c
return a, b, math_ops.matmul(x, y)
- dataset = dataset.map(f1, num_parallel_calls=32)
+ dataset = dataset.map(f1, num_parallel_calls=optimization.AUTOTUNE)
dataset = dataset_ops.Dataset.range(1).repeat().interleave(
lambda _: dataset, cycle_length=2)
- dataset = dataset.map(f2, num_parallel_calls=16)
+ dataset = dataset.map(f2, num_parallel_calls=optimization.AUTOTUNE)
dataset = dataset_ops.Dataset.range(1).repeat().interleave(
lambda _: dataset, cycle_length=2)
- dataset = dataset.map(f3, num_parallel_calls=10)
+ dataset = dataset.map(f3, num_parallel_calls=optimization.AUTOTUNE)
iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
get_next = iterator.get_next()
diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
index 1d6d9a60e5..0d8df93d11 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
@@ -10,7 +10,6 @@ tensorflow/core/framework/graph.pb.cc
tensorflow/core/framework/graph_transfer_info.pb.cc
tensorflow/core/framework/kernel_def.pb.cc
tensorflow/core/framework/log_memory.pb.cc
-tensorflow/core/framework/model.pb.cc
tensorflow/core/framework/node_def.pb.cc
tensorflow/core/framework/op_def.pb.cc
tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc
diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
index 884461ecae..d982df9319 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
@@ -10,7 +10,6 @@ tensorflow/core/framework/graph.pb.h
tensorflow/core/framework/graph_transfer_info.pb.h
tensorflow/core/framework/kernel_def.pb.h
tensorflow/core/framework/log_memory.pb.h
-tensorflow/core/framework/model.pb.h
tensorflow/core/framework/node_def.pb.h
tensorflow/core/framework/op_def.pb.h
tensorflow/core/framework/remote_fused_graph_execute_info.pb.h
diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt
index e23f499214..f94d70db90 100644
--- a/tensorflow/contrib/makefile/tf_pb_text_files.txt
+++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt
@@ -10,7 +10,6 @@ tensorflow/core/framework/graph.pb_text.cc
tensorflow/core/framework/graph_transfer_info.pb_text.cc
tensorflow/core/framework/kernel_def.pb_text.cc
tensorflow/core/framework/log_memory.pb_text.cc
-tensorflow/core/framework/model.pb_text.cc
tensorflow/core/framework/node_def.pb_text.cc
tensorflow/core/framework/op_def.pb_text.cc
tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc
diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt
index 5eae845d9b..8bec3e3e01 100644
--- a/tensorflow/contrib/makefile/tf_proto_files.txt
+++ b/tensorflow/contrib/makefile/tf_proto_files.txt
@@ -14,7 +14,6 @@ tensorflow/core/framework/graph.proto
tensorflow/core/framework/graph_transfer_info.proto
tensorflow/core/framework/kernel_def.proto
tensorflow/core/framework/log_memory.proto
-tensorflow/core/framework/model.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/reader_base.proto
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 55715bb3a6..4074232c93 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -178,7 +178,6 @@ COMMON_PROTO_SRCS = [
"framework/iterator.proto",
"framework/kernel_def.proto",
"framework/log_memory.proto",
- "framework/model.proto",
"framework/node_def.proto",
"framework/op_def.proto",
"framework/reader_base.proto",
@@ -842,7 +841,6 @@ tf_cuda_library(
"framework/log_memory.h",
"framework/lookup_interface.h",
"framework/memory_types.h",
- "framework/model.h",
"framework/node_def_builder.h",
"framework/node_def_util.h",
"framework/numeric_op.h",
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index 5281c56f04..284dafb886 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -20,7 +20,6 @@ limitations under the License.
namespace tensorflow {
namespace data {
-
namespace {
// A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor.
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 4ee6749eea..91b1e61d3c 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -47,6 +47,8 @@ class GraphDefBuilder;
class Node;
namespace data {
+// A constant that can be used to enable auto-tuning.
+constexpr int kAutoTune = -1;
class DatasetBase;
class SerializationContext;
@@ -670,13 +672,34 @@ class DatasetBaseIterator : public IteratorBase {
return strings::StrCat(params_.prefix, ":", name);
}
- // When performance modeling is enabled, this method sets metadata entry for
- // the model node corresponding to this iterator.
- void SetMetadata(IteratorContext* ctx, const string& key, int64 value) {
+ // When performance modeling is enabled, this method adds a constant parameter
+ // to the model node corresponding to this iterator.
+ void AddConstantParameter(IteratorContext* ctx, const string& name,
+ int64 value) {
if (ctx->model()) {
std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
if (node) {
- node->set_metadata(key, value);
+ node->add_constant_param(name, value);
+ }
+ }
+ }
+
+ // When performance modeling is enabled, this method adds a tunable parameter
+ // to the model node corresponding to this iterator.
+ //
+ // The `set_fn` function should set the tunable parameter to the value of
+ // its input argument. The function should be thread-safe; in particular, the
+ // state it updates should be protected by a lock as the function can be
+ // invoked asynchronously. It is guaranteed that this function will not be
+ // invoked after the iterator is deleted because the model node that owns
+ // the function is deleted when the iterator is deleted.
+ void AddTunableParameter(IteratorContext* ctx, const string& name,
+ int64 value, int64 min, int64 max,
+ std::function<void(int64)>&& set_fn) {
+ if (ctx->model()) {
+ std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
+ if (node) {
+ node->add_tunable_param(name, value, min, max, std::move(set_fn));
}
}
}
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
index 250b006641..b3fe357ea1 100644
--- a/tensorflow/core/framework/model.cc
+++ b/tensorflow/core/framework/model.cc
@@ -15,52 +15,28 @@ limitations under the License.
#include "tensorflow/core/framework/model.h"
+#include <memory>
+
+#include "tensorflow/core/lib/gtl/map_util.h"
+
namespace tensorflow {
namespace data {
namespace model {
// TODO(jsimsa): Use `Node` subclassing instead of types and node statements.
-void Node::CollectKnobs(std::vector<Node::Knob>* knobs) {
+void Node::CollectTunables(
+ std::vector<std::shared_ptr<Node::Tunable>>* tunables) {
mutex_lock l(mu_);
+ for (auto input : inputs_) {
+ input->CollectTunables(tunables);
+ }
switch (type_) {
- case Type::PARALLEL_INTERLEAVE_V2: {
- for (auto input : inputs_) {
- input->CollectKnobs(knobs);
- }
- int64 processing_time = static_cast<int64>(
- static_cast<double>(ProcessingTimeLocked() -
- inputs_.front()->ProcessingTime()) /
- static_cast<double>(inputs_.size() - 1));
- knobs->emplace_back(
- Node::Knob{this, processing_time, metadata_["parallelism"]});
- return;
- }
case Type::MAP_AND_BATCH:
+ case Type::PARALLEL_INTERLEAVE_V2:
case Type::PARALLEL_MAP: {
- for (auto input : inputs_) {
- input->CollectKnobs(knobs);
- }
- knobs->emplace_back(
- Node::Knob{this, NanosPerElementLocked(), metadata_["parallelism"]});
- return;
- }
- case Type::BATCH:
- case Type::CACHE:
- case Type::CONCATENATE:
- case Type::FILTER:
- case Type::FLAT_MAP:
- case Type::INTERLEAVE:
- case Type::MAP:
- case Type::PADDED_BATCH:
- case Type::PARALLEL_INTERLEAVE:
- case Type::PREFETCH:
- case Type::REPEAT:
- case Type::SHUFFLE:
- case Type::SKIP:
- case Type::TAKE:
- case Type::ZIP: {
- for (auto input : inputs_) {
- input->CollectKnobs(knobs);
+ if (auto* tunable_param =
+ gtl::FindOrNull(tunable_params_, "parallelism")) {
+ tunables->push_back(*tunable_param);
}
return;
}
@@ -69,12 +45,19 @@ void Node::CollectKnobs(std::vector<Node::Knob>* knobs) {
}
}
+int64 Node::GetParameterValue(const string& name) {
+ if (auto* tunable_param = gtl::FindOrNull(tunable_params_, name)) {
+ return (*tunable_param)->value;
+ }
+ return constant_params_[name];
+}
+
int64 Node::ProcessingTimeLocked() {
switch (type_) {
case Type::BATCH:
case Type::MAP_AND_BATCH:
case Type::PADDED_BATCH: {
- int64 batch_size = metadata_["batch_size"];
+ int64 batch_size = GetParameterValue("batch_size");
return NanosPerElementLocked() + batch_size * ProcessingTimeForInputs();
}
case Type::FILTER: {
@@ -122,7 +105,7 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
switch (type_) {
case Type::BATCH:
case Type::PADDED_BATCH: {
- double batch_size = metadata_["batch_size"];
+ double batch_size = GetParameterValue("batch_size");
int64 old_value = (*input_times)[input_times->size() - 1];
(*input_times)[input_times->size() - 1] = static_cast<int64>(
static_cast<double>(old_value + NanosPerElementLocked()) /
@@ -168,8 +151,8 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
static_cast<double>(inputs_.size() - 1);
}
case Type::MAP_AND_BATCH: {
- double batch_size = metadata_["batch_size"];
- double parallelism = metadata_["parallelism"];
+ double batch_size = GetParameterValue("batch_size");
+ double parallelism = GetParameterValue("parallelism");
int64 delta =
static_cast<int64>(static_cast<double>(NanosPerElementLocked()) /
(batch_size * parallelism));
@@ -182,22 +165,41 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
return std::max(0LL,
output_time - input_times->at(input_times->size() - 2));
}
- case Type::PARALLEL_INTERLEAVE:
+ case Type::PARALLEL_INTERLEAVE: {
+ // TODO(jsimsa): model the first input
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 delta = static_cast<double>(NanosPerElementLocked()) *
+ static_cast<double>(inputs_.size() - 1);
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 inputs_output_time = OutputTimeForInputs(input_times) -
+ inputs_.front()->OutputTime(input_times);
+ double parallelism = GetParameterValue("parallelism");
+ int64 output_time =
+ NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) /
+ static_cast<double>(inputs_.size() - 1)) /
+ parallelism);
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
case Type::PARALLEL_INTERLEAVE_V2: {
// TODO(jsimsa): model the first input
if (inputs_.size() <= 1) {
return NanosPerElementLocked();
}
- int64 delta =
- static_cast<int64>(static_cast<double>(NanosPerElementLocked()) *
- static_cast<double>(inputs_.size() - 1));
+ int64 delta = static_cast<double>(NanosPerElementLocked()) *
+ static_cast<double>(inputs_.size() - 1);
input_times->push_back(delta);
auto cleanup =
gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
int64 inputs_output_time = OutputTimeForInputs(input_times) -
inputs_.front()->OutputTime(input_times);
- double parallelism = std::min(port::NumSchedulableCPUs(),
- static_cast<int>(metadata_["parallelism"]));
+ double parallelism =
+ std::min(static_cast<int>(GetParameterValue("cycle_length")),
+ static_cast<int>(GetParameterValue("parallelism")));
int64 output_time =
NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) /
static_cast<double>(inputs_.size() - 1)) /
@@ -206,8 +208,9 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
output_time - input_times->at(input_times->size() - 2));
}
case Type::PARALLEL_MAP: {
- double parallelism = std::min(port::NumSchedulableCPUs(),
- static_cast<int>(metadata_["parallelism"]));
+ double parallelism =
+ std::min(port::NumSchedulableCPUs(),
+ static_cast<int>(GetParameterValue("parallelism")));
int64 delta = static_cast<int64>(
static_cast<double>(NanosPerElementLocked()) / parallelism);
input_times->push_back(delta);
@@ -248,23 +251,6 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
}
}
-Model::Model(const proto::Model& model_proto) {
- id_counter_ = model_proto.id_counter();
- std::map<int64, std::shared_ptr<Node>> lookup_table;
- for (auto node_proto : model_proto.node()) {
- std::shared_ptr<Node> node(new Node(node_proto));
- lookup_table[node_proto.id()] = node;
- }
- for (auto node_proto : model_proto.node()) {
- std::shared_ptr<Node> node = lookup_table[node_proto.id()];
- for (int64 id : node_proto.input()) {
- node->add_input(lookup_table[id]);
- }
- node->set_output(lookup_table[node_proto.output()]);
- }
- output_ = lookup_table[model_proto.output()];
-}
-
std::shared_ptr<Node> Model::AddNode(const string& name,
const string& output_name) {
mutex_lock l(mu_);
@@ -294,94 +280,77 @@ std::shared_ptr<Node> Model::LookupNode(const string& name) {
return result;
}
-void Model::Optimize() {
- mutex_lock l(mu_);
- int64 processing_time = ProcessingTime();
- int64 num_cpus = port::NumSchedulableCPUs();
- std::vector<Node::Knob> knobs = CollectKnobs();
- // The optimization algorithm starts by setting all parallelism knobs to 1. It
- // then repeatedly identifies the knob that, when turned up by 1, decreases
- // the output time the most. This process is repeated until all knobs reach
- // the number of schedulable CPUs or the projected output time is less than or
- // equal to the processing time needed to produce an element divided by the
- // number of schedulable CPUs.
- for (auto& knob : knobs) {
- LOG(INFO) << knob.node->name() << " " << knob.processing_time;
- knob.value = 1;
- knob.node->set_metadata("parallelism", knob.value);
- }
- while (true) {
- int64 output_time = OutputTime();
- bool all_knobs = true;
- for (auto knob : knobs) {
- if (knob.value < num_cpus) {
- all_knobs = false;
+// The optimization algorithm starts by setting all tunable parallelism
+// parameters to 1. It then repeatedly identifies the parameter that whose
+// increase in parallelism decreases the output time the most. This process is
+// repeated until all parameters reach their maximum values or the
+// projected output time is less than or equal to the processing time needed to
+// produce an element divided by CPU budget.
+void Model::Optimize(int64 cpu_budget) {
+ mutex_lock l(optimization_mu_);
+ std::vector<std::shared_ptr<Node::Tunable>> tunables;
+ {
+ mutex_lock l2(mu_);
+ const int64 processing_time = ProcessingTime();
+ tunables = CollectTunables();
+ for (auto tunable : tunables) {
+ tunable->value = 1;
+ }
+ while (true) {
+ const int64 output_time = OutputTime();
+ bool all_tunables = true;
+ for (auto& tunable : tunables) {
+ if (tunable->value < tunable->max) {
+ all_tunables = false;
+ break;
+ }
+ }
+ if (output_time < processing_time / cpu_budget || all_tunables) {
break;
}
- }
- if (output_time < processing_time / num_cpus || all_knobs) {
- break;
- }
- int64 best_delta = -1;
- int best_knob = -1;
- for (int i = 0; i < knobs.size(); ++i) {
- if (knobs[i].value == num_cpus) {
- continue;
+ int64 best_delta = -1;
+ Node::Tunable* best_tunable = nullptr;
+ for (auto& tunable : tunables) {
+ if (tunable->value == tunable->max) {
+ continue;
+ }
+ tunable->value++;
+ int64 delta = output_time - OutputTime();
+ if (delta > best_delta) {
+ best_delta = delta;
+ best_tunable = tunable.get();
+ }
+ tunable->value--;
}
- knobs[i].node->set_metadata("parallelism", knobs[i].value + 1);
- int64 delta = output_time - OutputTime();
- if (delta > best_delta) {
- best_delta = delta;
- best_knob = i;
+ if (best_tunable) {
+ // NOTE: This can happen because we are performing the optimization
+ // while the model data is changing. If this becomes an issue, we should
+ // look into performing the optimization using a model snapshot.
+ break;
}
- knobs[i].node->set_metadata("parallelism", knobs[i].value);
+ best_tunable->value++;
}
- knobs[best_knob].value++;
- knobs[best_knob].node->set_metadata("parallelism", knobs[best_knob].value);
}
- for (auto knob : knobs) {
- LOG(INFO) << knob.node->name() << " " << knob.value;
+ // The `set_fn` functions should be invoked without holding a lock to avoid a
+ // potential deadlock.
+ for (auto& tunable : tunables) {
+ tunable->set_fn(tunable->value);
}
- LOG(INFO) << "output time: " << OutputTime();
- LOG(INFO) << "processing time: " << ProcessingTime();
-}
-
-void Model::OutputToFile() {
- proto::Model model_proto;
- ToProto(&model_proto);
- string filename;
- Env::Default()->LocalTempFilename(&filename);
- TF_CHECK_OK(WriteStringToFile(Env::Default(), filename,
- model_proto.SerializeAsString()));
- LOG(INFO) << filename;
}
void Model::RemoveNode(const string& prefix) {
- mutex_lock l(mu_);
+ // Nodes are not allowed to be removed when optimization is in progress to
+ // prevent the optimization from trying to access an iterator that was
+ // concurrently deleted.
+ mutex_lock l(optimization_mu_);
+ mutex_lock l2(mu_);
lookup_table_.erase(prefix);
}
-void Model::ToProto(proto::Model* model_proto) {
- mutex_lock l(mu_);
- model_proto->set_id_counter(id_counter_);
- model_proto->set_output(output_->id());
- AddNodeToProto(output_, model_proto);
-}
-
-// static
-void Model::AddNodeToProto(const std::shared_ptr<Node>& node,
- proto::Model* model_proto) {
- proto::Node* node_proto = model_proto->add_node();
- node->ToProto(node_proto);
- for (const std::shared_ptr<Node>& input : node->inputs()) {
- AddNodeToProto(input, model_proto);
- }
-}
-
-std::vector<Node::Knob> Model::CollectKnobs() {
- std::vector<Node::Knob> knobs;
- output_->CollectKnobs(&knobs);
- return knobs;
+std::vector<std::shared_ptr<Node::Tunable>> Model::CollectTunables() {
+ std::vector<std::shared_ptr<Node::Tunable>> tunables;
+ output_->CollectTunables(&tunables);
+ return tunables;
}
int64 Model::OutputTime() {
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
index 98172909bf..f88ec06ef3 100644
--- a/tensorflow/core/framework/model.h
+++ b/tensorflow/core/framework/model.h
@@ -22,7 +22,6 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/core/framework/model.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
@@ -61,13 +60,10 @@ class Node {
public:
Node(int64 id, std::shared_ptr<Node> output) : id_(id), output_(output) {}
- explicit Node(const proto::Node& node_proto) : id_(node_proto.id()) {
- name_ = node_proto.name();
- type_ = TypeFromName(node_proto.name());
- processing_time_ = node_proto.processing_time();
- num_elements_ = node_proto.num_elements();
- metadata_.insert(node_proto.metadata().begin(),
- node_proto.metadata().end());
+ // Adds a constant parameter.
+ void add_constant_param(const string& name, int64 value) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ constant_params_[name] = value;
}
// Records that the node produced an element.
@@ -88,6 +84,15 @@ class Node {
processing_time_ += delta;
}
+ // Adds a tunable parameter.
+ void add_tunable_param(const string& name, int64 value, int64 min, int64 max,
+ std::function<void(int64)>&& set_fn)
+ LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ tunable_params_[name] =
+ std::make_shared<Tunable>(value, min, max, std::move(set_fn));
+ }
+
// Returns the unique node ID.
int64 id() LOCKS_EXCLUDED(mu_) { return id_; }
@@ -121,12 +126,6 @@ class Node {
inputs_.remove(input);
}
- // Adds the given key-value pair to the node metadata.
- void set_metadata(const string& key, int64 value) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- metadata_[key] = value;
- }
-
// Sets the node name.
void set_name(const string& name) LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
@@ -157,11 +156,16 @@ class Node {
}
private:
- // Represents a performance knob.
- struct Knob {
- Node* node;
- int64 processing_time;
+ // Represents a tunable parameter.
+ struct Tunable {
+ Tunable(int64 value, int64 min, int64 max,
+ std::function<void(int64)> set_fn)
+ : value(value), min(min), max(max), set_fn(std::move(set_fn)) {}
+
int64 value;
+ int64 min;
+ int64 max;
+ std::function<void(int64)> set_fn;
};
enum class Type {
@@ -186,8 +190,12 @@ class Node {
UNKNOWN,
};
- // Collects performance knobs in the subtree rooted in this node.
- void CollectKnobs(std::vector<Node::Knob>* knobs) LOCKS_EXCLUDED(mu_);
+ // Collects tunable parameters in the subtree rooted in this node.
+ void CollectTunables(std::vector<std::shared_ptr<Node::Tunable>>* tunables)
+ LOCKS_EXCLUDED(mu_);
+
+ // Gets a value of the given parameter (tunable or constant).
+ int64 GetParameterValue(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Returns the per-element processing time spent in this node.
int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
@@ -238,22 +246,6 @@ class Node {
return sum;
}
- // Serializes the node state into the given proto.
- void ToProto(proto::Node* node_proto) LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- node_proto->set_id(id_);
- node_proto->set_name(name_);
- node_proto->set_num_elements(num_elements_);
- node_proto->set_processing_time(processing_time_);
- for (const std::shared_ptr<Node>& input : inputs_) {
- node_proto->add_input(input->id());
- }
- if (output_) {
- node_proto->set_output(output_->id());
- }
- node_proto->mutable_metadata()->insert(metadata_.begin(), metadata_.end());
- }
-
Type TypeFromName(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (name_ == "Batch") {
return Type::BATCH;
@@ -319,7 +311,9 @@ class Node {
int64 processing_time_ GUARDED_BY(mu_) = 0;
int64 num_elements_ GUARDED_BY(mu_) = 0;
std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
- std::map<string, int64> metadata_ GUARDED_BY(mu_);
+ std::map<string, int64> constant_params_ GUARDED_BY(mu_);
+ // Tunables are shared with the model during optimization.
+ std::map<string, std::shared_ptr<Tunable>> tunable_params_ GUARDED_BY(mu_);
std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
std::shared_ptr<Node> output_ GUARDED_BY(mu_);
@@ -330,21 +324,15 @@ class Node {
// for collecting runtime information and optimizing performance. It collects
// runtime information about execution of the input pipeline that is used to
// create a performance model, which is in turn used to identify optimal values
-// of performance knobs.
+// of tunable parameters.
//
// Developers of tf.data transformations are not expected to interact with this
// class directly. Boiler plate code for creating the abstract representation of
// the input pipeline and collecting runtime information has been added to the
// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
-//
-// TODO(jsimsa): Add a mechanism for feeding the result of the optimization
-// into the input pipeline.
class Model {
public:
Model() = default;
- explicit Model(const proto::Model& model_proto);
-
- ~Model() {}
// Returns the model output node.
std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
@@ -360,30 +348,25 @@ class Model {
std::shared_ptr<Node> LookupNode(const string& name) LOCKS_EXCLUDED(mu_);
// Runs optimization.
- void Optimize() LOCKS_EXCLUDED(mu_);
-
- // Outputs the state of a model to a file.
- //
- // TODO(jsimsa): Remove this method once the optimization loop is closed.
- void OutputToFile() LOCKS_EXCLUDED(mu_);
+ void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_);
// Removes the node identified by the given name.
void RemoveNode(const string& prefix) LOCKS_EXCLUDED(mu_);
- // Serializes the model state to the given proto.
- void ToProto(proto::Model* model_proto) LOCKS_EXCLUDED(mu_);
-
private:
- static void AddNodeToProto(const std::shared_ptr<Node>& node,
- proto::Model* model_proto);
-
- std::vector<Node::Knob> CollectKnobs() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ std::vector<std::shared_ptr<Node::Tunable>> CollectTunables()
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
int64 OutputTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
int64 ProcessingTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ // Used for coordination between different input pipeline threads.
mutex mu_;
+ // Used for preventing iterator deletion when optimization is in progress
+ // because the optimization may try to update the values of tunable
+ // parameters.
+ mutex optimization_mu_ ACQUIRED_BEFORE(mu_);
int64 id_counter_ GUARDED_BY(mu_) = 1;
std::shared_ptr<Node> output_ GUARDED_BY(mu_);
std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/framework/model.proto b/tensorflow/core/framework/model.proto
deleted file mode 100644
index 26000007af..0000000000
--- a/tensorflow/core/framework/model.proto
+++ /dev/null
@@ -1,30 +0,0 @@
-syntax = "proto3";
-
-package tensorflow.data.model.proto;
-option cc_enable_arenas = true;
-
-message Model {
- // Counter used for generating new node IDs.
- int64 id_counter = 1;
- // Nodes of this model.
- repeated Node node = 2;
- // The ID of the output node.
- int64 output = 3;
-};
-
-message Node {
- // The node ID.
- int64 id = 1;
- // The node name.
- string name = 2;
- // Input node IDs.
- repeated int64 input = 3;
- // Output node ID.
- int64 output = 4;
- // Number of elements produced by the node.
- int64 num_elements = 5;
- // The CPU time spent by running threads of this node.
- int64 processing_time = 6;
- // Key-value store for node metadata (e.g. batch size or parallelism).
- map<string, int32> metadata = 7;
-};
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index 887b8c8365..d1db1d7bec 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -117,7 +117,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- SetMetadata(ctx, "batch_size", dataset()->batch_size_);
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 85e49355d3..80efac5d4b 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h"
namespace tensorflow {
@@ -39,7 +40,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()),
op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -77,7 +77,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
case 2:
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
- OP_REQUIRES(ctx, num_parallel_calls > 0,
+ OP_REQUIRES(ctx,
+ num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
break;
@@ -190,7 +191,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
+ : DatasetIterator<Dataset>(params),
+ num_parallel_calls_(params.dataset->num_parallel_calls_) {}
~Iterator() override {
mutex_lock l(mu_);
@@ -204,8 +206,24 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
- SetMetadata(ctx, "batch_size", dataset()->batch_size_);
- SetMetadata(ctx, "parallelism", dataset()->num_parallel_calls_);
+ mutex_lock l(mu_);
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
+ if (num_parallel_calls_ == kAutoTune) {
+ num_parallel_calls_ = 1;
+ std::function<void(int64)> set_fn = [this](int64 value) {
+ {
+ mutex_lock l(mu_);
+ num_parallel_calls_ = value;
+ }
+ VLOG(2) << "setting parallelism knob to " << value;
+ cond_var_.notify_all();
+ };
+ AddTunableParameter(
+ ctx, "parallelism", num_parallel_calls_ /* value */, 1 /* min */,
+ port::NumSchedulableCPUs() /* max */, std::move(set_fn));
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ }
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -428,7 +446,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
int MaxBatchResults() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- return (dataset()->num_parallel_calls_ + dataset()->batch_size_ - 1) /
+ return (num_parallel_calls_ + dataset()->batch_size_ - 1) /
dataset()->batch_size_;
}
@@ -480,15 +498,18 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
LOCKS_EXCLUDED(mu_) {
std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
- new_calls.reserve(dataset()->num_parallel_calls_);
StartWork(ctx.get());
auto stop_cleanup =
gtl::MakeCleanup([this, &ctx]() { StopWork(ctx.get()); });
+ {
+ tf_shared_lock l(mu_);
+ new_calls.reserve(num_parallel_calls_);
+ }
while (true) {
{
mutex_lock l(mu_);
while (!cancelled_ &&
- (num_calls_ >= dataset()->num_parallel_calls_ ||
+ (num_calls_ >= num_parallel_calls_ ||
batch_results_.size() > MaxBatchResults() ||
(batch_results_.size() == MaxBatchResults() &&
call_counter_ % dataset()->batch_size_ == 0))) {
@@ -501,7 +522,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
return;
}
- while (num_calls_ < dataset()->num_parallel_calls_ &&
+ while (num_calls_ < num_parallel_calls_ &&
(batch_results_.size() < MaxBatchResults() ||
(batch_results_.size() == MaxBatchResults() &&
call_counter_ % dataset()->batch_size_ != 0))) {
@@ -648,6 +669,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
// user specified level of parallelism and there are slots available in
// the `batch_results_` buffer.
condition_variable cond_var_;
+ // Identifies the maximum number of parallel calls.
+ int64 num_parallel_calls_ GUARDED_BY(mu_) = 0;
// Counts the number of outstanding calls for this batch.
int64 num_calls_ GUARDED_BY(mu_) = 0;
// Counts the total number of calls.
@@ -671,7 +694,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const Eigen::ThreadPoolDevice* device_; // not owned
};
- const int graph_def_version_;
const int op_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc
index c7f929dbc1..63025d3371 100644
--- a/tensorflow/core/kernels/data/model_dataset_op.cc
+++ b/tensorflow/core/kernels/data/model_dataset_op.cc
@@ -17,11 +17,14 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/cpu_info.h"
namespace tensorflow {
namespace data {
namespace {
+const int kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMicros;
+
class ModelDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ModelDatasetOp(OpKernelConstruction* ctx)
@@ -71,9 +74,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params), model_(new model::Model()) {}
-
- ~Iterator() override { model_->OutputToFile(); }
+ : DatasetIterator<Dataset>(params),
+ model_(std::make_shared<model::Model>()) {}
Status Initialize(IteratorContext* ctx) override {
IteratorContext ctx_with_model(CreateParams(ctx));
@@ -85,6 +87,21 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
+ int64 now = ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
+ if (last_optimization_ms_ + optimization_period_ms_ < now) {
+ model_->Optimize(port::NumSchedulableCPUs());
+ // Exponentially increase the period of running the optimization until
+ // a threshold is reached.
+ if (optimization_period_ms_ < kOptimizationPeriodThresholdMs) {
+ if (optimization_period_ms_ << 1 < kOptimizationPeriodThresholdMs) {
+ optimization_period_ms_ <<= 1;
+ } else {
+ optimization_period_ms_ = kOptimizationPeriodThresholdMs;
+ }
+ }
+ last_optimization_ms_ =
+ ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
+ }
IteratorContext ctx_with_model(CreateParams(ctx));
return input_impl_->GetNext(&ctx_with_model, out_tensors,
end_of_sequence);
@@ -113,6 +130,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
private:
mutex mu_;
std::shared_ptr<model::Model> model_;
+ int64 last_optimization_ms_ GUARDED_BY(mu_) = 0;
+ int64 optimization_period_ms_ GUARDED_BY(mu_) = 10;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
index 73eeafd797..7b01c3b4e0 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
@@ -207,7 +207,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- SetMetadata(ctx, "batch_size", dataset()->batch_size_);
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index aa5e613e24..2f2db09508 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -252,7 +252,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
- SetMetadata(ctx, "parallelism", dataset()->cycle_length_);
+ AddConstantParameter(ctx, "parallelism", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -1120,7 +1120,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
int64 num_parallel_calls;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
- OP_REQUIRES(ctx, num_parallel_calls > 0,
+ OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
OP_REQUIRES(
@@ -1233,6 +1233,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
args_list_(params.dataset->cycle_length_),
current_elements_(params.dataset->cycle_length_),
element_in_use_(params.dataset->cycle_length_, false),
+ num_parallel_calls_(params.dataset->num_parallel_calls_),
thread_pool_(new thread::ThreadPool(
Env::Default(), ThreadOptions(), "parallel_interleave",
dataset()->cycle_length_ /* num_threads */,
@@ -1250,7 +1251,24 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
- SetMetadata(ctx, "parallelism", dataset()->num_parallel_calls_);
+ mutex_lock l(mu_);
+ if (num_parallel_calls_ == kAutoTune) {
+ num_parallel_calls_ = 1;
+ auto set_fn = [this](int64 value) {
+ {
+ mutex_lock l(mu_);
+ num_parallel_calls_ = value;
+ }
+ VLOG(2) << "setting parallelism knob to " << value;
+ cond_var_.notify_all();
+ };
+ AddTunableParameter(
+ ctx, "parallelism", num_parallel_calls_ /* value */, 1 /* min */,
+ dataset()->cycle_length_ /* max */, std::move(set_fn));
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ }
+ AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -1459,7 +1477,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// not in use and there is space in the `invocation_results_` queue.
while (!cancelled_ && (!end_of_input_ || num_open_ > 0) &&
(element_in_use_[cycle_index_] ||
- num_calls_ >= dataset()->num_parallel_calls_ ||
+ num_calls_ >= num_parallel_calls_ ||
invocation_results_.size() >= MaxInvocationResults())) {
StopWork(ctx.get());
cond_var_.wait(l);
@@ -1472,7 +1490,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
while (!element_in_use_[cycle_index_] &&
(!end_of_input_ || num_open_ > 0) &&
- num_calls_ < dataset()->num_parallel_calls_ &&
+ num_calls_ < num_parallel_calls_ &&
invocation_results_.size() < MaxInvocationResults()) {
if (!current_elements_[cycle_index_]) {
// Try to create a new iterator from the next input element.
@@ -1647,6 +1665,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// Identifies the number of open iterators.
int64 num_open_ GUARDED_BY(mu_) = 0;
+ // Identifies the maximum number of parallel calls.
+ int64 num_parallel_calls_ GUARDED_BY(mu_) = 0;
+
// Identifies the number of outstanding calls.
int64 num_calls_ GUARDED_BY(mu_) = 0;
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index 0795987431..b584316d69 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -55,7 +55,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
int32 num_parallel_calls;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
- OP_REQUIRES(ctx, num_parallel_calls > 0,
+ OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 0b6e587881..5f6052ce83 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/cpu_info.h"
namespace tensorflow {
namespace data {
@@ -55,7 +56,25 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status Initialize(IteratorContext* ctx) override {
- SetMetadata(ctx, "parallelism", num_parallel_calls_);
+ mutex_lock l(mu_);
+ if (num_parallel_calls_ == kAutoTune) {
+ num_parallel_calls_ = 1;
+ auto set_fn = [this](int64 value) {
+ {
+ mutex_lock l(mu_);
+ num_parallel_calls_ = value;
+ }
+ VLOG(2) << "setting parallelism knob to " << value;
+ cond_var_.notify_all();
+ };
+ // TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and
+ // use it here for the maximum.
+ AddTunableParameter(ctx, "parallelism", num_parallel_calls_ /* value */,
+ 1 /* min */, port::NumSchedulableCPUs() /* max */,
+ std::move(set_fn));
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ }
TF_RETURN_IF_ERROR(
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
if (init_func_) {
@@ -211,8 +230,6 @@ class ParallelMapIterator : public DatasetBaseIterator {
std::move(done));
}
- int64 MaxInvocationResults() { return num_parallel_calls_; }
-
Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
@@ -235,13 +252,16 @@ class ParallelMapIterator : public DatasetBaseIterator {
StartWork(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
- new_calls.reserve(num_parallel_calls_);
+ {
+ tf_shared_lock l(mu_);
+ new_calls.reserve(num_parallel_calls_);
+ }
while (true) {
{
mutex_lock l(mu_);
while (!cancelled_ &&
(num_calls_ >= num_parallel_calls_ ||
- invocation_results_.size() >= MaxInvocationResults())) {
+ invocation_results_.size() >= num_parallel_calls_)) {
StopWork(ctx.get());
cond_var_.wait(l);
StartWork(ctx.get());
@@ -250,7 +270,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
return;
}
while (num_calls_ < num_parallel_calls_ &&
- invocation_results_.size() < MaxInvocationResults()) {
+ invocation_results_.size() < num_parallel_calls_) {
invocation_results_.emplace_back(new InvocationResult());
new_calls.push_back(invocation_results_.back());
num_calls_++;
@@ -305,7 +325,6 @@ class ParallelMapIterator : public DatasetBaseIterator {
const DatasetBase* const input_dataset_; // Not owned.
const std::function<Status(IteratorContext*)> init_func_;
const ParallelMapIteratorFunction map_func_;
- const int32 num_parallel_calls_;
// Used for coordination between the main thread and the runner thread.
mutex mu_;
// Used for coordination between the main thread and the runner thread. In
@@ -314,6 +333,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
// parallelism and there are slots available in the `invocation_results_`
// buffer.
condition_variable cond_var_;
+ // Identifies the maximum number of parallel calls.
+ int64 num_parallel_calls_ GUARDED_BY(mu_) = 0;
// Counts the number of outstanding calls.
int64 num_calls_ GUARDED_BY(mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;