diff options
author | gracehoney <31743510+aaroey@users.noreply.github.com> | 2018-06-26 12:11:10 -0700 |
---|---|---|
committer | gracehoney <31743510+aaroey@users.noreply.github.com> | 2018-06-26 12:11:10 -0700 |
commit | e77098094d401435f4d038d93724dd841e0469c1 (patch) | |
tree | a0d19b4b39839c3ad0b5090dfc36894b814a0c54 /tensorflow/contrib/tensorrt/trt_conversion.i | |
parent | 6f251e4a8b6b5997dafbd178e7b8dac274008688 (diff) |
Use tf_optimizer.OptimizeGraph() to reimplement create_inference_graph() method, and fix a bug in GetDeviceAndAllocator() where it doesn't set the cuda_device_id even if the device is found.
Diffstat (limited to 'tensorflow/contrib/tensorrt/trt_conversion.i')
-rw-r--r-- | tensorflow/contrib/tensorrt/trt_conversion.i | 77 |
1 files changed, 2 insertions, 75 deletions
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i index d6628cd1eb..d23cba5e3b 100644 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -104,77 +104,12 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); %ignoreall %unignore tensorflow; -%unignore trt_convert; %unignore calib_convert; %unignore get_linked_tensorrt_version; %unignore get_loaded_tensorrt_version; %{ -std::pair<string, string> trt_convert( - string graph_def_string, // The serialized GraphDef string. - std::vector<string> output_names, - size_t max_batch_size, - size_t max_workspace_size_bytes, - int precision_mode, - int minimum_segment_size, - bool is_dyn_op, - int max_cached_engines, - std::vector<int> cached_engine_batches - // Unfortunately we can't use TF_Status here since it - // is in c/c_api and brings in a lot of other libraries - // which in turn declare ops. These ops are included - // statically in our library and cause an abort when - // module is loaded due to double registration - // until Tensorflow properly exposes these headers - // we have to work around this by returning a string - // and converting it to exception on python side. - //,TF_Status* out_status) { -) { -#if GOOGLE_CUDA && GOOGLE_TENSORRT - string out_status; - - tensorflow::GraphDef graph_def; - if (!graph_def.ParseFromString(graph_def_string)) { - out_status = "InvalidArgument;Couldn't interpret input as a GraphDef"; - return std::pair<string, string>{out_status, ""}; - } - - if(precision_mode < 0 || precision_mode > 2){ - out_status = "InvalidArgument;Invalid precision_mode"; - return std::pair<string, string>{out_status, ""}; - } - if (!output_names.size()) { - out_status = "InvalidArgument;Size of the output_names vector is 0"; - return std::pair<string, string>{out_status, ""}; - } - tensorflow::GraphDef out_graph; - tensorflow::Status conversion_status = - tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT( - graph_def, output_names, max_batch_size, max_workspace_size_bytes, - &out_graph, precision_mode, minimum_segment_size, - is_dyn_op, max_cached_engines, cached_engine_batches); - if (!conversion_status.ok()) { - auto retCode = (int)conversion_status.code(); - char buff[2000]; - snprintf(buff, 2000, "%d;%s", retCode, - conversion_status.error_message().c_str()); - out_status = buff; - return std::pair<string, string>{out_status, ""}; - } - string result; - if (!out_graph.SerializeToString(&result)) { - out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; - return std::pair<string, string>{out_status, ""}; - } - out_status = "OK;All good!"; - return std::pair<string, string>{out_status, result}; -#else - // Returns FAILED_PRECONDITION. - return std::pair<string, string>{"9;TensorRT is not enabled!", ""}; -#endif // GOOGLE_CUDA && GOOGLE_TENSORRT -} - std::pair<string, string> calib_convert( string graph_def_string, bool is_dyn_op // unfortunately we can't use TF_Status here since it @@ -246,16 +181,8 @@ version_struct get_loaded_tensorrt_version(){ %} -std::pair<string, string> calib_convert(string graph_def_string, bool is_dyn_op); - -std::pair<string, string> trt_convert(string graph_def_string, - std::vector<string> output_names, - size_t max_batch_size, - size_t max_workspace_size_bytes, - int precision_mode, int minimum_segment_size, - bool is_dyn_op, - int max_cached_engines, - std::vector<int> cached_engine_batches); +std::pair<string, string> calib_convert( + string graph_def_string, bool is_dyn_op); version_struct get_linked_tensorrt_version(); version_struct get_loaded_tensorrt_version(); |