diff options
author | Jiri Simsa <jsimsa@google.com> | 2018-09-20 16:48:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-20 16:53:24 -0700 |
commit | 0e1efc3d9129c740a16081fdc53bdc482f8f0c11 (patch) | |
tree | fd153014209d962e8fccf90c31a372bc59ec38a9 /tensorflow/core/framework/model.cc | |
parent | 863b3bd6bb2214065f95fa20f551a8fc8568e55d (diff) |
[tf.data] Moving auto-tuning optimizations into a background thread, refactoring the API for exposing tunable parameters, and removing `model::Node` from the public API.
PiperOrigin-RevId: 213907565
Diffstat (limited to 'tensorflow/core/framework/model.cc')
-rw-r--r-- | tensorflow/core/framework/model.cc | 204 |
1 files changed, 129 insertions, 75 deletions
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index 112298c344..b0330ec990 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,16 +17,14 @@ limitations under the License. #include <memory> -#include "tensorflow/core/lib/gtl/map_util.h" - namespace tensorflow { namespace data { namespace model { // TODO(jsimsa): Use `Node` subclassing instead of types and node statements. -void Node::CollectTunables( +void Model::Node::CollectTunables( std::vector<std::shared_ptr<Node::Tunable>>* tunables) { - mutex_lock l(mu_); + tf_shared_lock l(mu_); for (auto input : inputs_) { input->CollectTunables(tunables); } @@ -45,14 +43,14 @@ void Node::CollectTunables( } } -int64 Node::GetParameterValue(const string& name) { +int64 Model::Node::GetParameterValue(const string& name) { if (auto* tunable_param = gtl::FindOrNull(tunable_params_, name)) { return (*tunable_param)->value; } return constant_params_[name]; } -int64 Node::ProcessingTimeLocked() { +int64 Model::Node::ProcessingTimeLocked() { switch (type_) { case Type::BATCH: case Type::MAP_AND_BATCH: @@ -101,7 +99,7 @@ int64 Node::ProcessingTimeLocked() { } } -int64 Node::OutputTimeLocked(std::vector<int64>* input_times) { +int64 Model::Node::OutputTimeLocked(std::vector<int64>* input_times) { switch (type_) { case Type::BATCH: case Type::PADDED_BATCH: { @@ -251,15 +249,34 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) { } } -std::shared_ptr<Node> Model::AddNode(const string& name, - const string& output_name) { - mutex_lock l(mu_); +void Model::AddConstantParameter(const string& node_name, + const string& parameter_name, int64 value) { + tf_shared_lock l(mu_); + auto node = gtl::FindOrNull(lookup_table_, node_name); + if (node) { + (*node)->add_constant_param(parameter_name, value); + } +} + +void Model::AddNode(const string& name, const string& output_name) { + // The name captures the sequence of iterators joined by `::`. We use the full + // sequence as the key in the lookup table, but only the last element of the + // sequence as the name node. + std::vector<string> tokens = + str_util::Split(name, ':', str_util::SkipEmpty()); + // The output name might contain an index. We need to strip it to make it + // possible for the model to successfully identify the output node. + string sanitized_output_name = output_name; + if (str_util::EndsWith(output_name, "]")) { + sanitized_output_name = output_name.substr(0, output_name.rfind('[')); + } std::shared_ptr<Node> output; - auto it = lookup_table_.find(output_name); + mutex_lock l(mu_); + auto it = lookup_table_.find(sanitized_output_name); if (it != lookup_table_.end()) { output = it->second; } - std::shared_ptr<Node> node(new Node(id_counter_++, output)); + std::shared_ptr<Node> node(new Node(id_counter_++, tokens.back(), output)); if (!output_) { output_ = node; } @@ -267,88 +284,125 @@ std::shared_ptr<Node> Model::AddNode(const string& name, output->add_input(node); } lookup_table_.insert(std::make_pair(name, node)); - return node; } -std::shared_ptr<Node> Model::LookupNode(const string& name) { +void Model::AddProcessingTime(const string& name, int64 delta) { tf_shared_lock l(mu_); - std::shared_ptr<Node> result; - auto it = lookup_table_.find(name); - if (it != lookup_table_.end()) { - result = it->second; + auto node = gtl::FindOrNull(lookup_table_, name); + if (node) { + (*node)->add_processing_time(delta); } - return result; +} + +void Model::AddTunableParameter(const string& node_name, + const string& parameter_name, + std::atomic<int64>* value, int64 min, int64 max, + condition_variable* cond_var) { + tf_shared_lock l(mu_); + auto node = *gtl::FindOrNull(lookup_table_, node_name); + DCHECK(node); + node->add_tunable_param(parameter_name, value, min, max, cond_var); } // The optimization algorithm starts by setting all tunable parallelism -// parameters to 1. It then repeatedly identifies the parameter that whose -// increase in parallelism decreases the output time the most. This process is -// repeated until all parameters reach their maximum values or the -// projected output time is less than or equal to the processing time needed to -// produce an element divided by CPU budget. +// parameters to 1. It then repeatedly identifies the parameter whose increase +// in parallelism decreases the output time the most. This process is repeated +// until all parameters reach their maximum values or the projected output time +// is less than or equal to the processing time needed to produce an element +// divided by CPU budget. void Model::Optimize(int64 cpu_budget) { - mutex_lock l(optimization_mu_); - std::vector<std::shared_ptr<Node::Tunable>> tunables; - { - mutex_lock l2(mu_); - const int64 processing_time = ProcessingTime(); - tunables = CollectTunables(); - for (auto tunable : tunables) { - tunable->value = 1; - } - while (true) { - const int64 output_time = OutputTime(); - bool all_tunables = true; - for (auto& tunable : tunables) { - if (tunable->value < tunable->max) { - all_tunables = false; - break; - } - } - if (output_time < processing_time / cpu_budget || all_tunables) { + tf_shared_lock lock(mu_); + std::vector<std::shared_ptr<Model::Node::Tunable>> tunables; + const int64 processing_time = ProcessingTime(); + tunables = CollectTunables(); + for (auto tunable : tunables) { + tunable->value = 1; + } + while (true) { + const int64 output_time = OutputTime(); + bool all_tunables = true; + for (auto& tunable : tunables) { + if (tunable->value < tunable->max) { + all_tunables = false; break; } - int64 best_delta = -1; - Node::Tunable* best_tunable = nullptr; - for (auto& tunable : tunables) { - if (tunable->value == tunable->max) { - continue; - } - tunable->value++; - int64 delta = output_time - OutputTime(); - if (delta > best_delta) { - best_delta = delta; - best_tunable = tunable.get(); - } - tunable->value--; + } + if (output_time < processing_time / cpu_budget || all_tunables) { + break; + } + int64 best_delta = -1; + Model::Node::Tunable* best_tunable = nullptr; + for (auto& tunable : tunables) { + if (tunable->value == tunable->max) { + continue; } - if (!best_tunable) { - // NOTE: This can happen because we are performing the optimization - // while the model data is changing. If this becomes an issue, we should - // look into performing the optimization using a model snapshot. - break; + tunable->value++; + int64 delta = output_time - OutputTime(); + if (delta > best_delta) { + best_delta = delta; + best_tunable = tunable.get(); } - best_tunable->value++; + tunable->value--; + } + if (!best_tunable) { + // NOTE: This can happen because we are performing the optimization + // while the model data is changing. If this becomes an issue, we should + // look into performing the optimization using a model snapshot. + break; } + best_tunable->value++; } - // The `set_fn` functions should be invoked without holding a lock to avoid a - // potential deadlock. + VLOG(2) << "Number of knobs: " << tunables.size(); for (auto& tunable : tunables) { - tunable->set_fn(tunable->value); + VLOG(2) << "Setting tunable parameter: " << tunable->value; + tunable->value_ptr->store(tunable->value); + if (tunable->cond_var) { + tunable->cond_var->notify_all(); + } + } +} + +void Model::RecordElement(const string& name) { + tf_shared_lock l(mu_); + auto node = gtl::FindOrNull(lookup_table_, name); + if (node) { + (*node)->record_element(); } } -void Model::RemoveNode(const string& prefix) { - // Nodes are not allowed to be removed when optimization is in progress to - // prevent the optimization from trying to access an iterator that was - // concurrently deleted. - mutex_lock l(optimization_mu_); - mutex_lock l2(mu_); - lookup_table_.erase(prefix); +void Model::RecordStart(const string& name, bool stop_output) { + tf_shared_lock l(mu_); + auto node = gtl::FindOrNull(lookup_table_, name); + if (node) { + if (stop_output && (*node)->output()) { + (*node)->output()->record_stop(); + } + (*node)->record_start(); + } +} + +void Model::RecordStop(const string& name, bool start_output) { + tf_shared_lock l(mu_); + auto node = gtl::FindOrNull(lookup_table_, name); + if (node) { + (*node)->record_stop(); + if (start_output && (*node)->output()) { + (*node)->output()->record_start(); + } + } +} + +void Model::RemoveNode(const string& name) { + mutex_lock l(mu_); + auto node = gtl::FindOrNull(lookup_table_, name); + if (node && (*node)->output()) { + (*node)->output()->remove_input(*node); + } + lookup_table_.erase(name); } -std::vector<std::shared_ptr<Node::Tunable>> Model::CollectTunables() { - std::vector<std::shared_ptr<Node::Tunable>> tunables; +std::vector<std::shared_ptr<Model::Node::Tunable>> Model::CollectTunables() { + std::vector<std::shared_ptr<Model::Node::Tunable>> tunables; output_->CollectTunables(&tunables); return tunables; } |