aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/trt_conversion.i
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/trt_conversion.i')
-rw-r--r--tensorflow/contrib/tensorrt/trt_conversion.i12
1 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i
index d6628cd1eb..422740fdf6 100644
--- a/tensorflow/contrib/tensorrt/trt_conversion.i
+++ b/tensorflow/contrib/tensorrt/trt_conversion.i
@@ -100,6 +100,7 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong);
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/stat_summarizer.h"
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+#include "tensorflow/contrib/tensorrt/convert/utils.h"
%}
%ignoreall
@@ -108,6 +109,7 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong);
%unignore calib_convert;
%unignore get_linked_tensorrt_version;
%unignore get_loaded_tensorrt_version;
+%unignore is_tensorrt_enabled;
%{
@@ -140,7 +142,7 @@ std::pair<string, string> trt_convert(
return std::pair<string, string>{out_status, ""};
}
- if(precision_mode < 0 || precision_mode > 2){
+ if (precision_mode < 0 || precision_mode > 2) {
out_status = "InvalidArgument;Invalid precision_mode";
return std::pair<string, string>{out_status, ""};
}
@@ -232,7 +234,8 @@ version_struct get_linked_tensorrt_version() {
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
return s;
}
-version_struct get_loaded_tensorrt_version(){
+
+version_struct get_loaded_tensorrt_version() {
// Return the version from the loaded library.
version_struct s;
#if GOOGLE_CUDA && GOOGLE_TENSORRT
@@ -244,6 +247,10 @@ version_struct get_loaded_tensorrt_version(){
return s;
}
+bool is_tensorrt_enabled() {
+ return tensorflow::tensorrt::IsGoogleTensorRTEnabled();
+}
+
%}
std::pair<string, string> calib_convert(string graph_def_string, bool is_dyn_op);
@@ -258,5 +265,6 @@ std::pair<string, string> trt_convert(string graph_def_string,
std::vector<int> cached_engine_batches);
version_struct get_linked_tensorrt_version();
version_struct get_loaded_tensorrt_version();
+bool is_tensorrt_enabled();
%unignoreall