diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert/convert_nodes.h')
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_nodes.h | 31 |
1 files changed, 29 insertions, 2 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 7684d8d4a2..6ae60ec352 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -23,6 +23,7 @@ limitations under the License. #include <vector> #include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" #include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" #include "tensorflow/core/framework/graph.pb.h" @@ -46,8 +47,8 @@ const int INT8MODE = 2; struct EngineConnection { EngineConnection(const string& outside, int out_id, int out_port, - const string& inside, int in_id, int in_port, - bool input_edge, int port) + const string& inside, int in_id, int in_port, + bool input_edge, int port) : outside_node_name(outside), outside_id(out_id), outside_port(out_port), @@ -104,6 +105,8 @@ struct EngineInfo { // topological order. // - segment_def: the output GraphDef, whose non-input/output nodedefs will be // sorted in topological order. +// +// TODO(aaroey): add tests to validate these properties. tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, @@ -128,6 +131,30 @@ tensorflow::Status ConvertGraphDefToEngine( TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool* convert_successfully); +// Helper class for the segmenter to determine whether an input edge to the TRT +// segment is valid. +class InputEdgeValidator { + public: + InputEdgeValidator(const grappler::GraphProperties& graph_properties) + : graph_properties_(graph_properties) {} + + // Return true if the specified edge is eligible to be an input edge of the + // TRT segment. + bool operator()(const tensorflow::Edge* in_edge) const; + + private: + const grappler::GraphProperties& graph_properties_; +}; + +// Helper class for the segmenter to determine whether an output edge from the +// TRT segment is valid. +class OutputEdgeValidator { + public: + // Return true if the specified edge is eligible to be an output edge of the + // TRT segment. + bool operator()(const tensorflow::Edge* out_edge) const; +}; + } // namespace convert } // namespace tensorrt } // namespace tensorflow |