aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-11 17:50:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-11 17:55:22 -0700
commit683cf4eb603defd7b55a83bbe0e0f335d7ab6354 (patch)
treec6c7894c51861922d478fc7b57343535d714c778 /tensorflow/core/framework
parentd77ec7f18fe9f4b03f7259a0003b966b6be28d03 (diff)
[tf.data] Mechanism for collecting processing time information and modeling performance.
PiperOrigin-RevId: 212557406
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/dataset.h108
-rw-r--r--tensorflow/core/framework/model.cc396
-rw-r--r--tensorflow/core/framework/model.h396
-rw-r--r--tensorflow/core/framework/model.proto30
4 files changed, 925 insertions, 5 deletions
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 4e51fba048..4ee6749eea 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/dataset_stateful_op_whitelist.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/model.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -291,6 +292,9 @@ class IteratorContext {
// The Allocator to be used to allocate the output of an iterator.
std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
+
+ // If non-null, identifies the object used for performance modeling.
+ std::shared_ptr<model::Model> model = nullptr;
};
explicit IteratorContext(Params params) : params_(std::move(params)) {}
@@ -342,6 +346,10 @@ class IteratorContext {
return params_.stats_aggregator_getter;
}
+ std::shared_ptr<model::Model> model() { return params_.model; }
+
+ Params params() { return params_; }
+
private:
Params params_;
};
@@ -376,7 +384,11 @@ class SerializationContext {
// defined below.
class IteratorBase {
public:
- virtual ~IteratorBase() {}
+ virtual ~IteratorBase() {
+ for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) {
+ (*rit)();
+ }
+ }
// Gets the next output from the range that this iterator is traversing.
//
@@ -410,6 +422,10 @@ class IteratorBase {
// in the outputs of this iterator.
virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
+ // Returns a string that identifies the sequence of iterators leading up to
+ // this iterator.
+ virtual const string& prefix() const = 0;
+
// Performs initialization that needs to happen outside of a constructor to
// properly propagate errors.
virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
@@ -449,6 +465,18 @@ class IteratorBase {
IteratorStateReader* reader) {
return errors::Unimplemented("RestoreInternal");
}
+
+ private:
+ friend class DatasetBase; // for access to `AddCleanupFunction`
+
+ // Registers a cleanup function to be called upon object destruction.
+ //
+ // Registered functions are invoked in the reserve order of registration.
+ void AddCleanupFunction(std::function<void()>&& cleanup_fn) {
+ cleanup_fns_.push_back(std::move(cleanup_fn));
+ }
+
+ std::vector<std::function<void()>> cleanup_fns_;
};
// Represents runtime information needed to construct a dataset.
@@ -498,6 +526,27 @@ class DatasetBase : public core::RefCounted {
Status MakeIterator(IteratorContext* ctx, const string& prefix,
std::unique_ptr<IteratorBase>* iterator) const {
*iterator = MakeIteratorInternal(prefix);
+ if (ctx->model()) {
+ // The prefix might contain an index. We need to strip it to make it
+ // possible for the model to successfully identify the output node.
+ string sanitized_prefix = prefix;
+ if (str_util::EndsWith(prefix, "]")) {
+ sanitized_prefix = prefix.substr(0, prefix.rfind('['));
+ }
+ std::shared_ptr<model::Node> node =
+ ctx->model()->AddNode((*iterator)->prefix(), sanitized_prefix);
+ std::vector<string> tokens =
+ str_util::Split((*iterator)->prefix(), ':', str_util::SkipEmpty());
+ node->set_name(tokens[tokens.size() - 1]);
+ std::shared_ptr<model::Model> model = ctx->model();
+ const string& prefix = (*iterator)->prefix();
+ (*iterator)->AddCleanupFunction([model, node, prefix]() {
+ if (node->output()) {
+ node->output()->remove_input(node);
+ }
+ model->RemoveNode(prefix);
+ });
+ }
return (*iterator)->Initialize(ctx);
}
@@ -524,6 +573,8 @@ class DatasetBase : public core::RefCounted {
IteratorStateWriter* writer) const;
protected:
+ friend class DatasetToGraphOp; // For access to graph related members.
+
class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
public:
DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {}
@@ -541,8 +592,6 @@ class DatasetBase : public core::RefCounted {
virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const = 0;
- friend class DatasetToGraphOp; // For access to graph related members.
-
private:
const string name_;
};
@@ -565,7 +614,7 @@ class DatasetBaseIterator : public IteratorBase {
~DatasetBaseIterator() override { params_.dataset->Unref(); }
// The sequence of iterators leading up to this iterator.
- const string& prefix() const { return params_.prefix; }
+ const string& prefix() const override { return params_.prefix; }
const DataTypeVector& output_dtypes() const override {
return params_.dataset->output_dtypes();
@@ -578,7 +627,23 @@ class DatasetBaseIterator : public IteratorBase {
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) final {
tracing::ScopedActivity activity(params_.prefix);
- Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+ Status s;
+ if (ctx->model()) {
+ std::shared_ptr<model::Node> node =
+ ctx->model()->LookupNode(params_.prefix);
+ if (node->output()) {
+ node->output()->stop_work();
+ }
+ node->start_work();
+ s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+ node->stop_work();
+ node->add_element();
+ if (node->output()) {
+ node->output()->start_work();
+ }
+ } else {
+ s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+ }
if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) {
s = errors::Internal(
"Iterator \"", params_.prefix,
@@ -605,6 +670,39 @@ class DatasetBaseIterator : public IteratorBase {
return strings::StrCat(params_.prefix, ":", name);
}
+ // When performance modeling is enabled, this method sets metadata entry for
+ // the model node corresponding to this iterator.
+ void SetMetadata(IteratorContext* ctx, const string& key, int64 value) {
+ if (ctx->model()) {
+ std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
+ if (node) {
+ node->set_metadata(key, value);
+ }
+ }
+ }
+
+ // When performance modeling is enabled, this method records the fact that
+ // a thread of this iterator has started work.
+ void StartWork(IteratorContext* ctx) {
+ if (ctx->model()) {
+ std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
+ if (node) {
+ node->start_work();
+ }
+ }
+ }
+
+ // When performance modeling is enabled, this method records the fact that
+ // a thread of this iterator has stopped work.
+ void StopWork(IteratorContext* ctx) {
+ if (ctx->model()) {
+ std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
+ if (node) {
+ node->stop_work();
+ }
+ }
+ }
+
private:
BaseParams params_;
};
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
new file mode 100644
index 0000000000..250b006641
--- /dev/null
+++ b/tensorflow/core/framework/model.cc
@@ -0,0 +1,396 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/model.h"
+
+namespace tensorflow {
+namespace data {
+namespace model {
+
+// TODO(jsimsa): Use `Node` subclassing instead of types and node statements.
+void Node::CollectKnobs(std::vector<Node::Knob>* knobs) {
+ mutex_lock l(mu_);
+ switch (type_) {
+ case Type::PARALLEL_INTERLEAVE_V2: {
+ for (auto input : inputs_) {
+ input->CollectKnobs(knobs);
+ }
+ int64 processing_time = static_cast<int64>(
+ static_cast<double>(ProcessingTimeLocked() -
+ inputs_.front()->ProcessingTime()) /
+ static_cast<double>(inputs_.size() - 1));
+ knobs->emplace_back(
+ Node::Knob{this, processing_time, metadata_["parallelism"]});
+ return;
+ }
+ case Type::MAP_AND_BATCH:
+ case Type::PARALLEL_MAP: {
+ for (auto input : inputs_) {
+ input->CollectKnobs(knobs);
+ }
+ knobs->emplace_back(
+ Node::Knob{this, NanosPerElementLocked(), metadata_["parallelism"]});
+ return;
+ }
+ case Type::BATCH:
+ case Type::CACHE:
+ case Type::CONCATENATE:
+ case Type::FILTER:
+ case Type::FLAT_MAP:
+ case Type::INTERLEAVE:
+ case Type::MAP:
+ case Type::PADDED_BATCH:
+ case Type::PARALLEL_INTERLEAVE:
+ case Type::PREFETCH:
+ case Type::REPEAT:
+ case Type::SHUFFLE:
+ case Type::SKIP:
+ case Type::TAKE:
+ case Type::ZIP: {
+ for (auto input : inputs_) {
+ input->CollectKnobs(knobs);
+ }
+ return;
+ }
+ default:
+ return;
+ }
+}
+
+int64 Node::ProcessingTimeLocked() {
+ switch (type_) {
+ case Type::BATCH:
+ case Type::MAP_AND_BATCH:
+ case Type::PADDED_BATCH: {
+ int64 batch_size = metadata_["batch_size"];
+ return NanosPerElementLocked() + batch_size * ProcessingTimeForInputs();
+ }
+ case Type::FILTER: {
+ std::shared_ptr<Node> input = inputs_.front();
+ double ratio = static_cast<double>(input->num_elements()) /
+ static_cast<double>(num_elements_);
+ return NanosPerElementLocked() +
+ static_cast<int64>(ratio *
+ static_cast<double>(ProcessingTimeForInputs()));
+ }
+ case Type::FLAT_MAP:
+ case Type::INTERLEAVE:
+ case Type::PARALLEL_INTERLEAVE:
+ case Type::PARALLEL_INTERLEAVE_V2: {
+ // TODO(jsimsa): model the first input
+ // TODO(jsimsa): use processing time history as a prior for future inputs
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 processing_time =
+ ProcessingTimeForInputs() - inputs_.front()->ProcessingTime();
+ return NanosPerElementLocked() +
+ static_cast<double>(processing_time) /
+ static_cast<double>(inputs_.size() - 1);
+ }
+ case Type::CACHE:
+ case Type::CONCATENATE:
+ case Type::MAP:
+ case Type::PARALLEL_MAP:
+ case Type::PREFETCH:
+ // TODO(jsimsa): use processing time history as a prior for future inputs
+ case Type::REPEAT:
+ case Type::SHUFFLE:
+ case Type::SKIP:
+ case Type::TAKE:
+ case Type::ZIP: {
+ return NanosPerElementLocked() + ProcessingTimeForInputs();
+ }
+ default:
+ return NanosPerElementLocked();
+ }
+}
+
+int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
+ switch (type_) {
+ case Type::BATCH:
+ case Type::PADDED_BATCH: {
+ double batch_size = metadata_["batch_size"];
+ int64 old_value = (*input_times)[input_times->size() - 1];
+ (*input_times)[input_times->size() - 1] = static_cast<int64>(
+ static_cast<double>(old_value + NanosPerElementLocked()) /
+ batch_size);
+ auto cleanup = gtl::MakeCleanup([input_times, old_value]() {
+ (*input_times)[input_times->size() - 1] = old_value;
+ });
+ return NanosPerElementLocked() +
+ batch_size * OutputTimeForInputs(input_times);
+ }
+ case Type::FILTER: {
+ std::shared_ptr<Node> input = inputs_.front();
+ int64 old_value = (*input_times)[input_times->size() - 1];
+ double ratio = static_cast<double>(input->num_elements()) /
+ static_cast<double>(num_elements_);
+ (*input_times)[input_times->size() - 1] = static_cast<int64>(
+ static_cast<double>(old_value + NanosPerElementLocked()) / ratio);
+ auto cleanup = gtl::MakeCleanup([input_times, old_value]() {
+ (*input_times)[input_times->size() - 1] = old_value;
+ });
+ return NanosPerElementLocked() +
+ static_cast<int64>(
+ static_cast<double>(OutputTimeForInputs(input_times)) * ratio);
+ }
+ case Type::FLAT_MAP:
+ case Type::INTERLEAVE: {
+ // TODO(jsimsa): model the first input
+ // TODO(jsimsa): use cycle length metadata instead of `inputs_.size() - 1`
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 delta =
+ static_cast<int64>(static_cast<double>(NanosPerElementLocked()) *
+ static_cast<double>(inputs_.size() - 1));
+ (*input_times)[input_times->size() - 1] += delta;
+ auto cleanup = gtl::MakeCleanup([input_times, delta]() {
+ (*input_times)[input_times->size() - 1] -= delta;
+ });
+ int64 output_time = OutputTimeForInputs(input_times) -
+ inputs_.front()->OutputTime(input_times);
+ return NanosPerElementLocked() +
+ static_cast<double>(output_time) /
+ static_cast<double>(inputs_.size() - 1);
+ }
+ case Type::MAP_AND_BATCH: {
+ double batch_size = metadata_["batch_size"];
+ double parallelism = metadata_["parallelism"];
+ int64 delta =
+ static_cast<int64>(static_cast<double>(NanosPerElementLocked()) /
+ (batch_size * parallelism));
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 output_time = static_cast<int64>(
+ static_cast<double>(NanosPerElementLocked()) / parallelism +
+ batch_size * OutputTimeForInputs(input_times));
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PARALLEL_INTERLEAVE:
+ case Type::PARALLEL_INTERLEAVE_V2: {
+ // TODO(jsimsa): model the first input
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 delta =
+ static_cast<int64>(static_cast<double>(NanosPerElementLocked()) *
+ static_cast<double>(inputs_.size() - 1));
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 inputs_output_time = OutputTimeForInputs(input_times) -
+ inputs_.front()->OutputTime(input_times);
+ double parallelism = std::min(port::NumSchedulableCPUs(),
+ static_cast<int>(metadata_["parallelism"]));
+ int64 output_time =
+ NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) /
+ static_cast<double>(inputs_.size() - 1)) /
+ parallelism);
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PARALLEL_MAP: {
+ double parallelism = std::min(port::NumSchedulableCPUs(),
+ static_cast<int>(metadata_["parallelism"]));
+ int64 delta = static_cast<int64>(
+ static_cast<double>(NanosPerElementLocked()) / parallelism);
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 output_time =
+ static_cast<double>(NanosPerElementLocked()) / parallelism +
+ OutputTimeForInputs(input_times);
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PREFETCH: {
+ int64 delta = NanosPerElementLocked();
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ return std::max(0LL, NanosPerElementLocked() +
+ OutputTimeForInputs(input_times) -
+ input_times->at(input_times->size() - 2));
+ }
+ case Type::CACHE:
+ case Type::CONCATENATE:
+ case Type::MAP:
+ case Type::REPEAT:
+ case Type::SHUFFLE:
+ case Type::SKIP:
+ case Type::TAKE:
+ case Type::ZIP: {
+ int64 delta = NanosPerElementLocked();
+ (*input_times)[input_times->size() - 1] += delta;
+ auto cleanup = gtl::MakeCleanup([input_times, delta]() {
+ (*input_times)[input_times->size() - 1] -= delta;
+ });
+ return NanosPerElementLocked() + OutputTimeForInputs(input_times);
+ }
+ default:
+ return NanosPerElementLocked();
+ }
+}
+
+Model::Model(const proto::Model& model_proto) {
+ id_counter_ = model_proto.id_counter();
+ std::map<int64, std::shared_ptr<Node>> lookup_table;
+ for (auto node_proto : model_proto.node()) {
+ std::shared_ptr<Node> node(new Node(node_proto));
+ lookup_table[node_proto.id()] = node;
+ }
+ for (auto node_proto : model_proto.node()) {
+ std::shared_ptr<Node> node = lookup_table[node_proto.id()];
+ for (int64 id : node_proto.input()) {
+ node->add_input(lookup_table[id]);
+ }
+ node->set_output(lookup_table[node_proto.output()]);
+ }
+ output_ = lookup_table[model_proto.output()];
+}
+
+std::shared_ptr<Node> Model::AddNode(const string& name,
+ const string& output_name) {
+ mutex_lock l(mu_);
+ std::shared_ptr<Node> output;
+ auto it = lookup_table_.find(output_name);
+ if (it != lookup_table_.end()) {
+ output = it->second;
+ }
+ std::shared_ptr<Node> node(new Node(id_counter_++, output));
+ if (!output_) {
+ output_ = node;
+ }
+ if (output) {
+ output->add_input(node);
+ }
+ lookup_table_.insert(std::make_pair(name, node));
+ return node;
+}
+
+std::shared_ptr<Node> Model::LookupNode(const string& name) {
+ tf_shared_lock l(mu_);
+ std::shared_ptr<Node> result;
+ auto it = lookup_table_.find(name);
+ if (it != lookup_table_.end()) {
+ result = it->second;
+ }
+ return result;
+}
+
+void Model::Optimize() {
+ mutex_lock l(mu_);
+ int64 processing_time = ProcessingTime();
+ int64 num_cpus = port::NumSchedulableCPUs();
+ std::vector<Node::Knob> knobs = CollectKnobs();
+ // The optimization algorithm starts by setting all parallelism knobs to 1. It
+ // then repeatedly identifies the knob that, when turned up by 1, decreases
+ // the output time the most. This process is repeated until all knobs reach
+ // the number of schedulable CPUs or the projected output time is less than or
+ // equal to the processing time needed to produce an element divided by the
+ // number of schedulable CPUs.
+ for (auto& knob : knobs) {
+ LOG(INFO) << knob.node->name() << " " << knob.processing_time;
+ knob.value = 1;
+ knob.node->set_metadata("parallelism", knob.value);
+ }
+ while (true) {
+ int64 output_time = OutputTime();
+ bool all_knobs = true;
+ for (auto knob : knobs) {
+ if (knob.value < num_cpus) {
+ all_knobs = false;
+ break;
+ }
+ }
+ if (output_time < processing_time / num_cpus || all_knobs) {
+ break;
+ }
+ int64 best_delta = -1;
+ int best_knob = -1;
+ for (int i = 0; i < knobs.size(); ++i) {
+ if (knobs[i].value == num_cpus) {
+ continue;
+ }
+ knobs[i].node->set_metadata("parallelism", knobs[i].value + 1);
+ int64 delta = output_time - OutputTime();
+ if (delta > best_delta) {
+ best_delta = delta;
+ best_knob = i;
+ }
+ knobs[i].node->set_metadata("parallelism", knobs[i].value);
+ }
+ knobs[best_knob].value++;
+ knobs[best_knob].node->set_metadata("parallelism", knobs[best_knob].value);
+ }
+ for (auto knob : knobs) {
+ LOG(INFO) << knob.node->name() << " " << knob.value;
+ }
+ LOG(INFO) << "output time: " << OutputTime();
+ LOG(INFO) << "processing time: " << ProcessingTime();
+}
+
+void Model::OutputToFile() {
+ proto::Model model_proto;
+ ToProto(&model_proto);
+ string filename;
+ Env::Default()->LocalTempFilename(&filename);
+ TF_CHECK_OK(WriteStringToFile(Env::Default(), filename,
+ model_proto.SerializeAsString()));
+ LOG(INFO) << filename;
+}
+
+void Model::RemoveNode(const string& prefix) {
+ mutex_lock l(mu_);
+ lookup_table_.erase(prefix);
+}
+
+void Model::ToProto(proto::Model* model_proto) {
+ mutex_lock l(mu_);
+ model_proto->set_id_counter(id_counter_);
+ model_proto->set_output(output_->id());
+ AddNodeToProto(output_, model_proto);
+}
+
+// static
+void Model::AddNodeToProto(const std::shared_ptr<Node>& node,
+ proto::Model* model_proto) {
+ proto::Node* node_proto = model_proto->add_node();
+ node->ToProto(node_proto);
+ for (const std::shared_ptr<Node>& input : node->inputs()) {
+ AddNodeToProto(input, model_proto);
+ }
+}
+
+std::vector<Node::Knob> Model::CollectKnobs() {
+ std::vector<Node::Knob> knobs;
+ output_->CollectKnobs(&knobs);
+ return knobs;
+}
+
+int64 Model::OutputTime() {
+ std::vector<int64> input_times(1, 0);
+ return output_->OutputTime(&input_times);
+}
+
+int64 Model::ProcessingTime() { return output_->ProcessingTime(); }
+
+} // namespace model
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
new file mode 100644
index 0000000000..98172909bf
--- /dev/null
+++ b/tensorflow/core/framework/model.h
@@ -0,0 +1,396 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
+
+#include <list>
+#include <memory>
+#include <string>
+#include <thread> // (b/114492873): move this include into core/platform
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/framework/model.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+namespace data {
+namespace model {
+
+class Model;
+class Node;
+
+// Abstract representation of a TensorFlow input pipeline node. It collects
+// information about inputs to this node, processing time spent executing the
+// node logic, number of elements produced by the node, various other
+// information (e.g. batch size or execution parallelism).
+//
+// Developers of tf.data transformations are not expected to interact with this
+// class directly. Boiler plate code for creating the abstract representation of
+// the input pipeline and collecting common information has been added to the
+// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
+//
+// In addition, `DatasetBaseIterator` provides wrappers that can be used for
+// transformation-specific information collection. The `SetMetadata` wrapper can
+// be used to pass arbitrary metadata to the modeling framework, while the
+// `StartWork` and `StopWork` wrappers should be used to correctly account for
+// processing time of multi-threaded transformation that yield the CPU; such
+// transformations should invoke `StartWork()` when a transformation thread
+// starts executing (e.g. when created or woken up) and `StopWork()` when a
+// transformation thread stops executing (e.g. when returning or waiting).
+//
+// TODO(jsimsa): Create an API to capture the abstract semantics of each
+// tf.data transformation and replace switch-case blocks with inheritance.
+class Node {
+ public:
+ Node(int64 id, std::shared_ptr<Node> output) : id_(id), output_(output) {}
+
+ explicit Node(const proto::Node& node_proto) : id_(node_proto.id()) {
+ name_ = node_proto.name();
+ type_ = TypeFromName(node_proto.name());
+ processing_time_ = node_proto.processing_time();
+ num_elements_ = node_proto.num_elements();
+ metadata_.insert(node_proto.metadata().begin(),
+ node_proto.metadata().end());
+ }
+
+ // Records that the node produced an element.
+ void add_element() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ num_elements_++;
+ }
+
+ // Adds an input.
+ void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ inputs_.push_back(node);
+ }
+
+ // Increments the aggregate processing time by the given delta.
+ void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ processing_time_ += delta;
+ }
+
+ // Returns the unique node ID.
+ int64 id() LOCKS_EXCLUDED(mu_) { return id_; }
+
+ // Returns the node inputs.
+ std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return inputs_;
+ }
+
+ // Returns the node name.
+ const string& name() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return name_;
+ }
+
+ // Returns the number of elements produced by the node.
+ int64 num_elements() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return num_elements_;
+ }
+
+ // Returns the node output.
+ std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return output_;
+ }
+
+ // Removes an input.
+ void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ inputs_.remove(input);
+ }
+
+ // Adds the given key-value pair to the node metadata.
+ void set_metadata(const string& key, int64 value) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ metadata_[key] = value;
+ }
+
+ // Sets the node name.
+ void set_name(const string& name) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ name_ = name;
+ type_ = TypeFromName(name);
+ }
+
+ // Set the node output.
+ void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ output_ = output;
+ }
+
+ // Records that a node thread has started work.
+ void start_work() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos();
+ }
+
+ // Records that a node thread has stopped work.
+ void stop_work() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ auto iter = work_start_.find(std::this_thread::get_id());
+ CHECK(work_start_.end() != iter)
+ << "Encountered a stop event that was not preceded by a start event.";
+ processing_time_ += Env::Default()->NowNanos() - iter->second;
+ work_start_.erase(iter);
+ }
+
+ private:
+ // Represents a performance knob.
+ struct Knob {
+ Node* node;
+ int64 processing_time;
+ int64 value;
+ };
+
+ enum class Type {
+ BATCH = 0,
+ CACHE,
+ CONCATENATE,
+ FILTER,
+ FLAT_MAP,
+ INTERLEAVE,
+ MAP,
+ MAP_AND_BATCH,
+ PADDED_BATCH,
+ PARALLEL_INTERLEAVE,
+ PARALLEL_INTERLEAVE_V2,
+ PARALLEL_MAP,
+ PREFETCH,
+ REPEAT,
+ SHUFFLE,
+ SKIP,
+ TAKE,
+ ZIP,
+ UNKNOWN,
+ };
+
+ // Collects performance knobs in the subtree rooted in this node.
+ void CollectKnobs(std::vector<Node::Knob>* knobs) LOCKS_EXCLUDED(mu_);
+
+ // Returns the per-element processing time spent in this node.
+ int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ return NanosPerElementLocked();
+ }
+
+ int64 NanosPerElementLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (num_elements_ == 0) {
+ return 0;
+ }
+ return (int64)((double)processing_time_ / (double)num_elements_);
+ }
+
+ // Returns the per-element output time for this node.
+ int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ return OutputTimeLocked(input_times);
+ }
+
+ int64 OutputTimeLocked(std::vector<int64>* input_times)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ int64 OutputTimeForInputs(std::vector<int64>* input_times)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 sum = 0;
+ for (auto input : inputs_) {
+ sum += input->OutputTime(input_times);
+ }
+ return sum;
+ }
+
+ // Returns the per-element processing time spent in the subtree rooted in this
+ // node.
+ int64 ProcessingTime() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ return ProcessingTimeLocked();
+ }
+
+ int64 ProcessingTimeLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Returns the per-element processing time spent in the inputs of this node.
+ int64 ProcessingTimeForInputs() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 sum = 0;
+ for (auto input : inputs_) {
+ sum += input->ProcessingTimeLocked();
+ }
+ return sum;
+ }
+
+ // Serializes the node state into the given proto.
+ void ToProto(proto::Node* node_proto) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ node_proto->set_id(id_);
+ node_proto->set_name(name_);
+ node_proto->set_num_elements(num_elements_);
+ node_proto->set_processing_time(processing_time_);
+ for (const std::shared_ptr<Node>& input : inputs_) {
+ node_proto->add_input(input->id());
+ }
+ if (output_) {
+ node_proto->set_output(output_->id());
+ }
+ node_proto->mutable_metadata()->insert(metadata_.begin(), metadata_.end());
+ }
+
+ Type TypeFromName(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (name_ == "Batch") {
+ return Type::BATCH;
+ }
+ if (str_util::EndsWith(name_, "Cache")) {
+ return Type::CACHE;
+ }
+ if (name_ == "Concatenate") {
+ return Type::CONCATENATE;
+ }
+ if (name_ == "Filter") {
+ return Type::FILTER;
+ }
+ if (name_ == "FlatMap") {
+ return Type::FLAT_MAP;
+ }
+ if (name_ == "Interleave") {
+ return Type::INTERLEAVE;
+ }
+ if (name_ == "Map") {
+ return Type::MAP;
+ }
+ if (name_ == "MapAndBatch") {
+ return Type::MAP_AND_BATCH;
+ }
+ if (name_ == "PaddedBatch") {
+ return Type::PADDED_BATCH;
+ }
+ if (name_ == "ParallelInterleave") {
+ return Type::PARALLEL_INTERLEAVE;
+ }
+ if (name_ == "ParallelInterleaveV2") {
+ return Type::PARALLEL_INTERLEAVE_V2;
+ }
+ if (name_ == "ParallelMap") {
+ return Type::PARALLEL_MAP;
+ }
+ if (name_ == "Prefetch") {
+ return Type::PREFETCH;
+ }
+ if (str_util::EndsWith(name_, "Repeat")) {
+ return Type::REPEAT;
+ }
+ if (name_ == "Shuffle") {
+ return Type::SHUFFLE;
+ }
+ if (str_util::EndsWith(name_, "Skip")) {
+ return Type::SKIP;
+ }
+ if (str_util::EndsWith(name_, "Take")) {
+ return Type::TAKE;
+ }
+ if (name_ == "Zip") {
+ return Type::ZIP;
+ }
+ return Type::UNKNOWN;
+ }
+
+ mutex mu_;
+ const int64 id_;
+ Type type_ GUARDED_BY(mu_);
+ string name_ GUARDED_BY(mu_);
+ int64 processing_time_ GUARDED_BY(mu_) = 0;
+ int64 num_elements_ GUARDED_BY(mu_) = 0;
+ std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
+ std::map<string, int64> metadata_ GUARDED_BY(mu_);
+ std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
+ std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+
+ friend class Model;
+};
+
+// Abstract representation of a TensorFlow input pipeline that can be used
+// for collecting runtime information and optimizing performance. It collects
+// runtime information about execution of the input pipeline that is used to
+// create a performance model, which is in turn used to identify optimal values
+// of performance knobs.
+//
+// Developers of tf.data transformations are not expected to interact with this
+// class directly. Boiler plate code for creating the abstract representation of
+// the input pipeline and collecting runtime information has been added to the
+// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
+//
+// TODO(jsimsa): Add a mechanism for feeding the result of the optimization
+// into the input pipeline.
+class Model {
+ public:
+ Model() = default;
+ explicit Model(const proto::Model& model_proto);
+
+ ~Model() {}
+
+ // Returns the model output node.
+ std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return output_;
+ }
+
+ // Adds a node with the given name and given output (identified by name).
+ std::shared_ptr<Node> AddNode(const string& name, const string& output_name)
+ LOCKS_EXCLUDED(mu_);
+
+ // Looks up the node using the given name.
+ std::shared_ptr<Node> LookupNode(const string& name) LOCKS_EXCLUDED(mu_);
+
+ // Runs optimization.
+ void Optimize() LOCKS_EXCLUDED(mu_);
+
+ // Outputs the state of a model to a file.
+ //
+ // TODO(jsimsa): Remove this method once the optimization loop is closed.
+ void OutputToFile() LOCKS_EXCLUDED(mu_);
+
+ // Removes the node identified by the given name.
+ void RemoveNode(const string& prefix) LOCKS_EXCLUDED(mu_);
+
+ // Serializes the model state to the given proto.
+ void ToProto(proto::Model* model_proto) LOCKS_EXCLUDED(mu_);
+
+ private:
+ static void AddNodeToProto(const std::shared_ptr<Node>& node,
+ proto::Model* model_proto);
+
+ std::vector<Node::Knob> CollectKnobs() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ int64 OutputTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ int64 ProcessingTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ mutex mu_;
+ int64 id_counter_ GUARDED_BY(mu_) = 1;
+ std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+ std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_);
+};
+
+} // namespace model
+} // namespace data
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
diff --git a/tensorflow/core/framework/model.proto b/tensorflow/core/framework/model.proto
new file mode 100644
index 0000000000..26000007af
--- /dev/null
+++ b/tensorflow/core/framework/model.proto
@@ -0,0 +1,30 @@
+syntax = "proto3";
+
+package tensorflow.data.model.proto;
+option cc_enable_arenas = true;
+
+message Model {
+ // Counter used for generating new node IDs.
+ int64 id_counter = 1;
+ // Nodes of this model.
+ repeated Node node = 2;
+ // The ID of the output node.
+ int64 output = 3;
+};
+
+message Node {
+ // The node ID.
+ int64 id = 1;
+ // The node name.
+ string name = 2;
+ // Input node IDs.
+ repeated int64 input = 3;
+ // Output node ID.
+ int64 output = 4;
+ // Number of elements produced by the node.
+ int64 num_elements = 5;
+ // The CPU time spent by running threads of this node.
+ int64 processing_time = 6;
+ // Key-value store for node metadata (e.g. batch size or parallelism).
+ map<string, int32> metadata = 7;
+};