blob: eb88ff71b1bf88da8b7b01ec218667dddb4eff38 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
|
#ifndef TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
#define TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
#include <functional>
#include <string>
#include <unordered_map>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/costmodel.h"
namespace tensorflow {
struct PartitionOptions {
// A function that returns a location for the execution of a given
// Node.
typedef std::function<string(const Node*)> NodeToLocFunc;
NodeToLocFunc node_to_loc = nullptr;
// A function that returns a unique graph node name with the given
// prefix.
typedef std::function<string(const string&)> NewNameFunc;
NewNameFunc new_name = nullptr;
// A function that returns the incarnation of a device given the
// device's fullname. If not found, GetIncarnationFunc should return
// kIlledgalIncarnation.
static const uint64 kIllegalIncarnation = 0;
typedef std::function<uint64(const string&)> GetIncarnationFunc;
GetIncarnationFunc get_incarnation = nullptr;
// True if all the control flow "code" has already been added. The
// control flow code needs to be added when we still have the entire
// graph before any partitioning. So this flag should be false for
// the first partitioning but true for all subsequent partitioning.
//
// TODO(yuanbyu): We could also make the addition of the control
// flow code incremental based on 'node_to_loc'. This makes the
// communication a broadcast tree, which could be more efficient when
// the number of participating devices is large.
bool control_flow_added;
// A function that returns the data type into which the tensor
// should be cast before sent over the wire.
typedef std::function<DataType(const Edge*)> ShouldCastFunc;
ShouldCastFunc should_cast = nullptr;
// Schedule the execution of the recvs based on their start times
// computed by some scheduling algorithm. The recvs are divided into
// epochs based on their start times. A recv is enabled only when
// execution reaches its epoch - N for some predefined N.
bool scheduling_for_recvs = false;
// The start time for each node in the graph computed by some scheduling
// algorithm. If 'need_to_record_start_times' is true, we record them
// in the graph as a node attribute.
bool need_to_record_start_times = false;
std::vector<Microseconds> start_times;
};
// Partition "input" graph into a set of graphs, one per location.
// The location for node n is derived by calling opts.node_to_loc(n).
// New nodes added by Partition use "opts.new_name(old_name)" to
// generate node names.
//
// Stores the partitions in *partitions.
Status Partition(const PartitionOptions& opts, Graph* input,
std::unordered_map<string, GraphDef>* partitions);
// Add control edges to the partitions to control the ordering
// and timing of the recv nodes based on the start times calculated
// using some scheduling algorithm.
Status AddControlEdges(const PartitionOptions& opts,
std::unordered_map<string, GraphDef>* partitions);
} // namespace tensorflow
#endif // TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
|