aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/convert/convert_nodes.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert/convert_nodes.h')
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h31
1 files changed, 29 insertions, 2 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index 7684d8d4a2..6ae60ec352 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/contrib/tensorrt/convert/utils.h"
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h"
#include "tensorflow/core/framework/graph.pb.h"
@@ -46,8 +47,8 @@ const int INT8MODE = 2;
struct EngineConnection {
EngineConnection(const string& outside, int out_id, int out_port,
- const string& inside, int in_id, int in_port,
- bool input_edge, int port)
+ const string& inside, int in_id, int in_port,
+ bool input_edge, int port)
: outside_node_name(outside),
outside_id(out_id),
outside_port(out_port),
@@ -104,6 +105,8 @@ struct EngineInfo {
// topological order.
// - segment_def: the output GraphDef, whose non-input/output nodedefs will be
// sorted in topological order.
+//
+// TODO(aaroey): add tests to validate these properties.
tensorflow::Status ConvertSegmentToGraphDef(
const tensorflow::Graph* graph,
const tensorflow::grappler::GraphProperties& graph_properties,
@@ -128,6 +131,30 @@ tensorflow::Status ConvertGraphDefToEngine(
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
bool* convert_successfully);
+// Helper class for the segmenter to determine whether an input edge to the TRT
+// segment is valid.
+class InputEdgeValidator {
+ public:
+ InputEdgeValidator(const grappler::GraphProperties& graph_properties)
+ : graph_properties_(graph_properties) {}
+
+ // Return true if the specified edge is eligible to be an input edge of the
+ // TRT segment.
+ bool operator()(const tensorflow::Edge* in_edge) const;
+
+ private:
+ const grappler::GraphProperties& graph_properties_;
+};
+
+// Helper class for the segmenter to determine whether an output edge from the
+// TRT segment is valid.
+class OutputEdgeValidator {
+ public:
+ // Return true if the specified edge is eligible to be an output edge of the
+ // TRT segment.
+ bool operator()(const tensorflow::Edge* out_edge) const;
+};
+
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow