diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/trt_conversion.i')
-rw-r--r-- | tensorflow/contrib/tensorrt/trt_conversion.i | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i index d6628cd1eb..422740fdf6 100644 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -100,6 +100,7 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/stat_summarizer.h" #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/contrib/tensorrt/convert/utils.h" %} %ignoreall @@ -108,6 +109,7 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); %unignore calib_convert; %unignore get_linked_tensorrt_version; %unignore get_loaded_tensorrt_version; +%unignore is_tensorrt_enabled; %{ @@ -140,7 +142,7 @@ std::pair<string, string> trt_convert( return std::pair<string, string>{out_status, ""}; } - if(precision_mode < 0 || precision_mode > 2){ + if (precision_mode < 0 || precision_mode > 2) { out_status = "InvalidArgument;Invalid precision_mode"; return std::pair<string, string>{out_status, ""}; } @@ -232,7 +234,8 @@ version_struct get_linked_tensorrt_version() { #endif // GOOGLE_CUDA && GOOGLE_TENSORRT return s; } -version_struct get_loaded_tensorrt_version(){ + +version_struct get_loaded_tensorrt_version() { // Return the version from the loaded library. version_struct s; #if GOOGLE_CUDA && GOOGLE_TENSORRT @@ -244,6 +247,10 @@ version_struct get_loaded_tensorrt_version(){ return s; } +bool is_tensorrt_enabled() { + return tensorflow::tensorrt::IsGoogleTensorRTEnabled(); +} + %} std::pair<string, string> calib_convert(string graph_def_string, bool is_dyn_op); @@ -258,5 +265,6 @@ std::pair<string, string> trt_convert(string graph_def_string, std::vector<int> cached_engine_batches); version_struct get_linked_tensorrt_version(); version_struct get_loaded_tensorrt_version(); +bool is_tensorrt_enabled(); %unignoreall |