aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/trt_conversion.i
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/trt_conversion.i')
-rw-r--r--tensorflow/contrib/tensorrt/trt_conversion.i63
1 files changed, 59 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i
index d679945d56..46480e99a1 100644
--- a/tensorflow/contrib/tensorrt/trt_conversion.i
+++ b/tensorflow/contrib/tensorrt/trt_conversion.i
@@ -64,13 +64,17 @@ PyObject* pair_helper(std::pair<string, string>* in) {
%ignoreall
%unignore tensorflow;
%unignore trt_convert;
+%unignore calib_convert;
%{
+
std::pair<string, string> trt_convert(
string graph_def_string, // The serialized GraphDef string.
std::vector<string> output_names,
size_t max_batch_size,
- size_t max_workspace_size_bytes
+ size_t max_workspace_size_bytes,
+ int precision_mode,
+ int minimum_segment_size
// Unfortunately we can't use TF_Status here since it
// is in c/c_api and brings in a lot of other libraries
// which in turn declare ops. These ops are included
@@ -90,16 +94,64 @@ std::pair<string, string> trt_convert(
return std::pair<string, string>{out_status, ""};
}
+ if(precision_mode < 0 || precision_mode > 2){
+ out_status = "InvalidArgument;Invalid precision_mode";
+ return std::pair<string, string>{out_status, ""};
+ }
if (!output_names.size()) {
out_status = "InvalidArgument;Size of the output_names vector is 0";
return std::pair<string, string>{out_status, ""};
- // return "";
}
tensorflow::GraphDef outGraph;
tensorflow::Status conversion_status =
tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT(
graph_def, output_names, max_batch_size, max_workspace_size_bytes,
- &outGraph);
+ &outGraph, precision_mode, minimum_segment_size);
+ if (!conversion_status.ok()) {
+ auto retCode = (int)conversion_status.code();
+ char buff[2000];
+ snprintf(buff, 2000, "%d;%s", retCode,
+ conversion_status.error_message().c_str());
+ out_status = buff;
+ return std::pair<string, string>{out_status, ""};
+ }
+ string result;
+ if (!outGraph.SerializeToString(&result)) {
+ out_status = "InvalidArgument;Couldn't serialize output as a GraphDef";
+ return std::pair<string, string>{out_status, ""};
+ }
+ out_status = "OK;All good!";
+ return std::pair<string, string>{out_status, result};
+#else
+ // Returns FAILED_PRECONDITION.
+ return std::pair<string, string>{"9;TensorRT is not enabled!", ""};
+#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
+}
+
+std::pair<string, string> calib_convert(string graph_def_string // const tensorflow::GraphDef&
+ // unfortunately we can't use TF_Status here since it
+ // is in c/c_api and brings in a lot of other libraries
+ // which in turn declare ops. These ops are included
+ // statically in our library and cause an abort when
+ // module is loaded due to double registration
+ // until Tensorflow properly exposes these headers
+ // we have to work around this by returning a string
+ // and converting it to exception on python side.
+ //,TF_Status* out_status) {
+) {
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+ string out_status;
+
+ tensorflow::GraphDef graph_def;
+ if (!graph_def.ParseFromString(graph_def_string)) {
+ out_status = "InvalidArgument;Couldn't interpret input as a GraphDef";
+ return std::pair<string, string>{out_status, ""};
+ }
+
+ tensorflow::GraphDef outGraph;
+ tensorflow::Status conversion_status =
+ tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph(graph_def,
+ &outGraph);
if (!conversion_status.ok()) {
auto retCode = (int)conversion_status.code();
char buff[2000];
@@ -122,10 +174,13 @@ std::pair<string, string> trt_convert(
}
%}
+std::pair<string, string> calib_convert(string graph_def_string);
+
std::pair<string, string> trt_convert(string graph_def_string,
std::vector<string> output_names,
size_t max_batch_size,
- size_t max_workspace_size_bytes);
+ size_t max_workspace_size_bytes,
+ int precision_mode, int minimum_segment_size);
%unignoreall