aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/model.h
blob: 98172909bf299a90fb54021b99a0bbd89e8f59d3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
#define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_

#include <list>
#include <memory>
#include <string>
#include <thread>  // (b/114492873): move this include into core/platform
#include <utility>
#include <vector>

#include "tensorflow/core/framework/model.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"

namespace tensorflow {
namespace data {
namespace model {

class Model;
class Node;

// Abstract representation of a TensorFlow input pipeline node. It collects
// information about inputs to this node, processing time spent executing the
// node logic, number of elements produced by the node, various other
// information (e.g. batch size or execution parallelism).
//
// Developers of tf.data transformations are not expected to interact with this
// class directly. Boiler plate code for creating the abstract representation of
// the input pipeline and collecting common information has been added to the
// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
//
// In addition, `DatasetBaseIterator` provides wrappers that can be used for
// transformation-specific information collection. The `SetMetadata` wrapper can
// be used to pass arbitrary metadata to the modeling framework, while the
// `StartWork` and `StopWork` wrappers should be used to correctly account for
// processing time of multi-threaded transformation that yield the CPU; such
// transformations should invoke `StartWork()` when a transformation thread
// starts executing (e.g. when created or woken up) and `StopWork()` when a
// transformation thread stops executing (e.g. when returning or waiting).
//
// TODO(jsimsa): Create an API to capture the abstract semantics of each
// tf.data transformation and replace switch-case blocks with inheritance.
class Node {
 public:
  Node(int64 id, std::shared_ptr<Node> output) : id_(id), output_(output) {}

  explicit Node(const proto::Node& node_proto) : id_(node_proto.id()) {
    name_ = node_proto.name();
    type_ = TypeFromName(node_proto.name());
    processing_time_ = node_proto.processing_time();
    num_elements_ = node_proto.num_elements();
    metadata_.insert(node_proto.metadata().begin(),
                     node_proto.metadata().end());
  }

  // Records that the node produced an element.
  void add_element() LOCKS_EXCLUDED(mu_) {
    mutex_lock l(mu_);
    num_elements_++;
  }

  // Adds an input.
  void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) {
    mutex_lock l(mu_);
    inputs_.push_back(node);
  }

  // Increments the aggregate processing time by the given delta.
  void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) {
    mutex_lock l(mu_);
    processing_time_ += delta;
  }

  // Returns the unique node ID.
  int64 id() LOCKS_EXCLUDED(mu_) { return id_; }

  // Returns the node inputs.
  std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) {
    tf_shared_lock l(mu_);
    return inputs_;
  }

  // Returns the node name.
  const string& name() LOCKS_EXCLUDED(mu_) {
    tf_shared_lock l(mu_);
    return name_;
  }

  // Returns the number of elements produced by the node.
  int64 num_elements() LOCKS_EXCLUDED(mu_) {
    tf_shared_lock l(mu_);
    return num_elements_;
  }

  // Returns the node output.
  std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
    tf_shared_lock l(mu_);
    return output_;
  }

  // Removes an input.
  void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) {
    mutex_lock l(mu_);
    inputs_.remove(input);
  }

  // Adds the given key-value pair to the node metadata.
  void set_metadata(const string& key, int64 value) LOCKS_EXCLUDED(mu_) {
    mutex_lock l(mu_);
    metadata_[key] = value;
  }

  // Sets the node name.
  void set_name(const string& name) LOCKS_EXCLUDED(mu_) {
    mutex_lock l(mu_);
    name_ = name;
    type_ = TypeFromName(name);
  }

  // Set the node output.
  void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) {
    mutex_lock l(mu_);
    output_ = output;
  }

  // Records that a node thread has started work.
  void start_work() LOCKS_EXCLUDED(mu_) {
    mutex_lock l(mu_);
    work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos();
  }

  // Records that a node thread has stopped work.
  void stop_work() LOCKS_EXCLUDED(mu_) {
    mutex_lock l(mu_);
    auto iter = work_start_.find(std::this_thread::get_id());
    CHECK(work_start_.end() != iter)
        << "Encountered a stop event that was not preceded by a start event.";
    processing_time_ += Env::Default()->NowNanos() - iter->second;
    work_start_.erase(iter);
  }

 private:
  // Represents a performance knob.
  struct Knob {
    Node* node;
    int64 processing_time;
    int64 value;
  };

  enum class Type {
    BATCH = 0,
    CACHE,
    CONCATENATE,
    FILTER,
    FLAT_MAP,
    INTERLEAVE,
    MAP,
    MAP_AND_BATCH,
    PADDED_BATCH,
    PARALLEL_INTERLEAVE,
    PARALLEL_INTERLEAVE_V2,
    PARALLEL_MAP,
    PREFETCH,
    REPEAT,
    SHUFFLE,
    SKIP,
    TAKE,
    ZIP,
    UNKNOWN,
  };

  // Collects performance knobs in the subtree rooted in this node.
  void CollectKnobs(std::vector<Node::Knob>* knobs) LOCKS_EXCLUDED(mu_);

  // Returns the per-element processing time spent in this node.
  int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
    mutex_lock l(mu_);
    return NanosPerElementLocked();
  }

  int64 NanosPerElementLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    if (num_elements_ == 0) {
      return 0;
    }
    return (int64)((double)processing_time_ / (double)num_elements_);
  }

  // Returns the per-element output time for this node.
  int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) {
    mutex_lock l(mu_);
    return OutputTimeLocked(input_times);
  }

  int64 OutputTimeLocked(std::vector<int64>* input_times)
      EXCLUSIVE_LOCKS_REQUIRED(mu_);

  int64 OutputTimeForInputs(std::vector<int64>* input_times)
      EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    int64 sum = 0;
    for (auto input : inputs_) {
      sum += input->OutputTime(input_times);
    }
    return sum;
  }

  // Returns the per-element processing time spent in the subtree rooted in this
  // node.
  int64 ProcessingTime() LOCKS_EXCLUDED(mu_) {
    mutex_lock l(mu_);
    return ProcessingTimeLocked();
  }

  int64 ProcessingTimeLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);

  // Returns the per-element processing time spent in the inputs of this node.
  int64 ProcessingTimeForInputs() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    int64 sum = 0;
    for (auto input : inputs_) {
      sum += input->ProcessingTimeLocked();
    }
    return sum;
  }

  // Serializes the node state into the given proto.
  void ToProto(proto::Node* node_proto) LOCKS_EXCLUDED(mu_) {
    mutex_lock l(mu_);
    node_proto->set_id(id_);
    node_proto->set_name(name_);
    node_proto->set_num_elements(num_elements_);
    node_proto->set_processing_time(processing_time_);
    for (const std::shared_ptr<Node>& input : inputs_) {
      node_proto->add_input(input->id());
    }
    if (output_) {
      node_proto->set_output(output_->id());
    }
    node_proto->mutable_metadata()->insert(metadata_.begin(), metadata_.end());
  }

  Type TypeFromName(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    if (name_ == "Batch") {
      return Type::BATCH;
    }
    if (str_util::EndsWith(name_, "Cache")) {
      return Type::CACHE;
    }
    if (name_ == "Concatenate") {
      return Type::CONCATENATE;
    }
    if (name_ == "Filter") {
      return Type::FILTER;
    }
    if (name_ == "FlatMap") {
      return Type::FLAT_MAP;
    }
    if (name_ == "Interleave") {
      return Type::INTERLEAVE;
    }
    if (name_ == "Map") {
      return Type::MAP;
    }
    if (name_ == "MapAndBatch") {
      return Type::MAP_AND_BATCH;
    }
    if (name_ == "PaddedBatch") {
      return Type::PADDED_BATCH;
    }
    if (name_ == "ParallelInterleave") {
      return Type::PARALLEL_INTERLEAVE;
    }
    if (name_ == "ParallelInterleaveV2") {
      return Type::PARALLEL_INTERLEAVE_V2;
    }
    if (name_ == "ParallelMap") {
      return Type::PARALLEL_MAP;
    }
    if (name_ == "Prefetch") {
      return Type::PREFETCH;
    }
    if (str_util::EndsWith(name_, "Repeat")) {
      return Type::REPEAT;
    }
    if (name_ == "Shuffle") {
      return Type::SHUFFLE;
    }
    if (str_util::EndsWith(name_, "Skip")) {
      return Type::SKIP;
    }
    if (str_util::EndsWith(name_, "Take")) {
      return Type::TAKE;
    }
    if (name_ == "Zip") {
      return Type::ZIP;
    }
    return Type::UNKNOWN;
  }

  mutex mu_;
  const int64 id_;
  Type type_ GUARDED_BY(mu_);
  string name_ GUARDED_BY(mu_);
  int64 processing_time_ GUARDED_BY(mu_) = 0;
  int64 num_elements_ GUARDED_BY(mu_) = 0;
  std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
  std::map<string, int64> metadata_ GUARDED_BY(mu_);
  std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
  std::shared_ptr<Node> output_ GUARDED_BY(mu_);

  friend class Model;
};

// Abstract representation of a TensorFlow input pipeline that can be used
// for collecting runtime information and optimizing performance. It collects
// runtime information about execution of the input pipeline that is used to
// create a performance model, which is in turn used to identify optimal values
// of performance knobs.
//
// Developers of tf.data transformations are not expected to interact with this
// class directly. Boiler plate code for creating the abstract representation of
// the input pipeline and collecting runtime information has been added to the
// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
//
// TODO(jsimsa): Add a mechanism for feeding the result of the optimization
// into the input pipeline.
class Model {
 public:
  Model() = default;
  explicit Model(const proto::Model& model_proto);

  ~Model() {}

  // Returns the model output node.
  std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
    tf_shared_lock l(mu_);
    return output_;
  }

  // Adds a node with the given name and given output (identified by name).
  std::shared_ptr<Node> AddNode(const string& name, const string& output_name)
      LOCKS_EXCLUDED(mu_);

  // Looks up the node using the given name.
  std::shared_ptr<Node> LookupNode(const string& name) LOCKS_EXCLUDED(mu_);

  // Runs optimization.
  void Optimize() LOCKS_EXCLUDED(mu_);

  // Outputs the state of a model to a file.
  //
  // TODO(jsimsa): Remove this method once the optimization loop is closed.
  void OutputToFile() LOCKS_EXCLUDED(mu_);

  // Removes the node identified by the given name.
  void RemoveNode(const string& prefix) LOCKS_EXCLUDED(mu_);

  // Serializes the model state to the given proto.
  void ToProto(proto::Model* model_proto) LOCKS_EXCLUDED(mu_);

 private:
  static void AddNodeToProto(const std::shared_ptr<Node>& node,
                             proto::Model* model_proto);

  std::vector<Node::Knob> CollectKnobs() EXCLUSIVE_LOCKS_REQUIRED(mu_);

  int64 OutputTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);

  int64 ProcessingTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);

  mutex mu_;
  int64 id_counter_ GUARDED_BY(mu_) = 1;
  std::shared_ptr<Node> output_ GUARDED_BY(mu_);
  std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_);
};

}  // namespace model
}  // namespace data
}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_