aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/model.cc
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-20 16:48:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 16:53:24 -0700
commit0e1efc3d9129c740a16081fdc53bdc482f8f0c11 (patch)
treefd153014209d962e8fccf90c31a372bc59ec38a9 /tensorflow/core/framework/model.cc
parent863b3bd6bb2214065f95fa20f551a8fc8568e55d (diff)
[tf.data] Moving auto-tuning optimizations into a background thread, refactoring the API for exposing tunable parameters, and removing `model::Node` from the public API.
PiperOrigin-RevId: 213907565
Diffstat (limited to 'tensorflow/core/framework/model.cc')
-rw-r--r--tensorflow/core/framework/model.cc204
1 files changed, 129 insertions, 75 deletions
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
index 112298c344..b0330ec990 100644
--- a/tensorflow/core/framework/model.cc
+++ b/tensorflow/core/framework/model.cc
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -17,16 +17,14 @@ limitations under the License.
#include <memory>
-#include "tensorflow/core/lib/gtl/map_util.h"
-
namespace tensorflow {
namespace data {
namespace model {
// TODO(jsimsa): Use `Node` subclassing instead of types and node statements.
-void Node::CollectTunables(
+void Model::Node::CollectTunables(
std::vector<std::shared_ptr<Node::Tunable>>* tunables) {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (auto input : inputs_) {
input->CollectTunables(tunables);
}
@@ -45,14 +43,14 @@ void Node::CollectTunables(
}
}
-int64 Node::GetParameterValue(const string& name) {
+int64 Model::Node::GetParameterValue(const string& name) {
if (auto* tunable_param = gtl::FindOrNull(tunable_params_, name)) {
return (*tunable_param)->value;
}
return constant_params_[name];
}
-int64 Node::ProcessingTimeLocked() {
+int64 Model::Node::ProcessingTimeLocked() {
switch (type_) {
case Type::BATCH:
case Type::MAP_AND_BATCH:
@@ -101,7 +99,7 @@ int64 Node::ProcessingTimeLocked() {
}
}
-int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
+int64 Model::Node::OutputTimeLocked(std::vector<int64>* input_times) {
switch (type_) {
case Type::BATCH:
case Type::PADDED_BATCH: {
@@ -251,15 +249,34 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
}
}
-std::shared_ptr<Node> Model::AddNode(const string& name,
- const string& output_name) {
- mutex_lock l(mu_);
+void Model::AddConstantParameter(const string& node_name,
+ const string& parameter_name, int64 value) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, node_name);
+ if (node) {
+ (*node)->add_constant_param(parameter_name, value);
+ }
+}
+
+void Model::AddNode(const string& name, const string& output_name) {
+ // The name captures the sequence of iterators joined by `::`. We use the full
+ // sequence as the key in the lookup table, but only the last element of the
+ // sequence as the name node.
+ std::vector<string> tokens =
+ str_util::Split(name, ':', str_util::SkipEmpty());
+ // The output name might contain an index. We need to strip it to make it
+ // possible for the model to successfully identify the output node.
+ string sanitized_output_name = output_name;
+ if (str_util::EndsWith(output_name, "]")) {
+ sanitized_output_name = output_name.substr(0, output_name.rfind('['));
+ }
std::shared_ptr<Node> output;
- auto it = lookup_table_.find(output_name);
+ mutex_lock l(mu_);
+ auto it = lookup_table_.find(sanitized_output_name);
if (it != lookup_table_.end()) {
output = it->second;
}
- std::shared_ptr<Node> node(new Node(id_counter_++, output));
+ std::shared_ptr<Node> node(new Node(id_counter_++, tokens.back(), output));
if (!output_) {
output_ = node;
}
@@ -267,88 +284,125 @@ std::shared_ptr<Node> Model::AddNode(const string& name,
output->add_input(node);
}
lookup_table_.insert(std::make_pair(name, node));
- return node;
}
-std::shared_ptr<Node> Model::LookupNode(const string& name) {
+void Model::AddProcessingTime(const string& name, int64 delta) {
tf_shared_lock l(mu_);
- std::shared_ptr<Node> result;
- auto it = lookup_table_.find(name);
- if (it != lookup_table_.end()) {
- result = it->second;
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->add_processing_time(delta);
}
- return result;
+}
+
+void Model::AddTunableParameter(const string& node_name,
+ const string& parameter_name,
+ std::atomic<int64>* value, int64 min, int64 max,
+ condition_variable* cond_var) {
+ tf_shared_lock l(mu_);
+ auto node = *gtl::FindOrNull(lookup_table_, node_name);
+ DCHECK(node);
+ node->add_tunable_param(parameter_name, value, min, max, cond_var);
}
// The optimization algorithm starts by setting all tunable parallelism
-// parameters to 1. It then repeatedly identifies the parameter that whose
-// increase in parallelism decreases the output time the most. This process is
-// repeated until all parameters reach their maximum values or the
-// projected output time is less than or equal to the processing time needed to
-// produce an element divided by CPU budget.
+// parameters to 1. It then repeatedly identifies the parameter whose increase
+// in parallelism decreases the output time the most. This process is repeated
+// until all parameters reach their maximum values or the projected output time
+// is less than or equal to the processing time needed to produce an element
+// divided by CPU budget.
void Model::Optimize(int64 cpu_budget) {
- mutex_lock l(optimization_mu_);
- std::vector<std::shared_ptr<Node::Tunable>> tunables;
- {
- mutex_lock l2(mu_);
- const int64 processing_time = ProcessingTime();
- tunables = CollectTunables();
- for (auto tunable : tunables) {
- tunable->value = 1;
- }
- while (true) {
- const int64 output_time = OutputTime();
- bool all_tunables = true;
- for (auto& tunable : tunables) {
- if (tunable->value < tunable->max) {
- all_tunables = false;
- break;
- }
- }
- if (output_time < processing_time / cpu_budget || all_tunables) {
+ tf_shared_lock lock(mu_);
+ std::vector<std::shared_ptr<Model::Node::Tunable>> tunables;
+ const int64 processing_time = ProcessingTime();
+ tunables = CollectTunables();
+ for (auto tunable : tunables) {
+ tunable->value = 1;
+ }
+ while (true) {
+ const int64 output_time = OutputTime();
+ bool all_tunables = true;
+ for (auto& tunable : tunables) {
+ if (tunable->value < tunable->max) {
+ all_tunables = false;
break;
}
- int64 best_delta = -1;
- Node::Tunable* best_tunable = nullptr;
- for (auto& tunable : tunables) {
- if (tunable->value == tunable->max) {
- continue;
- }
- tunable->value++;
- int64 delta = output_time - OutputTime();
- if (delta > best_delta) {
- best_delta = delta;
- best_tunable = tunable.get();
- }
- tunable->value--;
+ }
+ if (output_time < processing_time / cpu_budget || all_tunables) {
+ break;
+ }
+ int64 best_delta = -1;
+ Model::Node::Tunable* best_tunable = nullptr;
+ for (auto& tunable : tunables) {
+ if (tunable->value == tunable->max) {
+ continue;
}
- if (!best_tunable) {
- // NOTE: This can happen because we are performing the optimization
- // while the model data is changing. If this becomes an issue, we should
- // look into performing the optimization using a model snapshot.
- break;
+ tunable->value++;
+ int64 delta = output_time - OutputTime();
+ if (delta > best_delta) {
+ best_delta = delta;
+ best_tunable = tunable.get();
}
- best_tunable->value++;
+ tunable->value--;
+ }
+ if (!best_tunable) {
+ // NOTE: This can happen because we are performing the optimization
+ // while the model data is changing. If this becomes an issue, we should
+ // look into performing the optimization using a model snapshot.
+ break;
}
+ best_tunable->value++;
}
- // The `set_fn` functions should be invoked without holding a lock to avoid a
- // potential deadlock.
+ VLOG(2) << "Number of knobs: " << tunables.size();
for (auto& tunable : tunables) {
- tunable->set_fn(tunable->value);
+ VLOG(2) << "Setting tunable parameter: " << tunable->value;
+ tunable->value_ptr->store(tunable->value);
+ if (tunable->cond_var) {
+ tunable->cond_var->notify_all();
+ }
+ }
+}
+
+void Model::RecordElement(const string& name) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->record_element();
}
}
-void Model::RemoveNode(const string& prefix) {
- // Nodes are not allowed to be removed when optimization is in progress to
- // prevent the optimization from trying to access an iterator that was
- // concurrently deleted.
- mutex_lock l(optimization_mu_);
- mutex_lock l2(mu_);
- lookup_table_.erase(prefix);
+void Model::RecordStart(const string& name, bool stop_output) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ if (stop_output && (*node)->output()) {
+ (*node)->output()->record_stop();
+ }
+ (*node)->record_start();
+ }
+}
+
+void Model::RecordStop(const string& name, bool start_output) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->record_stop();
+ if (start_output && (*node)->output()) {
+ (*node)->output()->record_start();
+ }
+ }
+}
+
+void Model::RemoveNode(const string& name) {
+ mutex_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node && (*node)->output()) {
+ (*node)->output()->remove_input(*node);
+ }
+ lookup_table_.erase(name);
}
-std::vector<std::shared_ptr<Node::Tunable>> Model::CollectTunables() {
- std::vector<std::shared_ptr<Node::Tunable>> tunables;
+std::vector<std::shared_ptr<Model::Node::Tunable>> Model::CollectTunables() {
+ std::vector<std::shared_ptr<Model::Node::Tunable>> tunables;
output_->CollectTunables(&tunables);
return tunables;
}