diff options
author | Jiri Simsa <jsimsa@google.com> | 2018-09-11 17:50:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-11 17:55:22 -0700 |
commit | 683cf4eb603defd7b55a83bbe0e0f335d7ab6354 (patch) | |
tree | c6c7894c51861922d478fc7b57343535d714c778 /tensorflow/core/framework | |
parent | d77ec7f18fe9f4b03f7259a0003b966b6be28d03 (diff) |
[tf.data] Mechanism for collecting processing time information and modeling performance.
PiperOrigin-RevId: 212557406
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/dataset.h | 108 | ||||
-rw-r--r-- | tensorflow/core/framework/model.cc | 396 | ||||
-rw-r--r-- | tensorflow/core/framework/model.h | 396 | ||||
-rw-r--r-- | tensorflow/core/framework/model.proto | 30 |
4 files changed, 925 insertions, 5 deletions
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 4e51fba048..4ee6749eea 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset_stateful_op_whitelist.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/model.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -291,6 +292,9 @@ class IteratorContext { // The Allocator to be used to allocate the output of an iterator. std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr; + + // If non-null, identifies the object used for performance modeling. + std::shared_ptr<model::Model> model = nullptr; }; explicit IteratorContext(Params params) : params_(std::move(params)) {} @@ -342,6 +346,10 @@ class IteratorContext { return params_.stats_aggregator_getter; } + std::shared_ptr<model::Model> model() { return params_.model; } + + Params params() { return params_; } + private: Params params_; }; @@ -376,7 +384,11 @@ class SerializationContext { // defined below. class IteratorBase { public: - virtual ~IteratorBase() {} + virtual ~IteratorBase() { + for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) { + (*rit)(); + } + } // Gets the next output from the range that this iterator is traversing. // @@ -410,6 +422,10 @@ class IteratorBase { // in the outputs of this iterator. virtual const std::vector<PartialTensorShape>& output_shapes() const = 0; + // Returns a string that identifies the sequence of iterators leading up to + // this iterator. + virtual const string& prefix() const = 0; + // Performs initialization that needs to happen outside of a constructor to // properly propagate errors. virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); } @@ -449,6 +465,18 @@ class IteratorBase { IteratorStateReader* reader) { return errors::Unimplemented("RestoreInternal"); } + + private: + friend class DatasetBase; // for access to `AddCleanupFunction` + + // Registers a cleanup function to be called upon object destruction. + // + // Registered functions are invoked in the reserve order of registration. + void AddCleanupFunction(std::function<void()>&& cleanup_fn) { + cleanup_fns_.push_back(std::move(cleanup_fn)); + } + + std::vector<std::function<void()>> cleanup_fns_; }; // Represents runtime information needed to construct a dataset. @@ -498,6 +526,27 @@ class DatasetBase : public core::RefCounted { Status MakeIterator(IteratorContext* ctx, const string& prefix, std::unique_ptr<IteratorBase>* iterator) const { *iterator = MakeIteratorInternal(prefix); + if (ctx->model()) { + // The prefix might contain an index. We need to strip it to make it + // possible for the model to successfully identify the output node. + string sanitized_prefix = prefix; + if (str_util::EndsWith(prefix, "]")) { + sanitized_prefix = prefix.substr(0, prefix.rfind('[')); + } + std::shared_ptr<model::Node> node = + ctx->model()->AddNode((*iterator)->prefix(), sanitized_prefix); + std::vector<string> tokens = + str_util::Split((*iterator)->prefix(), ':', str_util::SkipEmpty()); + node->set_name(tokens[tokens.size() - 1]); + std::shared_ptr<model::Model> model = ctx->model(); + const string& prefix = (*iterator)->prefix(); + (*iterator)->AddCleanupFunction([model, node, prefix]() { + if (node->output()) { + node->output()->remove_input(node); + } + model->RemoveNode(prefix); + }); + } return (*iterator)->Initialize(ctx); } @@ -524,6 +573,8 @@ class DatasetBase : public core::RefCounted { IteratorStateWriter* writer) const; protected: + friend class DatasetToGraphOp; // For access to graph related members. + class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { public: DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {} @@ -541,8 +592,6 @@ class DatasetBase : public core::RefCounted { virtual std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const = 0; - friend class DatasetToGraphOp; // For access to graph related members. - private: const string name_; }; @@ -565,7 +614,7 @@ class DatasetBaseIterator : public IteratorBase { ~DatasetBaseIterator() override { params_.dataset->Unref(); } // The sequence of iterators leading up to this iterator. - const string& prefix() const { return params_.prefix; } + const string& prefix() const override { return params_.prefix; } const DataTypeVector& output_dtypes() const override { return params_.dataset->output_dtypes(); @@ -578,7 +627,23 @@ class DatasetBaseIterator : public IteratorBase { Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors, bool* end_of_sequence) final { tracing::ScopedActivity activity(params_.prefix); - Status s = GetNextInternal(ctx, out_tensors, end_of_sequence); + Status s; + if (ctx->model()) { + std::shared_ptr<model::Node> node = + ctx->model()->LookupNode(params_.prefix); + if (node->output()) { + node->output()->stop_work(); + } + node->start_work(); + s = GetNextInternal(ctx, out_tensors, end_of_sequence); + node->stop_work(); + node->add_element(); + if (node->output()) { + node->output()->start_work(); + } + } else { + s = GetNextInternal(ctx, out_tensors, end_of_sequence); + } if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) { s = errors::Internal( "Iterator \"", params_.prefix, @@ -605,6 +670,39 @@ class DatasetBaseIterator : public IteratorBase { return strings::StrCat(params_.prefix, ":", name); } + // When performance modeling is enabled, this method sets metadata entry for + // the model node corresponding to this iterator. + void SetMetadata(IteratorContext* ctx, const string& key, int64 value) { + if (ctx->model()) { + std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix()); + if (node) { + node->set_metadata(key, value); + } + } + } + + // When performance modeling is enabled, this method records the fact that + // a thread of this iterator has started work. + void StartWork(IteratorContext* ctx) { + if (ctx->model()) { + std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix()); + if (node) { + node->start_work(); + } + } + } + + // When performance modeling is enabled, this method records the fact that + // a thread of this iterator has stopped work. + void StopWork(IteratorContext* ctx) { + if (ctx->model()) { + std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix()); + if (node) { + node->stop_work(); + } + } + } + private: BaseParams params_; }; diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc new file mode 100644 index 0000000000..250b006641 --- /dev/null +++ b/tensorflow/core/framework/model.cc @@ -0,0 +1,396 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/model.h" + +namespace tensorflow { +namespace data { +namespace model { + +// TODO(jsimsa): Use `Node` subclassing instead of types and node statements. +void Node::CollectKnobs(std::vector<Node::Knob>* knobs) { + mutex_lock l(mu_); + switch (type_) { + case Type::PARALLEL_INTERLEAVE_V2: { + for (auto input : inputs_) { + input->CollectKnobs(knobs); + } + int64 processing_time = static_cast<int64>( + static_cast<double>(ProcessingTimeLocked() - + inputs_.front()->ProcessingTime()) / + static_cast<double>(inputs_.size() - 1)); + knobs->emplace_back( + Node::Knob{this, processing_time, metadata_["parallelism"]}); + return; + } + case Type::MAP_AND_BATCH: + case Type::PARALLEL_MAP: { + for (auto input : inputs_) { + input->CollectKnobs(knobs); + } + knobs->emplace_back( + Node::Knob{this, NanosPerElementLocked(), metadata_["parallelism"]}); + return; + } + case Type::BATCH: + case Type::CACHE: + case Type::CONCATENATE: + case Type::FILTER: + case Type::FLAT_MAP: + case Type::INTERLEAVE: + case Type::MAP: + case Type::PADDED_BATCH: + case Type::PARALLEL_INTERLEAVE: + case Type::PREFETCH: + case Type::REPEAT: + case Type::SHUFFLE: + case Type::SKIP: + case Type::TAKE: + case Type::ZIP: { + for (auto input : inputs_) { + input->CollectKnobs(knobs); + } + return; + } + default: + return; + } +} + +int64 Node::ProcessingTimeLocked() { + switch (type_) { + case Type::BATCH: + case Type::MAP_AND_BATCH: + case Type::PADDED_BATCH: { + int64 batch_size = metadata_["batch_size"]; + return NanosPerElementLocked() + batch_size * ProcessingTimeForInputs(); + } + case Type::FILTER: { + std::shared_ptr<Node> input = inputs_.front(); + double ratio = static_cast<double>(input->num_elements()) / + static_cast<double>(num_elements_); + return NanosPerElementLocked() + + static_cast<int64>(ratio * + static_cast<double>(ProcessingTimeForInputs())); + } + case Type::FLAT_MAP: + case Type::INTERLEAVE: + case Type::PARALLEL_INTERLEAVE: + case Type::PARALLEL_INTERLEAVE_V2: { + // TODO(jsimsa): model the first input + // TODO(jsimsa): use processing time history as a prior for future inputs + if (inputs_.size() <= 1) { + return NanosPerElementLocked(); + } + int64 processing_time = + ProcessingTimeForInputs() - inputs_.front()->ProcessingTime(); + return NanosPerElementLocked() + + static_cast<double>(processing_time) / + static_cast<double>(inputs_.size() - 1); + } + case Type::CACHE: + case Type::CONCATENATE: + case Type::MAP: + case Type::PARALLEL_MAP: + case Type::PREFETCH: + // TODO(jsimsa): use processing time history as a prior for future inputs + case Type::REPEAT: + case Type::SHUFFLE: + case Type::SKIP: + case Type::TAKE: + case Type::ZIP: { + return NanosPerElementLocked() + ProcessingTimeForInputs(); + } + default: + return NanosPerElementLocked(); + } +} + +int64 Node::OutputTimeLocked(std::vector<int64>* input_times) { + switch (type_) { + case Type::BATCH: + case Type::PADDED_BATCH: { + double batch_size = metadata_["batch_size"]; + int64 old_value = (*input_times)[input_times->size() - 1]; + (*input_times)[input_times->size() - 1] = static_cast<int64>( + static_cast<double>(old_value + NanosPerElementLocked()) / + batch_size); + auto cleanup = gtl::MakeCleanup([input_times, old_value]() { + (*input_times)[input_times->size() - 1] = old_value; + }); + return NanosPerElementLocked() + + batch_size * OutputTimeForInputs(input_times); + } + case Type::FILTER: { + std::shared_ptr<Node> input = inputs_.front(); + int64 old_value = (*input_times)[input_times->size() - 1]; + double ratio = static_cast<double>(input->num_elements()) / + static_cast<double>(num_elements_); + (*input_times)[input_times->size() - 1] = static_cast<int64>( + static_cast<double>(old_value + NanosPerElementLocked()) / ratio); + auto cleanup = gtl::MakeCleanup([input_times, old_value]() { + (*input_times)[input_times->size() - 1] = old_value; + }); + return NanosPerElementLocked() + + static_cast<int64>( + static_cast<double>(OutputTimeForInputs(input_times)) * ratio); + } + case Type::FLAT_MAP: + case Type::INTERLEAVE: { + // TODO(jsimsa): model the first input + // TODO(jsimsa): use cycle length metadata instead of `inputs_.size() - 1` + if (inputs_.size() <= 1) { + return NanosPerElementLocked(); + } + int64 delta = + static_cast<int64>(static_cast<double>(NanosPerElementLocked()) * + static_cast<double>(inputs_.size() - 1)); + (*input_times)[input_times->size() - 1] += delta; + auto cleanup = gtl::MakeCleanup([input_times, delta]() { + (*input_times)[input_times->size() - 1] -= delta; + }); + int64 output_time = OutputTimeForInputs(input_times) - + inputs_.front()->OutputTime(input_times); + return NanosPerElementLocked() + + static_cast<double>(output_time) / + static_cast<double>(inputs_.size() - 1); + } + case Type::MAP_AND_BATCH: { + double batch_size = metadata_["batch_size"]; + double parallelism = metadata_["parallelism"]; + int64 delta = + static_cast<int64>(static_cast<double>(NanosPerElementLocked()) / + (batch_size * parallelism)); + input_times->push_back(delta); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + int64 output_time = static_cast<int64>( + static_cast<double>(NanosPerElementLocked()) / parallelism + + batch_size * OutputTimeForInputs(input_times)); + return std::max(0LL, + output_time - input_times->at(input_times->size() - 2)); + } + case Type::PARALLEL_INTERLEAVE: + case Type::PARALLEL_INTERLEAVE_V2: { + // TODO(jsimsa): model the first input + if (inputs_.size() <= 1) { + return NanosPerElementLocked(); + } + int64 delta = + static_cast<int64>(static_cast<double>(NanosPerElementLocked()) * + static_cast<double>(inputs_.size() - 1)); + input_times->push_back(delta); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + int64 inputs_output_time = OutputTimeForInputs(input_times) - + inputs_.front()->OutputTime(input_times); + double parallelism = std::min(port::NumSchedulableCPUs(), + static_cast<int>(metadata_["parallelism"])); + int64 output_time = + NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) / + static_cast<double>(inputs_.size() - 1)) / + parallelism); + return std::max(0LL, + output_time - input_times->at(input_times->size() - 2)); + } + case Type::PARALLEL_MAP: { + double parallelism = std::min(port::NumSchedulableCPUs(), + static_cast<int>(metadata_["parallelism"])); + int64 delta = static_cast<int64>( + static_cast<double>(NanosPerElementLocked()) / parallelism); + input_times->push_back(delta); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + int64 output_time = + static_cast<double>(NanosPerElementLocked()) / parallelism + + OutputTimeForInputs(input_times); + return std::max(0LL, + output_time - input_times->at(input_times->size() - 2)); + } + case Type::PREFETCH: { + int64 delta = NanosPerElementLocked(); + input_times->push_back(delta); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + return std::max(0LL, NanosPerElementLocked() + + OutputTimeForInputs(input_times) - + input_times->at(input_times->size() - 2)); + } + case Type::CACHE: + case Type::CONCATENATE: + case Type::MAP: + case Type::REPEAT: + case Type::SHUFFLE: + case Type::SKIP: + case Type::TAKE: + case Type::ZIP: { + int64 delta = NanosPerElementLocked(); + (*input_times)[input_times->size() - 1] += delta; + auto cleanup = gtl::MakeCleanup([input_times, delta]() { + (*input_times)[input_times->size() - 1] -= delta; + }); + return NanosPerElementLocked() + OutputTimeForInputs(input_times); + } + default: + return NanosPerElementLocked(); + } +} + +Model::Model(const proto::Model& model_proto) { + id_counter_ = model_proto.id_counter(); + std::map<int64, std::shared_ptr<Node>> lookup_table; + for (auto node_proto : model_proto.node()) { + std::shared_ptr<Node> node(new Node(node_proto)); + lookup_table[node_proto.id()] = node; + } + for (auto node_proto : model_proto.node()) { + std::shared_ptr<Node> node = lookup_table[node_proto.id()]; + for (int64 id : node_proto.input()) { + node->add_input(lookup_table[id]); + } + node->set_output(lookup_table[node_proto.output()]); + } + output_ = lookup_table[model_proto.output()]; +} + +std::shared_ptr<Node> Model::AddNode(const string& name, + const string& output_name) { + mutex_lock l(mu_); + std::shared_ptr<Node> output; + auto it = lookup_table_.find(output_name); + if (it != lookup_table_.end()) { + output = it->second; + } + std::shared_ptr<Node> node(new Node(id_counter_++, output)); + if (!output_) { + output_ = node; + } + if (output) { + output->add_input(node); + } + lookup_table_.insert(std::make_pair(name, node)); + return node; +} + +std::shared_ptr<Node> Model::LookupNode(const string& name) { + tf_shared_lock l(mu_); + std::shared_ptr<Node> result; + auto it = lookup_table_.find(name); + if (it != lookup_table_.end()) { + result = it->second; + } + return result; +} + +void Model::Optimize() { + mutex_lock l(mu_); + int64 processing_time = ProcessingTime(); + int64 num_cpus = port::NumSchedulableCPUs(); + std::vector<Node::Knob> knobs = CollectKnobs(); + // The optimization algorithm starts by setting all parallelism knobs to 1. It + // then repeatedly identifies the knob that, when turned up by 1, decreases + // the output time the most. This process is repeated until all knobs reach + // the number of schedulable CPUs or the projected output time is less than or + // equal to the processing time needed to produce an element divided by the + // number of schedulable CPUs. + for (auto& knob : knobs) { + LOG(INFO) << knob.node->name() << " " << knob.processing_time; + knob.value = 1; + knob.node->set_metadata("parallelism", knob.value); + } + while (true) { + int64 output_time = OutputTime(); + bool all_knobs = true; + for (auto knob : knobs) { + if (knob.value < num_cpus) { + all_knobs = false; + break; + } + } + if (output_time < processing_time / num_cpus || all_knobs) { + break; + } + int64 best_delta = -1; + int best_knob = -1; + for (int i = 0; i < knobs.size(); ++i) { + if (knobs[i].value == num_cpus) { + continue; + } + knobs[i].node->set_metadata("parallelism", knobs[i].value + 1); + int64 delta = output_time - OutputTime(); + if (delta > best_delta) { + best_delta = delta; + best_knob = i; + } + knobs[i].node->set_metadata("parallelism", knobs[i].value); + } + knobs[best_knob].value++; + knobs[best_knob].node->set_metadata("parallelism", knobs[best_knob].value); + } + for (auto knob : knobs) { + LOG(INFO) << knob.node->name() << " " << knob.value; + } + LOG(INFO) << "output time: " << OutputTime(); + LOG(INFO) << "processing time: " << ProcessingTime(); +} + +void Model::OutputToFile() { + proto::Model model_proto; + ToProto(&model_proto); + string filename; + Env::Default()->LocalTempFilename(&filename); + TF_CHECK_OK(WriteStringToFile(Env::Default(), filename, + model_proto.SerializeAsString())); + LOG(INFO) << filename; +} + +void Model::RemoveNode(const string& prefix) { + mutex_lock l(mu_); + lookup_table_.erase(prefix); +} + +void Model::ToProto(proto::Model* model_proto) { + mutex_lock l(mu_); + model_proto->set_id_counter(id_counter_); + model_proto->set_output(output_->id()); + AddNodeToProto(output_, model_proto); +} + +// static +void Model::AddNodeToProto(const std::shared_ptr<Node>& node, + proto::Model* model_proto) { + proto::Node* node_proto = model_proto->add_node(); + node->ToProto(node_proto); + for (const std::shared_ptr<Node>& input : node->inputs()) { + AddNodeToProto(input, model_proto); + } +} + +std::vector<Node::Knob> Model::CollectKnobs() { + std::vector<Node::Knob> knobs; + output_->CollectKnobs(&knobs); + return knobs; +} + +int64 Model::OutputTime() { + std::vector<int64> input_times(1, 0); + return output_->OutputTime(&input_times); +} + +int64 Model::ProcessingTime() { return output_->ProcessingTime(); } + +} // namespace model +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h new file mode 100644 index 0000000000..98172909bf --- /dev/null +++ b/tensorflow/core/framework/model.h @@ -0,0 +1,396 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ + +#include <list> +#include <memory> +#include <string> +#include <thread> // (b/114492873): move this include into core/platform +#include <utility> +#include <vector> + +#include "tensorflow/core/framework/model.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { +namespace data { +namespace model { + +class Model; +class Node; + +// Abstract representation of a TensorFlow input pipeline node. It collects +// information about inputs to this node, processing time spent executing the +// node logic, number of elements produced by the node, various other +// information (e.g. batch size or execution parallelism). +// +// Developers of tf.data transformations are not expected to interact with this +// class directly. Boiler plate code for creating the abstract representation of +// the input pipeline and collecting common information has been added to the +// implementation of `DatasetBase` and `DatasetBaseIterator` respectively. +// +// In addition, `DatasetBaseIterator` provides wrappers that can be used for +// transformation-specific information collection. The `SetMetadata` wrapper can +// be used to pass arbitrary metadata to the modeling framework, while the +// `StartWork` and `StopWork` wrappers should be used to correctly account for +// processing time of multi-threaded transformation that yield the CPU; such +// transformations should invoke `StartWork()` when a transformation thread +// starts executing (e.g. when created or woken up) and `StopWork()` when a +// transformation thread stops executing (e.g. when returning or waiting). +// +// TODO(jsimsa): Create an API to capture the abstract semantics of each +// tf.data transformation and replace switch-case blocks with inheritance. +class Node { + public: + Node(int64 id, std::shared_ptr<Node> output) : id_(id), output_(output) {} + + explicit Node(const proto::Node& node_proto) : id_(node_proto.id()) { + name_ = node_proto.name(); + type_ = TypeFromName(node_proto.name()); + processing_time_ = node_proto.processing_time(); + num_elements_ = node_proto.num_elements(); + metadata_.insert(node_proto.metadata().begin(), + node_proto.metadata().end()); + } + + // Records that the node produced an element. + void add_element() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + num_elements_++; + } + + // Adds an input. + void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + inputs_.push_back(node); + } + + // Increments the aggregate processing time by the given delta. + void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + processing_time_ += delta; + } + + // Returns the unique node ID. + int64 id() LOCKS_EXCLUDED(mu_) { return id_; } + + // Returns the node inputs. + std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return inputs_; + } + + // Returns the node name. + const string& name() LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return name_; + } + + // Returns the number of elements produced by the node. + int64 num_elements() LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return num_elements_; + } + + // Returns the node output. + std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return output_; + } + + // Removes an input. + void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + inputs_.remove(input); + } + + // Adds the given key-value pair to the node metadata. + void set_metadata(const string& key, int64 value) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + metadata_[key] = value; + } + + // Sets the node name. + void set_name(const string& name) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + name_ = name; + type_ = TypeFromName(name); + } + + // Set the node output. + void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + output_ = output; + } + + // Records that a node thread has started work. + void start_work() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos(); + } + + // Records that a node thread has stopped work. + void stop_work() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + auto iter = work_start_.find(std::this_thread::get_id()); + CHECK(work_start_.end() != iter) + << "Encountered a stop event that was not preceded by a start event."; + processing_time_ += Env::Default()->NowNanos() - iter->second; + work_start_.erase(iter); + } + + private: + // Represents a performance knob. + struct Knob { + Node* node; + int64 processing_time; + int64 value; + }; + + enum class Type { + BATCH = 0, + CACHE, + CONCATENATE, + FILTER, + FLAT_MAP, + INTERLEAVE, + MAP, + MAP_AND_BATCH, + PADDED_BATCH, + PARALLEL_INTERLEAVE, + PARALLEL_INTERLEAVE_V2, + PARALLEL_MAP, + PREFETCH, + REPEAT, + SHUFFLE, + SKIP, + TAKE, + ZIP, + UNKNOWN, + }; + + // Collects performance knobs in the subtree rooted in this node. + void CollectKnobs(std::vector<Node::Knob>* knobs) LOCKS_EXCLUDED(mu_); + + // Returns the per-element processing time spent in this node. + int64 NanosPerElement() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + return NanosPerElementLocked(); + } + + int64 NanosPerElementLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (num_elements_ == 0) { + return 0; + } + return (int64)((double)processing_time_ / (double)num_elements_); + } + + // Returns the per-element output time for this node. + int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + return OutputTimeLocked(input_times); + } + + int64 OutputTimeLocked(std::vector<int64>* input_times) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + int64 OutputTimeForInputs(std::vector<int64>* input_times) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int64 sum = 0; + for (auto input : inputs_) { + sum += input->OutputTime(input_times); + } + return sum; + } + + // Returns the per-element processing time spent in the subtree rooted in this + // node. + int64 ProcessingTime() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + return ProcessingTimeLocked(); + } + + int64 ProcessingTimeLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Returns the per-element processing time spent in the inputs of this node. + int64 ProcessingTimeForInputs() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int64 sum = 0; + for (auto input : inputs_) { + sum += input->ProcessingTimeLocked(); + } + return sum; + } + + // Serializes the node state into the given proto. + void ToProto(proto::Node* node_proto) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + node_proto->set_id(id_); + node_proto->set_name(name_); + node_proto->set_num_elements(num_elements_); + node_proto->set_processing_time(processing_time_); + for (const std::shared_ptr<Node>& input : inputs_) { + node_proto->add_input(input->id()); + } + if (output_) { + node_proto->set_output(output_->id()); + } + node_proto->mutable_metadata()->insert(metadata_.begin(), metadata_.end()); + } + + Type TypeFromName(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (name_ == "Batch") { + return Type::BATCH; + } + if (str_util::EndsWith(name_, "Cache")) { + return Type::CACHE; + } + if (name_ == "Concatenate") { + return Type::CONCATENATE; + } + if (name_ == "Filter") { + return Type::FILTER; + } + if (name_ == "FlatMap") { + return Type::FLAT_MAP; + } + if (name_ == "Interleave") { + return Type::INTERLEAVE; + } + if (name_ == "Map") { + return Type::MAP; + } + if (name_ == "MapAndBatch") { + return Type::MAP_AND_BATCH; + } + if (name_ == "PaddedBatch") { + return Type::PADDED_BATCH; + } + if (name_ == "ParallelInterleave") { + return Type::PARALLEL_INTERLEAVE; + } + if (name_ == "ParallelInterleaveV2") { + return Type::PARALLEL_INTERLEAVE_V2; + } + if (name_ == "ParallelMap") { + return Type::PARALLEL_MAP; + } + if (name_ == "Prefetch") { + return Type::PREFETCH; + } + if (str_util::EndsWith(name_, "Repeat")) { + return Type::REPEAT; + } + if (name_ == "Shuffle") { + return Type::SHUFFLE; + } + if (str_util::EndsWith(name_, "Skip")) { + return Type::SKIP; + } + if (str_util::EndsWith(name_, "Take")) { + return Type::TAKE; + } + if (name_ == "Zip") { + return Type::ZIP; + } + return Type::UNKNOWN; + } + + mutex mu_; + const int64 id_; + Type type_ GUARDED_BY(mu_); + string name_ GUARDED_BY(mu_); + int64 processing_time_ GUARDED_BY(mu_) = 0; + int64 num_elements_ GUARDED_BY(mu_) = 0; + std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_); + std::map<string, int64> metadata_ GUARDED_BY(mu_); + std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_); + std::shared_ptr<Node> output_ GUARDED_BY(mu_); + + friend class Model; +}; + +// Abstract representation of a TensorFlow input pipeline that can be used +// for collecting runtime information and optimizing performance. It collects +// runtime information about execution of the input pipeline that is used to +// create a performance model, which is in turn used to identify optimal values +// of performance knobs. +// +// Developers of tf.data transformations are not expected to interact with this +// class directly. Boiler plate code for creating the abstract representation of +// the input pipeline and collecting runtime information has been added to the +// implementation of `DatasetBase` and `DatasetBaseIterator` respectively. +// +// TODO(jsimsa): Add a mechanism for feeding the result of the optimization +// into the input pipeline. +class Model { + public: + Model() = default; + explicit Model(const proto::Model& model_proto); + + ~Model() {} + + // Returns the model output node. + std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return output_; + } + + // Adds a node with the given name and given output (identified by name). + std::shared_ptr<Node> AddNode(const string& name, const string& output_name) + LOCKS_EXCLUDED(mu_); + + // Looks up the node using the given name. + std::shared_ptr<Node> LookupNode(const string& name) LOCKS_EXCLUDED(mu_); + + // Runs optimization. + void Optimize() LOCKS_EXCLUDED(mu_); + + // Outputs the state of a model to a file. + // + // TODO(jsimsa): Remove this method once the optimization loop is closed. + void OutputToFile() LOCKS_EXCLUDED(mu_); + + // Removes the node identified by the given name. + void RemoveNode(const string& prefix) LOCKS_EXCLUDED(mu_); + + // Serializes the model state to the given proto. + void ToProto(proto::Model* model_proto) LOCKS_EXCLUDED(mu_); + + private: + static void AddNodeToProto(const std::shared_ptr<Node>& node, + proto::Model* model_proto); + + std::vector<Node::Knob> CollectKnobs() EXCLUSIVE_LOCKS_REQUIRED(mu_); + + int64 OutputTime() EXCLUSIVE_LOCKS_REQUIRED(mu_); + + int64 ProcessingTime() EXCLUSIVE_LOCKS_REQUIRED(mu_); + + mutex mu_; + int64 id_counter_ GUARDED_BY(mu_) = 1; + std::shared_ptr<Node> output_ GUARDED_BY(mu_); + std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_); +}; + +} // namespace model +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ diff --git a/tensorflow/core/framework/model.proto b/tensorflow/core/framework/model.proto new file mode 100644 index 0000000000..26000007af --- /dev/null +++ b/tensorflow/core/framework/model.proto @@ -0,0 +1,30 @@ +syntax = "proto3"; + +package tensorflow.data.model.proto; +option cc_enable_arenas = true; + +message Model { + // Counter used for generating new node IDs. + int64 id_counter = 1; + // Nodes of this model. + repeated Node node = 2; + // The ID of the output node. + int64 output = 3; +}; + +message Node { + // The node ID. + int64 id = 1; + // The node name. + string name = 2; + // Input node IDs. + repeated int64 input = 3; + // Output node ID. + int64 output = 4; + // Number of elements produced by the node. + int64 num_elements = 5; + // The CPU time spent by running threads of this node. + int64 processing_time = 6; + // Key-value store for node metadata (e.g. batch size or parallelism). + map<string, int32> metadata = 7; +}; |