diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert/convert_nodes.h')
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_nodes.h | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 1a4c0e755d..6ae60ec352 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -23,6 +23,7 @@ limitations under the License. #include <vector> #include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" #include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" #include "tensorflow/core/framework/graph.pb.h" @@ -104,6 +105,8 @@ struct EngineInfo { // topological order. // - segment_def: the output GraphDef, whose non-input/output nodedefs will be // sorted in topological order. +// +// TODO(aaroey): add tests to validate these properties. tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, @@ -128,6 +131,30 @@ tensorflow::Status ConvertGraphDefToEngine( TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool* convert_successfully); +// Helper class for the segmenter to determine whether an input edge to the TRT +// segment is valid. +class InputEdgeValidator { + public: + InputEdgeValidator(const grappler::GraphProperties& graph_properties) + : graph_properties_(graph_properties) {} + + // Return true if the specified edge is eligible to be an input edge of the + // TRT segment. + bool operator()(const tensorflow::Edge* in_edge) const; + + private: + const grappler::GraphProperties& graph_properties_; +}; + +// Helper class for the segmenter to determine whether an output edge from the +// TRT segment is valid. +class OutputEdgeValidator { + public: + // Return true if the specified edge is eligible to be an output edge of the + // TRT segment. + bool operator()(const tensorflow::Edge* out_edge) const; +}; + } // namespace convert } // namespace tensorrt } // namespace tensorflow |