diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/segment/segment.h')
-rw-r--r-- | tensorflow/contrib/tensorrt/segment/segment.h | 20 |
1 files changed, 3 insertions, 17 deletions
diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h index 81b4bfe49f..8c44eb782a 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/contrib/tensorrt/segment/segment.h @@ -42,22 +42,6 @@ struct SegmentOptions { // Get the subgraphs of a graph that can be handled by TensorRT. // -// @param gdef The GraphDef describing the network -// @param candidate_fn A function that returns true for a NodeDef if -// that node can be handled by TensorRT. -// @param segments Returns the TensorRT segments/subgraphs. Each entry -// in the vector describes a subgraph by giving a set of the names of -// all the NodeDefs in that subgraph. -// @return the status. -// -// TODO(aaroey): remove this method. -tensorflow::Status SegmentGraph( - const tensorflow::GraphDef& gdef, - const std::function<bool(const tensorflow::Node*)>& candidate_fn, - const SegmentOptions& options, SegmentNodesVector* segments); - -// Get the subgraphs of a graph that can be handled by TensorRT. -// // @param graph tensorflow::Graph of the network // @param candidate_fn A function that returns true for a Node* if // that node can be handled by TensorRT. @@ -66,8 +50,10 @@ tensorflow::Status SegmentGraph( // all the NodeDefs in that subgraph. // @return the status. tensorflow::Status SegmentGraph( - tensorflow::Graph* tf_graph, + const tensorflow::Graph* tf_graph, const std::function<bool(const tensorflow::Node*)>& candidate_fn, + const std::function<bool(const tensorflow::Edge*)>& input_candidate_fn, + const std::function<bool(const tensorflow::Edge*)>& output_candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments); } // namespace segment |