aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph_partition.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/graph_partition.h')
-rw-r--r--tensorflow/core/graph/graph_partition.h77
1 files changed, 77 insertions, 0 deletions
diff --git a/tensorflow/core/graph/graph_partition.h b/tensorflow/core/graph/graph_partition.h
new file mode 100644
index 0000000000..eb88ff71b1
--- /dev/null
+++ b/tensorflow/core/graph/graph_partition.h
@@ -0,0 +1,77 @@
+#ifndef TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
+#define TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
+
+#include <functional>
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/costmodel.h"
+
+namespace tensorflow {
+
+struct PartitionOptions {
+ // A function that returns a location for the execution of a given
+ // Node.
+ typedef std::function<string(const Node*)> NodeToLocFunc;
+ NodeToLocFunc node_to_loc = nullptr;
+
+ // A function that returns a unique graph node name with the given
+ // prefix.
+ typedef std::function<string(const string&)> NewNameFunc;
+ NewNameFunc new_name = nullptr;
+
+ // A function that returns the incarnation of a device given the
+ // device's fullname. If not found, GetIncarnationFunc should return
+ // kIlledgalIncarnation.
+ static const uint64 kIllegalIncarnation = 0;
+ typedef std::function<uint64(const string&)> GetIncarnationFunc;
+ GetIncarnationFunc get_incarnation = nullptr;
+
+ // True if all the control flow "code" has already been added. The
+ // control flow code needs to be added when we still have the entire
+ // graph before any partitioning. So this flag should be false for
+ // the first partitioning but true for all subsequent partitioning.
+ //
+ // TODO(yuanbyu): We could also make the addition of the control
+ // flow code incremental based on 'node_to_loc'. This makes the
+ // communication a broadcast tree, which could be more efficient when
+ // the number of participating devices is large.
+ bool control_flow_added;
+
+ // A function that returns the data type into which the tensor
+ // should be cast before sent over the wire.
+ typedef std::function<DataType(const Edge*)> ShouldCastFunc;
+ ShouldCastFunc should_cast = nullptr;
+
+ // Schedule the execution of the recvs based on their start times
+ // computed by some scheduling algorithm. The recvs are divided into
+ // epochs based on their start times. A recv is enabled only when
+ // execution reaches its epoch - N for some predefined N.
+ bool scheduling_for_recvs = false;
+ // The start time for each node in the graph computed by some scheduling
+ // algorithm. If 'need_to_record_start_times' is true, we record them
+ // in the graph as a node attribute.
+ bool need_to_record_start_times = false;
+ std::vector<Microseconds> start_times;
+};
+
+// Partition "input" graph into a set of graphs, one per location.
+// The location for node n is derived by calling opts.node_to_loc(n).
+// New nodes added by Partition use "opts.new_name(old_name)" to
+// generate node names.
+//
+// Stores the partitions in *partitions.
+Status Partition(const PartitionOptions& opts, Graph* input,
+ std::unordered_map<string, GraphDef>* partitions);
+
+// Add control edges to the partitions to control the ordering
+// and timing of the recv nodes based on the start times calculated
+// using some scheduling algorithm.
+Status AddControlEdges(const PartitionOptions& opts,
+ std::unordered_map<string, GraphDef>* partitions);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_GRAPH_PARTITION_H_