aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-01-30 12:00:23 -0800
committerGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-01-30 12:00:23 -0800
commitd7b4fe4d4322a3fdab8a1dedb93d37a1f800a559 (patch)
tree4f90245cbf13ee5de46080752f70e761f9545c4e
parent864d477a9923b1514f3cedb9bcebe45e65227663 (diff)
Fix the build dependencies and formatting of the code, and make sure
they follow the style conventions.
-rw-r--r--tensorflow/contrib/BUILD2
-rw-r--r--tensorflow/contrib/tensorrt/BUILD67
-rw-r--r--tensorflow/contrib/tensorrt/README.md12
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc41
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.h8
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc52
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h10
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc23
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.h14
-rw-r--r--tensorflow/contrib/tensorrt/log/trt_logger.cc6
-rw-r--r--tensorflow/contrib/tensorrt/log/trt_logger.h3
-rw-r--r--tensorflow/contrib/tensorrt/ops/trt_engine_op.cc6
-rw-r--r--tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py2
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert.py31
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc1
-rw-r--r--tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc51
-rw-r--r--tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h7
-rw-r--r--tensorflow/contrib/tensorrt/trt_conversion.i122
18 files changed, 236 insertions, 222 deletions
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index f745c175b1..738b74c929 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -105,7 +105,7 @@ py_library(
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:util",
- ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_ops_py"])
+ ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"])
+ if_tensorrt(["//tensorflow/contrib/tensorrt:init_py"]),
)
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 2d4bccd267..3a214d2e86 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -17,7 +17,6 @@ load(
"tf_gen_op_wrapper_py",
"tf_py_wrap_cc",
"tf_cc_test",
- "tf_kernel_library",
"tf_cuda_cc_test",
"tf_cuda_library",
"tf_custom_op_py_library",
@@ -47,13 +46,9 @@ tf_cuda_cc_test(
tf_custom_op_library(
name = "python/ops/_trt_engine_op.so",
- srcs = [
- "kernels/trt_engine_op.cc",
- "kernels/trt_engine_op.h",
- "ops/trt_engine_op.cc",
- ],
- gpu_srcs = [],
+ srcs = ["ops/trt_engine_op.cc"],
deps = [
+ ":trt_engine_op_kernel",
":trt_shape_function",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core/kernels:bounds_check_lib",
@@ -64,10 +59,9 @@ tf_custom_op_library(
tf_cuda_library(
name = "trt_shape_function",
- srcs = [
- "shape_fn/trt_shfn.cc",
- ],
+ srcs = ["shape_fn/trt_shfn.cc"],
hdrs = ["shape_fn/trt_shfn.h"],
+ visibility = ["//visibility:public"],
deps = [
":trt_logging",
"//tensorflow/core:framework_headers_lib",
@@ -76,36 +70,26 @@ tf_cuda_library(
"@nsync//:nsync_headers",
"@protobuf_archive//:protobuf",
],
- visibility = ["//visibility:public"], #for c/c++ linking
)
-tf_kernel_library(
+cc_library(
name = "trt_engine_op_kernel",
- srcs = [
- "kernels/trt_engine_op.cc",
- ],
- hdrs = [
- "kernels/trt_engine_op.h",
- ],
- gpu_srcs = [
- ],
+ srcs = ["kernels/trt_engine_op.cc"],
+ hdrs = ["kernels/trt_engine_op.h"],
deps = [
":trt_logging",
- ":trt_shape_function",
- "//tensorflow/core:framework",
+ "//tensorflow/core:framework_headers_lib",
"//tensorflow/core:gpu_headers_lib",
- "//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
"//third_party/eigen3",
"@local_config_tensorrt//:nv_infer",
+ "@nsync//:nsync_headers",
],
alwayslink = 1,
)
tf_gen_op_libs(
- op_lib_names = [
- "trt_engine_op",
- ],
+ op_lib_names = ["trt_engine_op"],
deps = [
"@local_config_tensorrt//:nv_infer",
],
@@ -113,12 +97,8 @@ tf_gen_op_libs(
tf_cuda_library(
name = "trt_logging",
- srcs = [
- "log/trt_logger.cc",
- ],
- hdrs = [
- "log/trt_logger.h",
- ],
+ srcs = ["log/trt_logger.cc"],
+ hdrs = ["log/trt_logger.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:lib_proto_parsing",
@@ -190,6 +170,7 @@ tf_py_wrap_cc(
],
)
+# Library for the node-level conversion portion of TensorRT operation creation
tf_cuda_library(
name = "trt_conversion",
srcs = [
@@ -201,29 +182,27 @@ tf_cuda_library(
"convert/convert_nodes.h",
],
deps = [
- "@local_config_tensorrt//:nv_infer",
- "@protobuf_archive//:protobuf_headers",
- "@nsync//:nsync_headers",
":segment",
":trt_logging",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:framework_headers_lib",
"//tensorflow/core:framework_lite",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core:framework_headers_lib",
- "//tensorflow/core:core_cpu_base",
- "//tensorflow/core/grappler/optimizers:constant_folding",
- "//tensorflow/core/grappler/optimizers:layout_optimizer",
- "//tensorflow/core/grappler/clusters:virtual_cluster",
"//tensorflow/core/grappler:devices",
+ "//tensorflow/core/grappler/clusters:virtual_cluster",
"//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/optimizers:constant_folding",
+ "//tensorflow/core/grappler/optimizers:layout_optimizer",
+ "@local_config_tensorrt//:nv_infer",
+ "@nsync//:nsync_headers",
+ "@protobuf_archive//:protobuf_headers",
],
)
# Library for the segmenting portion of TensorRT operation creation
cc_library(
name = "segment",
- srcs = [
- "segment/segment.cc",
- ],
+ srcs = ["segment/segment.cc"],
hdrs = [
"segment/segment.h",
"segment/union_find.h",
@@ -249,8 +228,6 @@ tf_cc_test(
],
)
-# Library for the node-level conversion portion of TensorRT operation creation
-
filegroup(
name = "cppfiles",
srcs = glob(["**/*.cc"]),
diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md
index b362050983..b3c604f5f8 100644
--- a/tensorflow/contrib/tensorrt/README.md
+++ b/tensorflow/contrib/tensorrt/README.md
@@ -28,12 +28,12 @@ will be available. An example use is shown below.
import tensorflow as tf
import tensorflow.contrib.tensorrt as trt
#... create and train or load model
-gdef=sess.graph.as_graph_def()
-trt_gdef=trt.CreateInferenceGraph(gdef, #original graph_def
- ["output"], #name of output node(s)
- max_batch_size, #maximum batch size to run the inference
- max_workspace_size # max memory for TensorRT to use
- )
+gdef = sess.graph.as_graph_def()
+trt_gdef = trt.CreateInferenceGraph(
+ gdef, #original graph_def
+ ["output"], #name of output node(s)
+ max_batch_size, #maximum batch size to run the inference
+ max_workspace_size) # max memory for TensorRT to use
tf.reset_default_graph()
tf.import_graph_def(graph_def=trt_gdef)
#...... run inference
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index b8dbc7b7c8..1507981ca8 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+
#include <list>
#include <map>
#include <set>
@@ -28,10 +30,6 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/platform/logging.h"
-
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/devices.h"
@@ -39,11 +37,13 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/protobuf/device_properties.pb.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
-#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/contrib/tensorrt/segment/segment.h"
#include "tensorrt/include/NvInfer.h"
@@ -119,14 +119,15 @@ std::unordered_map<std::string, std::vector<int>> BuildTensorNameMap(
}
tensorflow::Status ConvertSubGraphToTensorRT(
- tensorflow::Graph& graph, const std::vector<std::string>& output_names,
+ const std::vector<std::string>& output_names,
const std::set<int>& subgraph_node_ids,
size_t max_batch_size, // max batch size that engine will be created for
// max amount of memory that engine will be allowed to consume, in bytes
size_t max_workspace_size,
- const tensorflow::grappler::GraphProperties& graph_properties) {
+ const tensorflow::grappler::GraphProperties& graph_properties,
+ tensorflow::Graph* graph) {
tensorflow::EdgeSet subgraph_incoming_edges;
- GetSubGraphIncomingEdges(graph, subgraph_node_ids, &subgraph_incoming_edges);
+ GetSubGraphIncomingEdges(*graph, subgraph_node_ids, &subgraph_incoming_edges);
std::vector<std::pair<int, int>> subgraph_inputs;
@@ -138,7 +139,7 @@ tensorflow::Status ConvertSubGraphToTensorRT(
// Collect outputs referenced from output_names
auto output_name_to_index_map = BuildTensorNameMap(output_names);
for (int node_id : subgraph_node_ids) {
- tensorflow::Node* node = graph.FindNodeId(node_id);
+ tensorflow::Node* node = graph->FindNodeId(node_id);
if (output_name_to_index_map.count(node->name())) {
for (int index : output_name_to_index_map.at(node->name())) {
subgraph_outputs_set.insert({node_id, index});
@@ -147,7 +148,7 @@ tensorflow::Status ConvertSubGraphToTensorRT(
}
// Collect outputs referenced from outgoing edges
tensorflow::EdgeSet subgraph_outgoing_edges;
- GetSubGraphOutgoingEdges(graph, subgraph_node_ids, &subgraph_outgoing_edges);
+ GetSubGraphOutgoingEdges(*graph, subgraph_node_ids, &subgraph_outgoing_edges);
for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()});
}
@@ -157,10 +158,10 @@ tensorflow::Status ConvertSubGraphToTensorRT(
// Build TensorRT node and add it to the graph
tensorflow::NodeDef trt_node_def;
TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(
- graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs,
+ *graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs,
max_batch_size, max_workspace_size, graph_properties, &trt_node_def));
tensorflow::Status status;
- tensorflow::Node* trt_node = graph.AddNode(trt_node_def, &status);
+ tensorflow::Node* trt_node = graph->AddNode(trt_node_def, &status);
TF_RETURN_IF_ERROR(status);
@@ -173,16 +174,16 @@ tensorflow::Status ConvertSubGraphToTensorRT(
for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
int new_src_output = subgraph_edge_to_output_map.at(old_src);
- graph.UpdateEdge(trt_node, new_src_output, edge->dst(), edge->dst_input());
+ graph->UpdateEdge(trt_node, new_src_output, edge->dst(), edge->dst_input());
}
// Remove the original subgraph
for (int node_id : subgraph_node_ids) {
- tensorflow::Node* node = graph.FindNodeId(node_id);
+ tensorflow::Node* node = graph->FindNodeId(node_id);
// Don't remove the input placeholders
if (node->type_string() == "Placeholder") {
continue;
}
- graph.RemoveNode(node);
+ graph->RemoveNode(node);
}
return tensorflow::Status::OK();
}
@@ -213,16 +214,16 @@ tensorflow::Status ConvertGraphDefToTensorRT(
// layout optimization
item.graph = graph_def;
tensorflow::grappler::LayoutOptimizer optimizer;
- tensorflow::grappler::Cluster* gCluster;
+ tensorflow::grappler::Cluster* cluster;
// virtual cluster
tensorflow::DeviceProperties device_properties;
device_properties.set_type("GPU");
device_properties.mutable_environment()->insert({"architecture", "6"});
- gCluster =
+ cluster =
new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}});
- tensorflow::Status status = optimizer.Optimize(gCluster, item, &gdef);
+ tensorflow::Status status = optimizer.Optimize(cluster, item, &gdef);
if (status != tensorflow::Status::OK()) return status;
@@ -267,8 +268,8 @@ tensorflow::Status ConvertGraphDefToTensorRT(
subgraph_node_ids.insert(node_map.at(node_name)->id());
}
TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT(
- graph, output_names, subgraph_node_ids, max_batch_size,
- max_workspace_size, static_graph_properties));
+ output_names, subgraph_node_ids, max_batch_size, max_workspace_size,
+ static_graph_properties, &graph));
}
graph.ToGraphDef(new_graph_def);
return tensorflow::Status::OK();
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h
index 621d428ace..e0fa02ecd4 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h
@@ -21,6 +21,9 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
namespace tensorflow {
namespace tensorrt {
namespace convert {
@@ -32,7 +35,12 @@ tensorflow::Status ConvertGraphDefToTensorRT(
const tensorflow::GraphDef& graph_def,
const std::vector<std::string>& output_names, size_t max_batch_size,
size_t max_workspace_size, tensorflow::GraphDef* new_graph_def);
+
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 60e6a1ab96..a42f559651 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
+
#include <algorithm>
-#include <fstream>
#include <list>
#include <map>
#include <memory>
#include <set>
-#include <sstream>
#include <string>
#include <unordered_map>
#include <utility>
@@ -36,16 +36,13 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
-// Check if the types are equal. Cast to int first so that failure log message
-// would work!
-
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
-
-#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorrt/include/NvInfer.h"
+// Check if the types are equal. Cast to int first so that failure log message
+// would work!
#define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
namespace tensorflow {
@@ -107,8 +104,8 @@ static std::vector<std::pair<int, int>> createSamePadding(
int right = p - left;
VLOG(-1) << "PADDING_" << i << " pre: " << left << ", post: " << right
- << "paras: " << inputDims[i] << ", " << stride.d[i] << ", "
- << "kernel: " << kernel.d[i];
+ << "paras: " << inputDims[i] << ", " << stride.d[i] << ", "
+ << "kernel: " << kernel.d[i];
padding[i] = {left, right};
}
return padding;
@@ -664,7 +661,7 @@ tensorflow::Status ConstantFoldBinary(
nvinfer1::Dims output_shape;
output_shape.nbDims = nbDims;
VLOG(-1) << "nbDims: " << nbDims
- << "the other: " << weights_input_r.shape_.nbDims;
+ << "the other: " << weights_input_r.shape_.nbDims;
for (int i = 0; i < nbDims; i++) {
if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) {
output_shape.d[i] = weights_input_l.shape_.d[i];
@@ -677,8 +674,8 @@ tensorflow::Status ConstantFoldBinary(
"Binary op with incompatible shape at, " + node_def.op());
}
VLOG(-1) << "left: " << weights_input_l.shape_.d[i]
- << "right: " << weights_input_r.shape_.d[i]
- << "output: " << output_shape.d[i];
+ << "right: " << weights_input_r.shape_.d[i]
+ << "output: " << output_shape.d[i];
}
// FIXME assume type matches input weights
@@ -822,9 +819,9 @@ tensorflow::Status BinaryTensorOpTensor(
CHECK_EQ_TYPE(tensor_r->getType(), dtype);
auto op_pair = ops.find(node_def.op());
if (op_pair == ops.end())
- return tensorflow::errors::Unimplemented("binary op: " + node_def.op() +
- " not supported at: " +
- node_def.name());
+ return tensorflow::errors::Unimplemented(
+ "binary op: " + node_def.op() +
+ " not supported at: " + node_def.name());
nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
*const_cast<nvinfer1::ITensor*>(tensor_l),
@@ -909,11 +906,11 @@ tensorflow::Status ConvertConv2D(Converter& ctx,
padding[1].first != padding[1].second) {
// TODO(jie): handle asymmetric padding
VLOG(-1) << "padding!!!: " << padding[0].first << padding[0].second
- << padding[1].first << padding[1].second;
+ << padding[1].first << padding[1].second;
auto dim_before = tensor->getDimensions();
- VLOG(-1) << "TENSOR before: " << dim_before.d[0] << ", "
- << dim_before.d[1] << dim_before.d[2] << ", " << dim_before.d[3];
+ VLOG(-1) << "TENSOR before: " << dim_before.d[0] << ", " << dim_before.d[1]
+ << dim_before.d[2] << ", " << dim_before.d[3];
auto padLayer = ctx.network()->addPadding(
*const_cast<nvinfer1::ITensor*>(tensor),
nvinfer1::DimsHW(padding[0].first, padding[1].first),
@@ -922,7 +919,7 @@ tensorflow::Status ConvertConv2D(Converter& ctx,
tensor = padLayer->getOutput(0);
auto dim_after = tensor->getDimensions();
VLOG(-1) << "TENSOR after: " << dim_after.d[0] << ", " << dim_after.d[1]
- << dim_after.d[2] << ", " << dim_after.d[3];
+ << dim_after.d[2] << ", " << dim_after.d[3];
}
nvinfer1::IConvolutionLayer* layer =
@@ -936,7 +933,7 @@ tensorflow::Status ConvertConv2D(Converter& ctx,
auto dim_after = output_tensor->getDimensions();
VLOG(-1) << "TENSOR out: " << dim_after.d[0] << ", " << dim_after.d[1]
- << dim_after.d[2] << ", " << dim_after.d[3];
+ << dim_after.d[2] << ", " << dim_after.d[3];
if (data_format == "NHWC") {
// TODO(jie): transpose it back!
@@ -992,8 +989,7 @@ tensorflow::Status ConvertPool(Converter& ctx,
{static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
} else if (attrs.get<std::string>("padding") == "VALID") {
// No padding for valid padding here
- VLOG(-1) << "no padding added for VALID padding in pool"
- << node_def.name();
+ VLOG(-1) << "no padding added for VALID padding in pool" << node_def.name();
padding = {{0, 0}, {0, 0}};
} else {
return tensorflow::errors::Unimplemented(
@@ -1004,7 +1000,7 @@ tensorflow::Status ConvertPool(Converter& ctx,
padding[1].first != padding[1].second) {
// TODO(jie): handle asymmetric padding
VLOG(-1) << "padding!!!: " << padding[0].first << padding[0].second
- << padding[1].first << padding[1].second;
+ << padding[1].first << padding[1].second;
auto padLayer = ctx.network()->addPadding(
*const_cast<nvinfer1::ITensor*>(tensor),
nvinfer1::DimsHW(padding[0].first, padding[1].first),
@@ -1480,9 +1476,9 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
TF_CHECK_OK(convert_dtype(tf_dtype, &dtype));
VLOG(-1) << "accessing output index of: " << std::to_string(output_idx)
- << ", at node: " << node_name
- << "with output entry from shape_map: "
- << std::to_string(op_info_vec.size());
+ << ", at node: " << node_name
+ << "with output entry from shape_map: "
+ << std::to_string(op_info_vec.size());
// TODO(ben,jie): update TRT input format/dimension
nvinfer1::DimsCHW input_dim_psuedo_chw;
@@ -1490,7 +1486,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
for (int i = 1; i < op_info.shape().dim_size(); i++) {
VLOG(-1) << "dimension: " << i
- << " , size: " << op_info.shape().dim(i).size();
+ << " , size: " << op_info.shape().dim(i).size();
input_dim_psuedo_chw.d[i - 1] = op_info.shape().dim(i).size();
}
@@ -1517,7 +1513,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
for (const tensorflow::Node* node : order) {
tensorflow::NodeDef const& node_def = node->def();
VLOG(-1) << "converting node: " << node_def.name() << " , "
- << node_def.op();
+ << node_def.op();
TF_RETURN_IF_ERROR(converter.convert_node(node_def));
}
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index a1f9c3f4a1..69657e0cb9 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -25,6 +25,9 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/lib/core/status.h"
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
namespace tensorflow {
namespace tensorrt {
namespace convert {
@@ -35,12 +38,15 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
input_inds, // {node_id, output_idx}
const std::vector<std::pair<int, int>>&
output_inds, // {node_id, output_idx}
- size_t max_batch_size,
- size_t max_workspace_size,
+ size_t max_batch_size, size_t max_workspace_size,
const tensorflow::grappler::GraphProperties& graph_prop,
tensorflow::NodeDef* trt_node);
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index dc8b625731..6e4fbe20e7 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -12,21 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <sstream>
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/stream_executor.h"
+#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
-
-#include <cuda_runtime_api.h>
-#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
+#include "cuda/include/cuda_runtime_api.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
-static ::tensorflow::tensorrt::Logger gLogger;
-
namespace tensorrt {
+static ::tensorflow::tensorrt::Logger logger;
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
// read serialized_engine
@@ -39,14 +35,14 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_));
// TODO(samikama) runtime should be taken from a resourcemanager as well.
- // Only engine should be in the op and context and runtime should be taken
- // from resourcemanager
- nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(gLogger);
+ // Only engine should be in the op and context and runtime should be taken
+ // from resourcemanager
+ nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
trt_engine_ptr_.reset(infer->deserializeCudaEngine(
serialized_engine.c_str(), serialized_engine.size(), nullptr));
trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
- // runtime is safe to delete after engine creation
+ // Runtime is safe to delete after engine creation
infer->destroy();
}
@@ -89,7 +85,7 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
// This is bad that we have to reallocate output buffer every run.
// Create an output tensor
binding_index = trt_engine_ptr_->getBindingIndex(output_nodes_[i].c_str());
- Tensor* output_tensor = NULL;
+ Tensor* output_tensor = nullptr;
TensorShape output_shape;
if (binding_index != -1) {
@@ -131,6 +127,7 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
}
REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
+
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
index 0e3ff45ede..0964b4b18a 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
@@ -20,18 +20,17 @@ limitations under the License.
#include <string>
#include <vector>
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
-#include <cuda_runtime_api.h>
+#include "cuda/include/cuda_runtime_api.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {
-
namespace tensorrt {
class Logger;
+
class TRTEngineOp : public OpKernel {
public:
explicit TRTEngineOp(OpKernelConstruction* context);
@@ -43,17 +42,18 @@ class TRTEngineOp : public OpKernel {
struct Destroyer {
void operator()(T* d) { d->destroy(); }
};
+
template <typename T>
using destroyed_ptr = std::unique_ptr<T, Destroyer<T>>;
destroyed_ptr<nvinfer1::ICudaEngine> trt_engine_ptr_;
- // TODO(samikama) context should go to a resource manager!
+ // TODO(samikama): context should go to a resource manager!
destroyed_ptr<nvinfer1::IExecutionContext> trt_execution_context_ptr_;
+
std::vector<string> input_nodes_;
std::vector<string> output_nodes_;
};
} // namespace tensorrt
-
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.cc b/tensorflow/contrib/tensorrt/log/trt_logger.cc
index 2473b8effc..5131c80794 100644
--- a/tensorflow/contrib/tensorrt/log/trt_logger.cc
+++ b/tensorflow/contrib/tensorrt/log/trt_logger.cc
@@ -13,16 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
-
-#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/core/platform/logging.h"
-// Use TF logging for TensorRT informations
namespace tensorflow {
namespace tensorrt {
+// Use TF logging for TensorRT informations
void Logger::log(Severity severity, const char* msg) {
// Suppress info-level messages
switch (severity) {
diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h
index c07a3e6b2d..0dc2b1708b 100644
--- a/tensorflow/contrib/tensorrt/log/trt_logger.h
+++ b/tensorflow/contrib/tensorrt/log/trt_logger.h
@@ -1,4 +1,3 @@
-// -*- c++ -*-
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -28,9 +27,9 @@ namespace tensorrt {
// Logger for GIE info/warning/errors
class Logger : public nvinfer1::ILogger {
+ private:
void log(nvinfer1::ILogger::Severity severity, const char* msg) override;
- private:
std::string name_;
};
diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
index 7139ff9618..fa72bce039 100644
--- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -34,4 +37,7 @@ REGISTER_OP("TRTEngineOp")
.Output("out_tensor: OutT")
.SetShapeFn(shape_inference::TRTEngineOpShapeInference);
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py b/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py
index c4ab9b89ea..97db23797f 100644
--- a/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py
+++ b/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py
@@ -31,5 +31,3 @@ if platform.system() != "Windows":
resource_loader.get_path_to_datafile("_trt_engine_op.so"))
else:
raise RuntimeError("Windows platforms are not supported")
-
-
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
index f6d2dbede6..6bdc20ed04 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert.py
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -29,9 +29,13 @@ from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
+
# TODO(skama): get outputs from session when implemented as c++
# optimization pass
-def CreateInferenceGraph(input_graph_def, outputs,max_batch_size=1,max_workspace_size=2<<20):
+def CreateInferenceGraph(input_graph_def,
+ outputs,
+ max_batch_size=1,
+ max_workspace_size=2 << 20):
"""Python wrapper for the TRT transormation.
@@ -45,35 +49,34 @@ def CreateInferenceGraph(input_graph_def, outputs,max_batch_size=1,max_workspace
New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
"""
- out_names=[]
+ out_names = []
for i in outputs:
- if isinstance(i,ops.Tensor):
+ if isinstance(i, ops.Tensor):
out_names.append(i.name)
else:
out_names.append(i)
-
- input_graph_def_str= \
- input_graph_def.SerializeToString()
+
+ input_graph_def_str = input_graph_def.SerializeToString()
# TODO(sami): Fix this when we can return status from C++ library
# There is a problem with the TF internal library setup that doesn't
# allow us to return a status object from C++. Thus we return a
# pair or strings where first one is encoded status and the second
# one is the transformed graphs protobuf string.
- out = trt_convert(
- input_graph_def_str ,outputs,
- max_batch_size,max_workspace_size)
+ out = trt_convert(input_graph_def_str, outputs, max_batch_size,
+ max_workspace_size)
status = out[0]
output_graph_def_string = out[1]
- del input_graph_def_str #save some memory
+ del input_graph_def_str #save some memory
if len(status) < 2:
- raise _impl.UnknownError(None,None,status)
+ raise _impl.UnknownError(None, None, status)
if status[:2] != "OK":
- msg=status.split(";")
+ msg = status.split(";")
if len(msg) == 1:
raise RuntimeError("Status message is malformed {}".format(status))
- raise _impl._make_specific_exception(None,None,";".join(msg[1:]), int(msg[0]))
+ raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
+ int(msg[0]))
output_graph_def = graph_pb2.GraphDef()
output_graph_def.ParseFromString(output_graph_def_string)
- del output_graph_def_string #save some memory
+ del output_graph_def_string #save some memory
return output_graph_def
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
index 89457b71e8..c9d3840606 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -68,7 +68,6 @@ bool CanContractEdge(const tensorflow::Edge* edge,
return !is_cycle;
}
-//------------------------------------------------------------------------------
void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
std::vector<const tensorflow::Edge*>* remove_edges) {
// Transfer all inputs and outputs of 'dst' to 'src' except edges
diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
index fef63c64d8..ebaf996a29 100644
--- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
+++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
@@ -13,71 +13,74 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h"
+
#include <string>
#include <vector>
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
-#include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h"
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {
namespace shape_inference {
-tensorflow::Status TRTEngineOpShapeInference(InferenceContext* c) {
+tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) {
tensorflow::tensorrt::Logger logger;
string serialized_engine;
- c->GetAttr("serialized_engine", &serialized_engine);
+ context->GetAttr("serialized_engine", &serialized_engine);
nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine(
serialized_engine.c_str(), serialized_engine.size(), nullptr);
- int nbBatch = -1;
- // debug print out input arrays
+ int num_batch = -1;
std::vector<::tensorflow::DataType> input_type;
- c->GetAttr("InT", &input_type);
- for (size_t i = 0; i < c->num_inputs(); i++) {
- // check if input shape is legit
- auto input_shape = c->input(i);
- for (int j = 0; j < c->Rank(input_shape); j++) {
- auto dimHandler = c->Dim(input_shape, j);
+ context->GetAttr("InT", &input_type);
+ for (size_t i = 0; i < context->num_inputs(); i++) {
+ // Check if input shape is legit
+ auto input_shape = context->input(i);
+ for (int j = 0; j < context->Rank(input_shape); j++) {
+ auto dim_handler = context->Dim(input_shape, j);
if (j == 0) {
- if (i == 0)
- nbBatch = c->Value(dimHandler);
- else if (nbBatch != c->Value(dimHandler))
+ if (i == 0) {
+ num_batch = context->Value(dim_handler);
+ } else if (num_batch != context->Value(dim_handler)) {
// TODO(jie): TensorRT engine requires consistent batch between inputs
// tensors. Segmenter should be aware of this.
LOG(FATAL) << "TensorRT engine requires consistent batch size";
+ }
}
}
}
- // arrange input here
+ // Arrange input here
std::vector<string> input_nodes;
- c->GetAttr("input_nodes", &input_nodes);
+ context->GetAttr("input_nodes", &input_nodes);
- // arrange output here
+ // Arrange output here
std::vector<string> output_nodes;
- c->GetAttr("output_nodes", &output_nodes);
+ context->GetAttr("output_nodes", &output_nodes);
for (size_t i = 0; i < output_nodes.size(); i++) {
int binding_index = trt_engine->getBindingIndex(output_nodes[i].c_str());
ShapeHandle output_shape;
- std::vector<DimensionHandle> vecDim;
- vecDim.emplace_back(c->MakeDim(nbBatch));
+ std::vector<DimensionHandle> dim_vec;
+ dim_vec.emplace_back(context->MakeDim(num_batch));
if (binding_index != -1) {
auto dims = trt_engine->getBindingDimensions(binding_index);
- for (int j = 0; j < dims.nbDims; j++)
- vecDim.emplace_back(c->MakeDim(dims.d[j]));
+ for (int j = 0; j < dims.nbDims; j++) {
+ dim_vec.emplace_back(context->MakeDim(dims.d[j]));
+ }
} else {
LOG(FATAL) << "TensorRT engine cannot find binding: " << output_nodes[i];
}
- output_shape = c->MakeShape(vecDim);
- c->set_output(i, output_shape);
+ output_shape = context->MakeShape(dim_vec);
+ context->set_output(i, output_shape);
}
return Status::OK();
}
+
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h
index f09b261139..9ca4ad0d55 100644
--- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h
+++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h
@@ -16,8 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
#define TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
-#include "tensorflow/core/framework/shape_inference.h"
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace shape_inference {
@@ -25,4 +27,7 @@ Status TRTEngineOpShapeInference(InferenceContext* c);
} // namespace shape_inference
} // namespace tensorflow
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
#endif // TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i
index 38cdabdff0..a7c7e5bc9f 100644
--- a/tensorflow/contrib/tensorrt/trt_conversion.i
+++ b/tensorflow/contrib/tensorrt/trt_conversion.i
@@ -1,8 +1,19 @@
-/*
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- wrap trt_conversion
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- */
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/* wrap trt_conversion */
%{
#define SWIG_FILE_WITH_INIT
%}
@@ -25,60 +36,65 @@
%unignore trt_convert;
%{
- std::pair<string,string> trt_convert(string graph_def_string,//const tensorflow::GraphDef&
- std::vector<string> output_names,
- size_t max_batch_size,
- size_t max_workspace_size
- // unfortunately we can't use TF_Status here since it
- // is in c/c_api and brings in a lot of other libraries
- // which in turn declare ops. These ops are included
- // statically in our library and cause an abort when
- // module is loaded due to double registration
- // until Tensorflow properly exposes these headers
- // we have to work around this by returning a string
- // and converting it to exception on python side.
- //,TF_Status* out_status) {
- ) {
- string out_status;
+std::pair<string, string> trt_convert(
+ string graph_def_string, // The serialized GraphDef string.
+ std::vector<string> output_names,
+ size_t max_batch_size,
+ size_t max_workspace_size
+ // Unfortunately we can't use TF_Status here since it
+ // is in c/c_api and brings in a lot of other libraries
+ // which in turn declare ops. These ops are included
+ // statically in our library and cause an abort when
+ // module is loaded due to double registration
+ // until Tensorflow properly exposes these headers
+ // we have to work around this by returning a string
+ // and converting it to exception on python side.
+ //,TF_Status* out_status) {
+) {
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+ string out_status;
- tensorflow::GraphDef graph_def;
- if (!graph_def.ParseFromString(graph_def_string)) {
- out_status="InvalidArgument;Couldn't interpret input as a GraphDef";
- return std::pair<string,string>{out_status,""};
- }
+ tensorflow::GraphDef graph_def;
+ if (!graph_def.ParseFromString(graph_def_string)) {
+ out_status = "InvalidArgument;Couldn't interpret input as a GraphDef";
+ return std::pair<string, string>{out_status, ""};
+ }
- if (!output_names.size()) {
- out_status="InvalidArgument;Size of the output_names vector is 0";
- return std::pair<string,string>{out_status,""};
- //return "";
- }
- tensorflow::GraphDef outGraph;
- tensorflow::Status conversion_status =
- tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT(graph_def,
- output_names,
- max_batch_size,
- max_workspace_size,
- &outGraph);
- if (!conversion_status.ok()) {
- auto retCode=(int)conversion_status.code();
- char buff[2000];
- snprintf(buff,2000,"%d;%s",retCode,conversion_status.error_message().c_str());
- out_status=buff;
- return std::pair<string,string>{out_status,""};
- }
- string result;
- if (!outGraph.SerializeToString(&result)) {
- out_status="InvalidArgument;Couldn't serialize output as a GraphDef";
- return std::pair<string,string>{out_status,""};
- }
- out_status="OK;All good!";
- return std::pair<string,string>{out_status,result};
+ if (!output_names.size()) {
+ out_status = "InvalidArgument;Size of the output_names vector is 0";
+ return std::pair<string, string>{out_status, ""};
+ // return "";
+ }
+ tensorflow::GraphDef outGraph;
+ tensorflow::Status conversion_status =
+ tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT(
+ graph_def, output_names, max_batch_size, max_workspace_size,
+ &outGraph);
+ if (!conversion_status.ok()) {
+ auto retCode = (int)conversion_status.code();
+ char buff[2000];
+ snprintf(buff, 2000, "%d;%s", retCode,
+ conversion_status.error_message().c_str());
+ out_status = buff;
+ return std::pair<string, string>{out_status, ""};
+ }
+ string result;
+ if (!outGraph.SerializeToString(&result)) {
+ out_status = "InvalidArgument;Couldn't serialize output as a GraphDef";
+ return std::pair<string, string>{out_status, ""};
}
+ out_status = "OK;All good!";
+ return std::pair<string, string>{out_status, result};
+#else
+ // Returns FAILED_PRECONDITION.
+ return std::pair<string, string>{"9;TensorRT is not enabled!", ""};
+#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
+}
%}
-std::pair<string,string> trt_convert(string graph_def_string,
- std::vector<string> output_names,
- size_t max_batch_size,
- size_t max_workspace_size);
+std::pair<string, string> trt_convert(string graph_def_string,
+ std::vector<string> output_names,
+ size_t max_batch_size,
+ size_t max_workspace_size);
%unignoreall