aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/segment/segment.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/segment/segment.h')
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.h18
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h
index 7e8685f44a..1568dd9153 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.h
+++ b/tensorflow/contrib/tensorrt/segment/segment.h
@@ -29,7 +29,9 @@ namespace tensorflow {
namespace tensorrt {
namespace segment {
-using SegmentNodesVector = std::vector<std::set<string>>;
+// vector of segments, each entry contains a device name and a set of nodes in
+// segment
+using SegmentNodesVector = std::vector<std::pair<std::set<string>, string>>;
struct SegmentOptions {
// Segment must contain at least this many nodes.
@@ -51,6 +53,20 @@ tensorflow::Status SegmentGraph(
const std::function<bool(const tensorflow::Node*)>& candidate_fn,
const SegmentOptions& options, SegmentNodesVector* segments);
+// Get the subgraphs of a graph that can be handled by TensorRT.
+//
+// @param graph tensorflow::Graph of the network
+// @param candidate_fn A function that returns true for a Node* if
+// that node can be handled by TensorRT.
+// @param segments Returns the TensorRT segments/subgraphs. Each entry
+// in the vector describes a subgraph by giving a set of the names of
+// all the NodeDefs in that subgraph.
+// @return the status.
+tensorflow::Status SegmentGraph(
+ tensorflow::Graph* tf_graph,
+ const std::function<bool(const tensorflow::Node*)>& candidate_fn,
+ const SegmentOptions& options, SegmentNodesVector* segments);
+
} // namespace segment
} // namespace tensorrt
} // namespace tensorflow