diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/trt_conversion.i')
-rw-r--r-- | tensorflow/contrib/tensorrt/trt_conversion.i | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i index 3b1a18f8ac..6ea15fb8ef 100644 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -101,6 +101,7 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); #include "tensorflow/core/util/stat_summarizer.h" #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" #include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" %} %ignoreall @@ -109,6 +110,10 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); %unignore get_linked_tensorrt_version; %unignore get_loaded_tensorrt_version; %unignore is_tensorrt_enabled; +%unignore enable_test_value; +%unignore clear_test_values; +%unignore add_test_value; +%unignore get_test_value; %{ @@ -186,6 +191,34 @@ bool is_tensorrt_enabled() { return tensorflow::tensorrt::IsGoogleTensorRTEnabled(); } +void enable_test_value() { + tensorflow::tensorrt::test::EnableTestValue(); +} + +#if PY_MAJOR_VERSION < 3 +#define TRT_PY_TO_CPP_STRING PyString_AsString +#define TRT_CPP_TO_PY_STRING PyString_FromString +#else +#define TRT_PY_TO_CPP_STRING PyUnicode_AsUTF8 +#define TRT_CPP_TO_PY_STRING PyUnicode_FromString +#endif + +void clear_test_values(PyObject* pattern) { + tensorflow::tensorrt::test::ClearTestValues( + string(TRT_PY_TO_CPP_STRING(pattern))); +} + +void add_test_value(PyObject* label, PyObject* value) { + tensorflow::tensorrt::test::AddTestValue( + string(TRT_PY_TO_CPP_STRING(label)), string(TRT_PY_TO_CPP_STRING(value))); +} + +PyObject* get_test_value(PyObject* label) { + string value = tensorflow::tensorrt::test::GetTestValue( + string(TRT_PY_TO_CPP_STRING(label))); + return TRT_CPP_TO_PY_STRING(value.c_str()); +} + %} std::pair<string, string> calib_convert( @@ -193,5 +226,9 @@ std::pair<string, string> calib_convert( version_struct get_linked_tensorrt_version(); version_struct get_loaded_tensorrt_version(); bool is_tensorrt_enabled(); +void enable_test_value(); +void clear_test_values(PyObject* pattern); +void add_test_value(PyObject* label, PyObject* value); +PyObject* get_test_value(PyObject* label); %unignoreall |