diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert/convert_nodes.h')
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_nodes.h | 53 |
1 files changed, 9 insertions, 44 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 954a1e72f8..2e7fd19566 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -17,8 +17,6 @@ limitations under the License. #define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ #include <set> -#include <string> -#include <unordered_map> #include <utility> #include <vector> @@ -34,49 +32,16 @@ namespace tensorflow { namespace tensorrt { namespace convert { -const int FP32MODE = 0; -const int FP16MODE = 1; -const int INT8MODE = 2; +tensorflow::Status ConvertSubGraphToTensorRTNodeDef( + const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids, + const std::vector<std::pair<int, int>>& + input_inds, // {node_id, output_idx} + const std::vector<std::pair<int, int>>& + output_inds, // {node_id, output_idx} + size_t max_batch_size, size_t max_workspace_size_bytes, + const tensorflow::grappler::GraphProperties& graph_prop, + tensorflow::NodeDef* trt_node); -struct SubGraphParams { - SubGraphParams( - tensorflow::Graph& inp_graph, - const std::set<int>& subgraph_node_id_numbers, - const std::vector<std::pair<int, int>>& input_indices, - const std::vector<std::pair<int, int>>& output_indices, - size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes, - const tensorflow::grappler::GraphProperties& current_graph_properties, - std::unordered_map<string, std::pair<int, string>>* output_edges, - tensorflow::NodeDef* constructed_trt_node, - int engine_precision_mode = FP32MODE) - : graph(inp_graph), - subgraph_node_ids(subgraph_node_id_numbers), - input_inds(input_indices), - output_inds(output_indices), - max_batch_size(max_supported_batch_size), - max_workspace_size_bytes(max_consumed_workspace_size_bytes), - graph_properties(current_graph_properties), - output_edge_map(output_edges), - trt_node(constructed_trt_node), - precision_mode(engine_precision_mode) {} - - tensorflow::Graph& graph; - const std::set<int>& subgraph_node_ids; - const std::vector<std::pair<int, int>>& input_inds; // {node_id, output_idx} - const std::vector<std::pair<int, int>>& output_inds; // {node_id, output_idx} - size_t max_batch_size; - size_t max_workspace_size_bytes; - const tensorflow::grappler::GraphProperties& graph_properties; - std::unordered_map<string, std::pair<int, string>>* output_edge_map; - tensorflow::NodeDef* trt_node; - const int precision_mode; -}; - -// TODO(sami): Replace references with const reference or pointers -tensorflow::Status ConvertSubGraphToTensorRTNodeDef(SubGraphParams& params); -tensorflow::Status InjectCalibrationNode(SubGraphParams& params); -tensorflow::Status ConvertCalibrationNodeToEngineNode(tensorflow::Graph& graph, - tensorflow::Node* c_node); } // namespace convert } // namespace tensorrt } // namespace tensorflow |