aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/model.h
blob: 9596252664e213a7c61856ee333e74bf51b9b340 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
#define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_

#include <list>
#include <memory>
#include <string>
#include <thread>  // (b/114492873): move this include into core/platform
#include <utility>
#include <vector>

#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"

namespace tensorflow {
namespace data {
namespace model {

// Represents thread-safe state that can be shared between an input pipeline and
// the performance model.
struct SharedState {
 public:
  explicit SharedState(int64 value, std::shared_ptr<mutex> mu,
                       std::shared_ptr<condition_variable> cond_var)
      : value(value), mu(std::move(mu)), cond_var(std::move(cond_var)) {}

  std::shared_ptr<mutex> mu;
  std::shared_ptr<condition_variable> cond_var;
  int64 value;
};

// Abstract representation of a TensorFlow input pipeline that can be used
// for collecting runtime information and optimizing performance. It collects
// runtime information about execution of the input pipeline that is used to
// create a performance model, which is in turn used to identify optimal values
// of tunable parameters.
//
// Developers of tf.data transformations are not expected to interact with this
// class directly. Boiler plate code for creating the abstract representation of
// the input pipeline and collecting runtime information has been added to the
// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
class Model {
 public:
  Model() = default;

  // Adds a constant parameter for the given node.
  void AddConstantParameter(const string& node_name,
                            const string& parameter_name, int64 value)
      LOCKS_EXCLUDED(mu_);

  // Adds a node with the given name and given output (identified by name).
  void AddNode(const string& name, const string& output_name)
      LOCKS_EXCLUDED(mu_);

  // Increments the processing time for the given node..
  void AddProcessingTime(const string& name, int64 delta) LOCKS_EXCLUDED(mu_);

  // Adds a tunable parameter for the given node.
  void AddTunableParameter(const string& node_name,
                           const string& parameter_name,
                           std::shared_ptr<SharedState> value, int64 min,
                           int64 max) LOCKS_EXCLUDED(mu_);

  // Runs optimization.
  void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_);

  // Records that a node has produced an element.
  void RecordElement(const string& name) LOCKS_EXCLUDED(mu_);

  // Records that the given node has started work. If `stop_output` is set, it
  // also records that the output of the given node has stopped work.
  void RecordStart(const string& name, bool stop_output) LOCKS_EXCLUDED(mu_);

  // Records that the given node has stopped work. If `stop_output` is set, it
  // also records that the output of the given node has started work.
  void RecordStop(const string& name, bool start_output) LOCKS_EXCLUDED(mu_);

  // Removes the given node.
  void RemoveNode(const string& name) LOCKS_EXCLUDED(mu_);

 private:
  // Abstract representation of a TensorFlow input pipeline node. It collects
  // information about inputs to this node, processing time spent executing the
  // node logic, number of elements produced by the node, various other
  // information (e.g. batch size or execution parallelism).
  //
  // Developers of tf.data transformations are not expected to interact with
  // this class directly. Boiler plate code for creating the abstract
  // representation of the input pipeline and collecting common information has
  // been added to the implementation of `DatasetBase` and `DatasetBaseIterator`
  // respectively.
  //
  // In addition, `DatasetBaseIterator` provides wrappers that can be used for
  // transformation-specific information collection. The `SetMetadata` wrapper
  // can be used to pass arbitrary metadata to the modeling framework, while the
  // `StartWork` and `StopWork` wrappers should be used to correctly account for
  // processing time of multi-threaded transformation that yield the CPU; such
  // transformations should invoke `StartWork()` when a transformation thread
  // starts executing (e.g. when created or woken up) and `StopWork()` when a
  // transformation thread stops executing (e.g. when returning or waiting).
  //
  // TODO(jsimsa): Create an API to capture the abstract semantics of each
  // tf.data transformation and replace switch-case blocks with inheritance.
  class Node {
   public:
    // Represents a tunable parameter.
    struct Tunable {
      Tunable(std::shared_ptr<SharedState> state, int64 min, int64 max)
          : value(state->value), min(min), max(max), state(std::move(state)) {}

      // Identifies the model value of the parameter. This can be different from
      // the actual value (e.g. during optimization search).
      int64 value;

      // Identifies the minimum value of the parameter.
      int64 min;

      // Identifies the maximum value of the parameter.
      int64 max;

      // Shared state of the parameter.
      std::shared_ptr<SharedState> state;
    };

    Node(int64 id, const string& name, std::shared_ptr<Node> output)
        : id_(id), name_(name), type_(TypeFromName(name)), output_(output) {}

    // Adds a constant parameter.
    void add_constant_param(const string& name, int64 value)
        LOCKS_EXCLUDED(mu_) {
      mutex_lock l(mu_);
      constant_params_[name] = value;
    }

    // Adds an input.
    void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) {
      mutex_lock l(mu_);
      inputs_.push_back(node);
    }

    // Increments the aggregate processing time by the given delta.
    void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) {
      mutex_lock l(mu_);
      processing_time_ += delta;
    }

    // Adds a tunable parameter.
    void add_tunable_param(const string& name,
                           std::shared_ptr<SharedState> state, int64 min,
                           int64 max) LOCKS_EXCLUDED(mu_) {
      mutex_lock l(mu_);
      tunable_params_[name] =
          std::make_shared<Tunable>(std::move(state), min, max);
    }

    // Returns the unique node ID.
    int64 id() LOCKS_EXCLUDED(mu_) { return id_; }

    // Returns the node inputs.
    std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) {
      tf_shared_lock l(mu_);
      return inputs_;
    }

    // Returns the node name.
    const string& name() LOCKS_EXCLUDED(mu_) {
      tf_shared_lock l(mu_);
      return name_;
    }

    // Returns the number of elements produced by the node.
    int64 num_elements() LOCKS_EXCLUDED(mu_) {
      tf_shared_lock l(mu_);
      return num_elements_;
    }

    // Returns the node output.
    std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
      tf_shared_lock l(mu_);
      return output_;
    }

    // Records that the node produced an element.
    void record_element() LOCKS_EXCLUDED(mu_) {
      mutex_lock l(mu_);
      num_elements_++;
    }

    // Records that a node thread has started executing.
    void record_start() LOCKS_EXCLUDED(mu_) {
      mutex_lock l(mu_);
      work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos();
    }

    // Records that a node thread has stopped executing.
    void record_stop() LOCKS_EXCLUDED(mu_) {
      mutex_lock l(mu_);
      std::thread::id tid = std::this_thread::get_id();
      auto start_time = gtl::FindOrNull(work_start_, tid);
      DCHECK(start_time)
          << "Encountered a stop event that was not preceded by a start event.";
      if (start_time) {
        processing_time_ += Env::Default()->NowNanos() - *start_time;
        work_start_.erase(tid);
      }
    }

    // Removes an input.
    void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) {
      mutex_lock l(mu_);
      inputs_.remove(input);
    }

    // Set the node output.
    void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) {
      mutex_lock l(mu_);
      output_ = output;
    }

    // Collects tunable parameters in the subtree rooted in this node.
    void CollectTunables(std::vector<std::shared_ptr<Tunable>>* tunables)
        LOCKS_EXCLUDED(mu_);

    // Returns the per-element output time for this node.
    int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) {
      tf_shared_lock l(mu_);
      return OutputTimeLocked(input_times);
    }

    // Returns the per-element processing time spent in the subtree rooted in
    // this node.
    int64 ProcessingTime() LOCKS_EXCLUDED(mu_) {
      tf_shared_lock l(mu_);
      return ProcessingTimeLocked();
    }

   private:
    enum class Type {
      BATCH = 0,
      CACHE,
      CONCATENATE,
      FILTER,
      FLAT_MAP,
      INTERLEAVE,
      MAP,
      MAP_AND_BATCH,
      PADDED_BATCH,
      PARALLEL_INTERLEAVE,
      PARALLEL_INTERLEAVE_V2,
      PARALLEL_MAP,
      PREFETCH,
      REPEAT,
      SHUFFLE,
      SKIP,
      TAKE,
      ZIP,
      UNKNOWN,
    };

    // Gets a value of the given parameter (tunable or constant).
    int64 GetParameterValue(const string& name) SHARED_LOCKS_REQUIRED(mu_);

    // Returns the per-element processing time spent in this node.
    int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
      tf_shared_lock l(mu_);
      return NanosPerElementLocked();
    }

    int64 NanosPerElementLocked() SHARED_LOCKS_REQUIRED(mu_) {
      if (num_elements_ == 0) {
        return 0;
      }
      return (int64)((double)processing_time_ / (double)num_elements_);
    }

    int64 OutputTimeLocked(std::vector<int64>* input_times)
        SHARED_LOCKS_REQUIRED(mu_);

    int64 OutputTimeForInputs(std::vector<int64>* input_times)
        SHARED_LOCKS_REQUIRED(mu_) {
      int64 sum = 0;
      for (auto input : inputs_) {
        sum += input->OutputTime(input_times);
      }
      return sum;
    }

    int64 ProcessingTimeLocked() SHARED_LOCKS_REQUIRED(mu_);

    // Returns the per-element processing time spent in the inputs of this node.
    int64 ProcessingTimeForInputs() SHARED_LOCKS_REQUIRED(mu_) {
      int64 sum = 0;
      for (auto input : inputs_) {
        sum += input->ProcessingTime();
      }
      return sum;
    }

    Type TypeFromName(const string& name) SHARED_LOCKS_REQUIRED(mu_) {
      if (name_ == "Batch") {
        return Type::BATCH;
      }
      if (str_util::EndsWith(name_, "Cache")) {
        return Type::CACHE;
      }
      if (name_ == "Concatenate") {
        return Type::CONCATENATE;
      }
      if (name_ == "Filter") {
        return Type::FILTER;
      }
      if (name_ == "FlatMap") {
        return Type::FLAT_MAP;
      }
      if (name_ == "Interleave") {
        return Type::INTERLEAVE;
      }
      if (name_ == "Map") {
        return Type::MAP;
      }
      if (name_ == "MapAndBatch" || name_ == "NumaMapAndBatch") {
        return Type::MAP_AND_BATCH;
      }
      if (name_ == "PaddedBatch") {
        return Type::PADDED_BATCH;
      }
      if (name_ == "ParallelInterleave") {
        return Type::PARALLEL_INTERLEAVE;
      }
      if (name_ == "ParallelInterleaveV2") {
        return Type::PARALLEL_INTERLEAVE_V2;
      }
      if (name_ == "ParallelMap") {
        return Type::PARALLEL_MAP;
      }
      if (name_ == "Prefetch") {
        return Type::PREFETCH;
      }
      if (str_util::EndsWith(name_, "Repeat")) {
        return Type::REPEAT;
      }
      if (name_ == "Shuffle") {
        return Type::SHUFFLE;
      }
      if (str_util::EndsWith(name_, "Skip")) {
        return Type::SKIP;
      }
      if (str_util::EndsWith(name_, "Take")) {
        return Type::TAKE;
      }
      if (name_ == "Zip") {
        return Type::ZIP;
      }
      return Type::UNKNOWN;
    }

    mutex mu_;
    const int64 id_;
    const string name_;
    const Type type_;
    int64 processing_time_ GUARDED_BY(mu_) = 0;
    int64 num_elements_ GUARDED_BY(mu_) = 0;
    std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
    std::map<string, int64> constant_params_ GUARDED_BY(mu_);
    // Tunables are shared with the model during optimization.
    std::map<string, std::shared_ptr<Tunable>> tunable_params_ GUARDED_BY(mu_);
    std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
    std::shared_ptr<Node> output_ GUARDED_BY(mu_);
  };

  std::vector<std::shared_ptr<Node::Tunable>> CollectTunables()
      SHARED_LOCKS_REQUIRED(mu_);

  int64 OutputTime() SHARED_LOCKS_REQUIRED(mu_);

  int64 ProcessingTime() SHARED_LOCKS_REQUIRED(mu_);

  // Used for coordination between different input pipeline threads. Exclusive
  // access is required only when adding or removing nodes. Concurrent access to
  // existing nodes is protected by a node mutex.
  mutex mu_;
  int64 id_counter_ GUARDED_BY(mu_) = 1;
  std::shared_ptr<Node> output_ GUARDED_BY(mu_);
  std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_);
};

}  // namespace model
}  // namespace data
}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_