aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc')
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc39
1 files changed, 11 insertions, 28 deletions
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index b32371b642..8efdf63ebe 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -24,12 +24,8 @@ limitations under the License.
#include "cuda/include/cuda_runtime_api.h"
namespace tensorflow {
-static ::tensorflow::tensorrt::Logger logger;
-namespace gpu = ::perftools::gputools;
-using IRuntime = nvinfer1::IRuntime;
-using Dims = nvinfer1::Dims;
-
namespace tensorrt {
+static ::tensorflow::tensorrt::Logger logger;
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
// read serialized_engine
@@ -44,21 +40,10 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
// TODO(samikama) runtime should be taken from a resourcemanager as well.
// Only engine should be in the op and context and runtime should be taken
// from resourcemanager
- // TODO(jie): cudaSetDevice make sure trt engine is allocated on the same
- // gpu where the input/output is also located.
- int gpu_id = context->device()->tensorflow_gpu_device_info()->gpu_id;
- cudaSetDevice(gpu_id);
- int device;
- cudaGetDevice(&device);
- if (gpu_id != device) LOG(FATAL) << "set device failed!";
-
- // TODO(samikama) runtime should be taken from a resourcemanager as well.
- // Only engine should be in the op and context and runtime should be taken
- // from resourcemanager
-
- IRuntime* infer = nvinfer1::createInferRuntime(logger);
+ nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
trt_engine_ptr_.reset(infer->deserializeCudaEngine(
serialized_engine.c_str(), serialized_engine.size(), nullptr));
+
trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
// Runtime is safe to delete after engine creation
infer->destroy();
@@ -70,6 +55,7 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
size_t binding_index;
int num_batch = 0;
+ bool valid = true;
for (int i = 0; i < context->num_inputs(); i++) {
// Grab the input tensor
binding_index = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str());
@@ -78,12 +64,8 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
const TensorShape& input_shape = input_tensor.shape();
if (i == 0) {
num_batch = input_shape.dim_size(0);
- if (num_batch > trt_engine_ptr_->getMaxBatchSize()) {
- LOG(FATAL) << "input tensor batch larger than max_batch_size: "
- << trt_engine_ptr_->getMaxBatchSize();
- }
} else if (num_batch != input_shape.dim_size(0)) {
- LOG(FATAL) << "input data inconsistent batch size";
+ valid = false;
break;
}
switch (trt_engine_ptr_->getBindingDataType(binding_index)) {
@@ -99,6 +81,9 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
}
}
+ // Might want a different way to inform the user of batch size inconsistency
+ if (!valid) LOG(WARNING) << "input data inconsistent batch size";
+
for (int i = 0; i < static_cast<int>(output_nodes_.size()); i++) {
// This is bad that we have to reallocate output buffer every run.
// Create an output tensor
@@ -141,11 +126,9 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
->implementation()
->CudaStreamMemberHack()));
- // TODO(jie): trt enqueue does not return error
- auto ret = trt_execution_context_ptr_->enqueue(num_batch, &buffers[0],
- *stream, nullptr);
- VLOG(2) << "enqueue returns: " << ret;
- // sync should be done by TF.
+ // execution handled by TF since we are getting stream from TF.
+ // it is safe for CPU pointer array (buffers) to go out of scope after enqueue
+ trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], *stream, nullptr);
}
REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);