aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt
diff options
context:
space:
mode:
authorGravatar Guangda Lai <laigd@google.com>2018-07-25 11:02:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-25 11:08:10 -0700
commit0bc512505957e3685305b6a850f222c6eed88c7d (patch)
tree4b968fe5ce802554eaf53d8be01758e775ec3839 /tensorflow/contrib/tensorrt
parentebe8a6fba27a357117d0ba154197b02d6a8b4ffb (diff)
Enable TensorRT build.
PiperOrigin-RevId: 206020981
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r--tensorflow/contrib/tensorrt/BUILD8
-rw-r--r--tensorflow/contrib/tensorrt/convert/utils.cc2
-rw-r--r--tensorflow/contrib/tensorrt/tensorrt_test.cc9
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py1
4 files changed, 16 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 2fe1f2c242..5889fd5aaf 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -37,7 +37,9 @@ tf_cuda_cc_test(
"nomac",
],
deps = [
+ "//tensorflow/core:gpu_init",
"//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
] + if_tensorrt([
@@ -384,12 +386,12 @@ cuda_py_tests(
"test/base_test.py",
# "test/batch_matmul_test.py",
# "test/biasadd_matmul_test.py",
- "test/binary_tensor_weight_broadcast_test.py",
- "test/concatenation_test.py",
+ # "test/binary_tensor_weight_broadcast_test.py", # Blocked by trt4 installation
+ # "test/concatenation_test.py", # Blocked by trt4 installation
"test/const_broadcast_test.py",
"test/multi_connection_neighbor_engine_test.py",
"test/neighboring_engine_test.py",
- "test/unary_test.py",
+ # "test/unary_test.py", # Blocked by trt4 installation
# "test/vgg_block_nchw_test.py",
# "test/vgg_block_test.py",
],
diff --git a/tensorflow/contrib/tensorrt/convert/utils.cc b/tensorflow/contrib/tensorrt/convert/utils.cc
index 24591cf84b..17857cf4d0 100644
--- a/tensorflow/contrib/tensorrt/convert/utils.cc
+++ b/tensorflow/contrib/tensorrt/convert/utils.cc
@@ -24,7 +24,7 @@ bool IsGoogleTensorRTEnabled() {
// safely write code that uses tensorrt conditionally. E.g. if it does not
// check for for tensorrt, and user mistakenly uses tensorrt, they will just
// crash and burn.
-#ifdef GOOGLE_TENSORRT
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
return true;
#else
return false;
diff --git a/tensorflow/contrib/tensorrt/tensorrt_test.cc b/tensorflow/contrib/tensorrt/tensorrt_test.cc
index 3712a9a6fe..769982c645 100644
--- a/tensorflow/contrib/tensorrt/tensorrt_test.cc
+++ b/tensorflow/contrib/tensorrt/tensorrt_test.cc
@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/test.h"
#if GOOGLE_CUDA
@@ -130,6 +132,13 @@ void Execute(nvinfer1::IExecutionContext* context, const float* input,
}
TEST(TensorrtTest, BasicFunctions) {
+ // Handle the case where the test is run on machine with no gpu available.
+ if (CHECK_NOTNULL(GPUMachineManager())->VisibleDeviceCount() <= 0) {
+ LOG(WARNING) << "No gpu device available, probably not being run on a gpu "
+ "machine. Skipping...";
+ return;
+ }
+
// Create the network model.
nvinfer1::IHostMemory* model = CreateNetwork();
// Use the model to create an engine and then an execution context.
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index 60b8eb6e81..bb7f5a77f0 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -107,6 +107,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
graph_options = config_pb2.GraphOptions()
gpu_options = config_pb2.GPUOptions()
+ gpu_options.allow_growth = True
if trt_convert.get_linked_tensorrt_version()[0] == 3:
gpu_options.per_process_gpu_memory_fraction = 0.50