aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/segment/segment.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/segment/segment.h')
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.h20
1 files changed, 3 insertions, 17 deletions
diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h
index 81b4bfe49f..8c44eb782a 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.h
+++ b/tensorflow/contrib/tensorrt/segment/segment.h
@@ -42,22 +42,6 @@ struct SegmentOptions {
// Get the subgraphs of a graph that can be handled by TensorRT.
//
-// @param gdef The GraphDef describing the network
-// @param candidate_fn A function that returns true for a NodeDef if
-// that node can be handled by TensorRT.
-// @param segments Returns the TensorRT segments/subgraphs. Each entry
-// in the vector describes a subgraph by giving a set of the names of
-// all the NodeDefs in that subgraph.
-// @return the status.
-//
-// TODO(aaroey): remove this method.
-tensorflow::Status SegmentGraph(
- const tensorflow::GraphDef& gdef,
- const std::function<bool(const tensorflow::Node*)>& candidate_fn,
- const SegmentOptions& options, SegmentNodesVector* segments);
-
-// Get the subgraphs of a graph that can be handled by TensorRT.
-//
// @param graph tensorflow::Graph of the network
// @param candidate_fn A function that returns true for a Node* if
// that node can be handled by TensorRT.
@@ -66,8 +50,10 @@ tensorflow::Status SegmentGraph(
// all the NodeDefs in that subgraph.
// @return the status.
tensorflow::Status SegmentGraph(
- tensorflow::Graph* tf_graph,
+ const tensorflow::Graph* tf_graph,
const std::function<bool(const tensorflow::Node*)>& candidate_fn,
+ const std::function<bool(const tensorflow::Edge*)>& input_candidate_fn,
+ const std::function<bool(const tensorflow::Edge*)>& output_candidate_fn,
const SegmentOptions& options, SegmentNodesVector* segments);
} // namespace segment