aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph_partition.h
blob: eb88ff71b1bf88da8b7b01ec218667dddb4eff38 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#ifndef TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
#define TENSORFLOW_GRAPH_GRAPH_PARTITION_H_

#include <functional>
#include <string>
#include <unordered_map>

#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/costmodel.h"

namespace tensorflow {

struct PartitionOptions {
  // A function that returns a location for the execution of a given
  // Node.
  typedef std::function<string(const Node*)> NodeToLocFunc;
  NodeToLocFunc node_to_loc = nullptr;

  // A function that returns a unique graph node name with the given
  // prefix.
  typedef std::function<string(const string&)> NewNameFunc;
  NewNameFunc new_name = nullptr;

  // A function that returns the incarnation of a device given the
  // device's fullname. If not found, GetIncarnationFunc should return
  // kIlledgalIncarnation.
  static const uint64 kIllegalIncarnation = 0;
  typedef std::function<uint64(const string&)> GetIncarnationFunc;
  GetIncarnationFunc get_incarnation = nullptr;

  // True if all the control flow "code" has already been added. The
  // control flow code needs to be added when we still have the entire
  // graph before any partitioning. So this flag should be false for
  // the first partitioning but true for all subsequent partitioning.
  //
  // TODO(yuanbyu): We could also make the addition of the control
  // flow code incremental based on 'node_to_loc'. This makes the
  // communication a broadcast tree, which could be more efficient when
  // the number of participating devices is large.
  bool control_flow_added;

  // A function that returns the data type into which the tensor
  // should be cast before sent over the wire.
  typedef std::function<DataType(const Edge*)> ShouldCastFunc;
  ShouldCastFunc should_cast = nullptr;

  // Schedule the execution of the recvs based on their start times
  // computed by some scheduling algorithm. The recvs are divided into
  // epochs based on their start times. A recv is enabled only when
  // execution reaches its epoch - N for some predefined N.
  bool scheduling_for_recvs = false;
  // The start time for each node in the graph computed by some scheduling
  // algorithm. If 'need_to_record_start_times' is true, we record them
  // in the graph as a node attribute.
  bool need_to_record_start_times = false;
  std::vector<Microseconds> start_times;
};

// Partition "input" graph into a set of graphs, one per location.
// The location for node n is derived by calling opts.node_to_loc(n).
// New nodes added by Partition use "opts.new_name(old_name)" to
// generate node names.
//
// Stores the partitions in *partitions.
Status Partition(const PartitionOptions& opts, Graph* input,
                 std::unordered_map<string, GraphDef>* partitions);

// Add control edges to the partitions to control the ordering
// and timing of the recv nodes based on the start times calculated
// using some scheduling algorithm.
Status AddControlEdges(const PartitionOptions& opts,
                       std::unordered_map<string, GraphDef>* partitions);

}  // namespace tensorflow

#endif  // TENSORFLOW_GRAPH_GRAPH_PARTITION_H_