diff options
author | Guangda Lai <laigd@google.com> | 2018-07-25 11:02:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-25 11:08:10 -0700 |
commit | 0bc512505957e3685305b6a850f222c6eed88c7d (patch) | |
tree | 4b968fe5ce802554eaf53d8be01758e775ec3839 /tensorflow/contrib/tensorrt | |
parent | ebe8a6fba27a357117d0ba154197b02d6a8b4ffb (diff) |
Enable TensorRT build.
PiperOrigin-RevId: 206020981
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r-- | tensorflow/contrib/tensorrt/BUILD | 8 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/utils.cc | 2 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/tensorrt_test.cc | 9 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py | 1 |
4 files changed, 16 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 2fe1f2c242..5889fd5aaf 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -37,7 +37,9 @@ tf_cuda_cc_test( "nomac", ], deps = [ + "//tensorflow/core:gpu_init", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor", "//tensorflow/core:test", "//tensorflow/core:test_main", ] + if_tensorrt([ @@ -384,12 +386,12 @@ cuda_py_tests( "test/base_test.py", # "test/batch_matmul_test.py", # "test/biasadd_matmul_test.py", - "test/binary_tensor_weight_broadcast_test.py", - "test/concatenation_test.py", + # "test/binary_tensor_weight_broadcast_test.py", # Blocked by trt4 installation + # "test/concatenation_test.py", # Blocked by trt4 installation "test/const_broadcast_test.py", "test/multi_connection_neighbor_engine_test.py", "test/neighboring_engine_test.py", - "test/unary_test.py", + # "test/unary_test.py", # Blocked by trt4 installation # "test/vgg_block_nchw_test.py", # "test/vgg_block_test.py", ], diff --git a/tensorflow/contrib/tensorrt/convert/utils.cc b/tensorflow/contrib/tensorrt/convert/utils.cc index 24591cf84b..17857cf4d0 100644 --- a/tensorflow/contrib/tensorrt/convert/utils.cc +++ b/tensorflow/contrib/tensorrt/convert/utils.cc @@ -24,7 +24,7 @@ bool IsGoogleTensorRTEnabled() { // safely write code that uses tensorrt conditionally. E.g. if it does not // check for for tensorrt, and user mistakenly uses tensorrt, they will just // crash and burn. -#ifdef GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT return true; #else return false; diff --git a/tensorflow/contrib/tensorrt/tensorrt_test.cc b/tensorflow/contrib/tensorrt/tensorrt_test.cc index 3712a9a6fe..769982c645 100644 --- a/tensorflow/contrib/tensorrt/tensorrt_test.cc +++ b/tensorflow/contrib/tensorrt/tensorrt_test.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/test.h" #if GOOGLE_CUDA @@ -130,6 +132,13 @@ void Execute(nvinfer1::IExecutionContext* context, const float* input, } TEST(TensorrtTest, BasicFunctions) { + // Handle the case where the test is run on machine with no gpu available. + if (CHECK_NOTNULL(GPUMachineManager())->VisibleDeviceCount() <= 0) { + LOG(WARNING) << "No gpu device available, probably not being run on a gpu " + "machine. Skipping..."; + return; + } + // Create the network model. nvinfer1::IHostMemory* model = CreateNetwork(); // Use the model to create an engine and then an execution context. diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py index 60b8eb6e81..bb7f5a77f0 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py @@ -107,6 +107,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): graph_options = config_pb2.GraphOptions() gpu_options = config_pb2.GPUOptions() + gpu_options.allow_growth = True if trt_convert.get_linked_tensorrt_version()[0] == 3: gpu_options.per_process_gpu_memory_fraction = 0.50 |