diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc')
-rw-r--r-- | tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc | 39 |
1 files changed, 11 insertions, 28 deletions
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index b32371b642..8efdf63ebe 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -24,12 +24,8 @@ limitations under the License. #include "cuda/include/cuda_runtime_api.h" namespace tensorflow { -static ::tensorflow::tensorrt::Logger logger; -namespace gpu = ::perftools::gputools; -using IRuntime = nvinfer1::IRuntime; -using Dims = nvinfer1::Dims; - namespace tensorrt { +static ::tensorflow::tensorrt::Logger logger; TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) { // read serialized_engine @@ -44,21 +40,10 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) { // TODO(samikama) runtime should be taken from a resourcemanager as well. // Only engine should be in the op and context and runtime should be taken // from resourcemanager - // TODO(jie): cudaSetDevice make sure trt engine is allocated on the same - // gpu where the input/output is also located. - int gpu_id = context->device()->tensorflow_gpu_device_info()->gpu_id; - cudaSetDevice(gpu_id); - int device; - cudaGetDevice(&device); - if (gpu_id != device) LOG(FATAL) << "set device failed!"; - - // TODO(samikama) runtime should be taken from a resourcemanager as well. - // Only engine should be in the op and context and runtime should be taken - // from resourcemanager - - IRuntime* infer = nvinfer1::createInferRuntime(logger); + nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger); trt_engine_ptr_.reset(infer->deserializeCudaEngine( serialized_engine.c_str(), serialized_engine.size(), nullptr)); + trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext()); // Runtime is safe to delete after engine creation infer->destroy(); @@ -70,6 +55,7 @@ void TRTEngineOp::Compute(OpKernelContext* context) { size_t binding_index; int num_batch = 0; + bool valid = true; for (int i = 0; i < context->num_inputs(); i++) { // Grab the input tensor binding_index = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str()); @@ -78,12 +64,8 @@ void TRTEngineOp::Compute(OpKernelContext* context) { const TensorShape& input_shape = input_tensor.shape(); if (i == 0) { num_batch = input_shape.dim_size(0); - if (num_batch > trt_engine_ptr_->getMaxBatchSize()) { - LOG(FATAL) << "input tensor batch larger than max_batch_size: " - << trt_engine_ptr_->getMaxBatchSize(); - } } else if (num_batch != input_shape.dim_size(0)) { - LOG(FATAL) << "input data inconsistent batch size"; + valid = false; break; } switch (trt_engine_ptr_->getBindingDataType(binding_index)) { @@ -99,6 +81,9 @@ void TRTEngineOp::Compute(OpKernelContext* context) { } } + // Might want a different way to inform the user of batch size inconsistency + if (!valid) LOG(WARNING) << "input data inconsistent batch size"; + for (int i = 0; i < static_cast<int>(output_nodes_.size()); i++) { // This is bad that we have to reallocate output buffer every run. // Create an output tensor @@ -141,11 +126,9 @@ void TRTEngineOp::Compute(OpKernelContext* context) { ->implementation() ->CudaStreamMemberHack())); - // TODO(jie): trt enqueue does not return error - auto ret = trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], - *stream, nullptr); - VLOG(2) << "enqueue returns: " << ret; - // sync should be done by TF. + // execution handled by TF since we are getting stream from TF. + // it is safe for CPU pointer array (buffers) to go out of scope after enqueue + trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], *stream, nullptr); } REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp); |