aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/convert/convert_nodes.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert/convert_nodes.h')
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h53
1 files changed, 9 insertions, 44 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index 954a1e72f8..2e7fd19566 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -17,8 +17,6 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
#include <set>
-#include <string>
-#include <unordered_map>
#include <utility>
#include <vector>
@@ -34,49 +32,16 @@ namespace tensorflow {
namespace tensorrt {
namespace convert {
-const int FP32MODE = 0;
-const int FP16MODE = 1;
-const int INT8MODE = 2;
+tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
+ const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
+ const std::vector<std::pair<int, int>>&
+ input_inds, // {node_id, output_idx}
+ const std::vector<std::pair<int, int>>&
+ output_inds, // {node_id, output_idx}
+ size_t max_batch_size, size_t max_workspace_size_bytes,
+ const tensorflow::grappler::GraphProperties& graph_prop,
+ tensorflow::NodeDef* trt_node);
-struct SubGraphParams {
- SubGraphParams(
- tensorflow::Graph& inp_graph,
- const std::set<int>& subgraph_node_id_numbers,
- const std::vector<std::pair<int, int>>& input_indices,
- const std::vector<std::pair<int, int>>& output_indices,
- size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes,
- const tensorflow::grappler::GraphProperties& current_graph_properties,
- std::unordered_map<string, std::pair<int, string>>* output_edges,
- tensorflow::NodeDef* constructed_trt_node,
- int engine_precision_mode = FP32MODE)
- : graph(inp_graph),
- subgraph_node_ids(subgraph_node_id_numbers),
- input_inds(input_indices),
- output_inds(output_indices),
- max_batch_size(max_supported_batch_size),
- max_workspace_size_bytes(max_consumed_workspace_size_bytes),
- graph_properties(current_graph_properties),
- output_edge_map(output_edges),
- trt_node(constructed_trt_node),
- precision_mode(engine_precision_mode) {}
-
- tensorflow::Graph& graph;
- const std::set<int>& subgraph_node_ids;
- const std::vector<std::pair<int, int>>& input_inds; // {node_id, output_idx}
- const std::vector<std::pair<int, int>>& output_inds; // {node_id, output_idx}
- size_t max_batch_size;
- size_t max_workspace_size_bytes;
- const tensorflow::grappler::GraphProperties& graph_properties;
- std::unordered_map<string, std::pair<int, string>>* output_edge_map;
- tensorflow::NodeDef* trt_node;
- const int precision_mode;
-};
-
-// TODO(sami): Replace references with const reference or pointers
-tensorflow::Status ConvertSubGraphToTensorRTNodeDef(SubGraphParams& params);
-tensorflow::Status InjectCalibrationNode(SubGraphParams& params);
-tensorflow::Status ConvertCalibrationNodeToEngineNode(tensorflow::Graph& graph,
- tensorflow::Node* c_node);
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow