aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/tensorrt_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/tensorrt_test.cc')
-rw-r--r--tensorflow/contrib/tensorrt/tensorrt_test.cc9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/contrib/tensorrt/tensorrt_test.cc b/tensorflow/contrib/tensorrt/tensorrt_test.cc
index 3712a9a6fe..769982c645 100644
--- a/tensorflow/contrib/tensorrt/tensorrt_test.cc
+++ b/tensorflow/contrib/tensorrt/tensorrt_test.cc
@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/test.h"
#if GOOGLE_CUDA
@@ -130,6 +132,13 @@ void Execute(nvinfer1::IExecutionContext* context, const float* input,
}
TEST(TensorrtTest, BasicFunctions) {
+ // Handle the case where the test is run on machine with no gpu available.
+ if (CHECK_NOTNULL(GPUMachineManager())->VisibleDeviceCount() <= 0) {
+ LOG(WARNING) << "No gpu device available, probably not being run on a gpu "
+ "machine. Skipping...";
+ return;
+ }
+
// Create the network model.
nvinfer1::IHostMemory* model = CreateNetwork();
// Use the model to create an engine and then an execution context.