aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/segment/segment.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/segment/segment.h')
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.h7
1 files changed, 5 insertions, 2 deletions
diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h
index 1568dd9153..81b4bfe49f 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.h
+++ b/tensorflow/contrib/tensorrt/segment/segment.h
@@ -29,8 +29,9 @@ namespace tensorflow {
namespace tensorrt {
namespace segment {
-// vector of segments, each entry contains a device name and a set of nodes in
-// segment
+// Vector of segments, each entry contains a set of node names and a device name
+// in the segment.
+// TODO(aaroey): use node pointer instead of node name.
using SegmentNodesVector = std::vector<std::pair<std::set<string>, string>>;
struct SegmentOptions {
@@ -48,6 +49,8 @@ struct SegmentOptions {
// in the vector describes a subgraph by giving a set of the names of
// all the NodeDefs in that subgraph.
// @return the status.
+//
+// TODO(aaroey): remove this method.
tensorflow::Status SegmentGraph(
const tensorflow::GraphDef& gdef,
const std::function<bool(const tensorflow::Node*)>& candidate_fn,