diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/trt_conversion.i')
-rw-r--r-- | tensorflow/contrib/tensorrt/trt_conversion.i | 63 |
1 files changed, 59 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i index d679945d56..46480e99a1 100644 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -64,13 +64,17 @@ PyObject* pair_helper(std::pair<string, string>* in) { %ignoreall %unignore tensorflow; %unignore trt_convert; +%unignore calib_convert; %{ + std::pair<string, string> trt_convert( string graph_def_string, // The serialized GraphDef string. std::vector<string> output_names, size_t max_batch_size, - size_t max_workspace_size_bytes + size_t max_workspace_size_bytes, + int precision_mode, + int minimum_segment_size // Unfortunately we can't use TF_Status here since it // is in c/c_api and brings in a lot of other libraries // which in turn declare ops. These ops are included @@ -90,16 +94,64 @@ std::pair<string, string> trt_convert( return std::pair<string, string>{out_status, ""}; } + if(precision_mode < 0 || precision_mode > 2){ + out_status = "InvalidArgument;Invalid precision_mode"; + return std::pair<string, string>{out_status, ""}; + } if (!output_names.size()) { out_status = "InvalidArgument;Size of the output_names vector is 0"; return std::pair<string, string>{out_status, ""}; - // return ""; } tensorflow::GraphDef outGraph; tensorflow::Status conversion_status = tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT( graph_def, output_names, max_batch_size, max_workspace_size_bytes, - &outGraph); + &outGraph, precision_mode, minimum_segment_size); + if (!conversion_status.ok()) { + auto retCode = (int)conversion_status.code(); + char buff[2000]; + snprintf(buff, 2000, "%d;%s", retCode, + conversion_status.error_message().c_str()); + out_status = buff; + return std::pair<string, string>{out_status, ""}; + } + string result; + if (!outGraph.SerializeToString(&result)) { + out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; + return std::pair<string, string>{out_status, ""}; + } + out_status = "OK;All good!"; + return std::pair<string, string>{out_status, result}; +#else + // Returns FAILED_PRECONDITION. + return std::pair<string, string>{"9;TensorRT is not enabled!", ""}; +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +} + +std::pair<string, string> calib_convert(string graph_def_string // const tensorflow::GraphDef& + // unfortunately we can't use TF_Status here since it + // is in c/c_api and brings in a lot of other libraries + // which in turn declare ops. These ops are included + // statically in our library and cause an abort when + // module is loaded due to double registration + // until Tensorflow properly exposes these headers + // we have to work around this by returning a string + // and converting it to exception on python side. + //,TF_Status* out_status) { +) { +#if GOOGLE_CUDA && GOOGLE_TENSORRT + string out_status; + + tensorflow::GraphDef graph_def; + if (!graph_def.ParseFromString(graph_def_string)) { + out_status = "InvalidArgument;Couldn't interpret input as a GraphDef"; + return std::pair<string, string>{out_status, ""}; + } + + tensorflow::GraphDef outGraph; + tensorflow::Status conversion_status = + tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph(graph_def, + &outGraph); if (!conversion_status.ok()) { auto retCode = (int)conversion_status.code(); char buff[2000]; @@ -122,10 +174,13 @@ std::pair<string, string> trt_convert( } %} +std::pair<string, string> calib_convert(string graph_def_string); + std::pair<string, string> trt_convert(string graph_def_string, std::vector<string> output_names, size_t max_batch_size, - size_t max_workspace_size_bytes); + size_t max_workspace_size_bytes, + int precision_mode, int minimum_segment_size); %unignoreall |