diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/segment/segment.h')
-rw-r--r-- | tensorflow/contrib/tensorrt/segment/segment.h | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h index 7e8685f44a..1568dd9153 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/contrib/tensorrt/segment/segment.h @@ -29,7 +29,9 @@ namespace tensorflow { namespace tensorrt { namespace segment { -using SegmentNodesVector = std::vector<std::set<string>>; +// vector of segments, each entry contains a device name and a set of nodes in +// segment +using SegmentNodesVector = std::vector<std::pair<std::set<string>, string>>; struct SegmentOptions { // Segment must contain at least this many nodes. @@ -51,6 +53,20 @@ tensorflow::Status SegmentGraph( const std::function<bool(const tensorflow::Node*)>& candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments); +// Get the subgraphs of a graph that can be handled by TensorRT. +// +// @param graph tensorflow::Graph of the network +// @param candidate_fn A function that returns true for a Node* if +// that node can be handled by TensorRT. +// @param segments Returns the TensorRT segments/subgraphs. Each entry +// in the vector describes a subgraph by giving a set of the names of +// all the NodeDefs in that subgraph. +// @return the status. +tensorflow::Status SegmentGraph( + tensorflow::Graph* tf_graph, + const std::function<bool(const tensorflow::Node*)>& candidate_fn, + const SegmentOptions& options, SegmentNodesVector* segments); + } // namespace segment } // namespace tensorrt } // namespace tensorflow |