aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/convert/convert_nodes.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert/convert_nodes.h')
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h53
1 files changed, 44 insertions, 9 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index 2e7fd19566..954a1e72f8 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -17,6 +17,8 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
#include <set>
+#include <string>
+#include <unordered_map>
#include <utility>
#include <vector>
@@ -32,16 +34,49 @@ namespace tensorflow {
namespace tensorrt {
namespace convert {
-tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
- const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
- const std::vector<std::pair<int, int>>&
- input_inds, // {node_id, output_idx}
- const std::vector<std::pair<int, int>>&
- output_inds, // {node_id, output_idx}
- size_t max_batch_size, size_t max_workspace_size_bytes,
- const tensorflow::grappler::GraphProperties& graph_prop,
- tensorflow::NodeDef* trt_node);
+const int FP32MODE = 0;
+const int FP16MODE = 1;
+const int INT8MODE = 2;
+struct SubGraphParams {
+ SubGraphParams(
+ tensorflow::Graph& inp_graph,
+ const std::set<int>& subgraph_node_id_numbers,
+ const std::vector<std::pair<int, int>>& input_indices,
+ const std::vector<std::pair<int, int>>& output_indices,
+ size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes,
+ const tensorflow::grappler::GraphProperties& current_graph_properties,
+ std::unordered_map<string, std::pair<int, string>>* output_edges,
+ tensorflow::NodeDef* constructed_trt_node,
+ int engine_precision_mode = FP32MODE)
+ : graph(inp_graph),
+ subgraph_node_ids(subgraph_node_id_numbers),
+ input_inds(input_indices),
+ output_inds(output_indices),
+ max_batch_size(max_supported_batch_size),
+ max_workspace_size_bytes(max_consumed_workspace_size_bytes),
+ graph_properties(current_graph_properties),
+ output_edge_map(output_edges),
+ trt_node(constructed_trt_node),
+ precision_mode(engine_precision_mode) {}
+
+ tensorflow::Graph& graph;
+ const std::set<int>& subgraph_node_ids;
+ const std::vector<std::pair<int, int>>& input_inds; // {node_id, output_idx}
+ const std::vector<std::pair<int, int>>& output_inds; // {node_id, output_idx}
+ size_t max_batch_size;
+ size_t max_workspace_size_bytes;
+ const tensorflow::grappler::GraphProperties& graph_properties;
+ std::unordered_map<string, std::pair<int, string>>* output_edge_map;
+ tensorflow::NodeDef* trt_node;
+ const int precision_mode;
+};
+
+// TODO(sami): Replace references with const reference or pointers
+tensorflow::Status ConvertSubGraphToTensorRTNodeDef(SubGraphParams& params);
+tensorflow::Status InjectCalibrationNode(SubGraphParams& params);
+tensorflow::Status ConvertCalibrationNodeToEngineNode(tensorflow::Graph& graph,
+ tensorflow::Node* c_node);
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow