aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/trt_conversion.i
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-06-26 12:11:10 -0700
committerGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-06-26 12:11:10 -0700
commite77098094d401435f4d038d93724dd841e0469c1 (patch)
treea0d19b4b39839c3ad0b5090dfc36894b814a0c54 /tensorflow/contrib/tensorrt/trt_conversion.i
parent6f251e4a8b6b5997dafbd178e7b8dac274008688 (diff)
Use tf_optimizer.OptimizeGraph() to reimplement create_inference_graph() method, and fix a bug in GetDeviceAndAllocator() where it doesn't set the cuda_device_id even if the device is found.
Diffstat (limited to 'tensorflow/contrib/tensorrt/trt_conversion.i')
-rw-r--r--tensorflow/contrib/tensorrt/trt_conversion.i77
1 files changed, 2 insertions, 75 deletions
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i
index d6628cd1eb..d23cba5e3b 100644
--- a/tensorflow/contrib/tensorrt/trt_conversion.i
+++ b/tensorflow/contrib/tensorrt/trt_conversion.i
@@ -104,77 +104,12 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong);
%ignoreall
%unignore tensorflow;
-%unignore trt_convert;
%unignore calib_convert;
%unignore get_linked_tensorrt_version;
%unignore get_loaded_tensorrt_version;
%{
-std::pair<string, string> trt_convert(
- string graph_def_string, // The serialized GraphDef string.
- std::vector<string> output_names,
- size_t max_batch_size,
- size_t max_workspace_size_bytes,
- int precision_mode,
- int minimum_segment_size,
- bool is_dyn_op,
- int max_cached_engines,
- std::vector<int> cached_engine_batches
- // Unfortunately we can't use TF_Status here since it
- // is in c/c_api and brings in a lot of other libraries
- // which in turn declare ops. These ops are included
- // statically in our library and cause an abort when
- // module is loaded due to double registration
- // until Tensorflow properly exposes these headers
- // we have to work around this by returning a string
- // and converting it to exception on python side.
- //,TF_Status* out_status) {
-) {
-#if GOOGLE_CUDA && GOOGLE_TENSORRT
- string out_status;
-
- tensorflow::GraphDef graph_def;
- if (!graph_def.ParseFromString(graph_def_string)) {
- out_status = "InvalidArgument;Couldn't interpret input as a GraphDef";
- return std::pair<string, string>{out_status, ""};
- }
-
- if(precision_mode < 0 || precision_mode > 2){
- out_status = "InvalidArgument;Invalid precision_mode";
- return std::pair<string, string>{out_status, ""};
- }
- if (!output_names.size()) {
- out_status = "InvalidArgument;Size of the output_names vector is 0";
- return std::pair<string, string>{out_status, ""};
- }
- tensorflow::GraphDef out_graph;
- tensorflow::Status conversion_status =
- tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT(
- graph_def, output_names, max_batch_size, max_workspace_size_bytes,
- &out_graph, precision_mode, minimum_segment_size,
- is_dyn_op, max_cached_engines, cached_engine_batches);
- if (!conversion_status.ok()) {
- auto retCode = (int)conversion_status.code();
- char buff[2000];
- snprintf(buff, 2000, "%d;%s", retCode,
- conversion_status.error_message().c_str());
- out_status = buff;
- return std::pair<string, string>{out_status, ""};
- }
- string result;
- if (!out_graph.SerializeToString(&result)) {
- out_status = "InvalidArgument;Couldn't serialize output as a GraphDef";
- return std::pair<string, string>{out_status, ""};
- }
- out_status = "OK;All good!";
- return std::pair<string, string>{out_status, result};
-#else
- // Returns FAILED_PRECONDITION.
- return std::pair<string, string>{"9;TensorRT is not enabled!", ""};
-#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
-}
-
std::pair<string, string> calib_convert(
string graph_def_string, bool is_dyn_op
// unfortunately we can't use TF_Status here since it
@@ -246,16 +181,8 @@ version_struct get_loaded_tensorrt_version(){
%}
-std::pair<string, string> calib_convert(string graph_def_string, bool is_dyn_op);
-
-std::pair<string, string> trt_convert(string graph_def_string,
- std::vector<string> output_names,
- size_t max_batch_size,
- size_t max_workspace_size_bytes,
- int precision_mode, int minimum_segment_size,
- bool is_dyn_op,
- int max_cached_engines,
- std::vector<int> cached_engine_batches);
+std::pair<string, string> calib_convert(
+ string graph_def_string, bool is_dyn_op);
version_struct get_linked_tensorrt_version();
version_struct get_loaded_tensorrt_version();