aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc')
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc39
1 files changed, 28 insertions, 11 deletions
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 8efdf63ebe..b32371b642 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -24,8 +24,12 @@ limitations under the License.
#include "cuda/include/cuda_runtime_api.h"
namespace tensorflow {
-namespace tensorrt {
static ::tensorflow::tensorrt::Logger logger;
+namespace gpu = ::perftools::gputools;
+using IRuntime = nvinfer1::IRuntime;
+using Dims = nvinfer1::Dims;
+
+namespace tensorrt {
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
// read serialized_engine
@@ -40,10 +44,21 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
// TODO(samikama) runtime should be taken from a resourcemanager as well.
// Only engine should be in the op and context and runtime should be taken
// from resourcemanager
- nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
+ // TODO(jie): cudaSetDevice make sure trt engine is allocated on the same
+ // gpu where the input/output is also located.
+ int gpu_id = context->device()->tensorflow_gpu_device_info()->gpu_id;
+ cudaSetDevice(gpu_id);
+ int device;
+ cudaGetDevice(&device);
+ if (gpu_id != device) LOG(FATAL) << "set device failed!";
+
+ // TODO(samikama) runtime should be taken from a resourcemanager as well.
+ // Only engine should be in the op and context and runtime should be taken
+ // from resourcemanager
+
+ IRuntime* infer = nvinfer1::createInferRuntime(logger);
trt_engine_ptr_.reset(infer->deserializeCudaEngine(
serialized_engine.c_str(), serialized_engine.size(), nullptr));
-
trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
// Runtime is safe to delete after engine creation
infer->destroy();
@@ -55,7 +70,6 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
size_t binding_index;
int num_batch = 0;
- bool valid = true;
for (int i = 0; i < context->num_inputs(); i++) {
// Grab the input tensor
binding_index = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str());
@@ -64,8 +78,12 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
const TensorShape& input_shape = input_tensor.shape();
if (i == 0) {
num_batch = input_shape.dim_size(0);
+ if (num_batch > trt_engine_ptr_->getMaxBatchSize()) {
+ LOG(FATAL) << "input tensor batch larger than max_batch_size: "
+ << trt_engine_ptr_->getMaxBatchSize();
+ }
} else if (num_batch != input_shape.dim_size(0)) {
- valid = false;
+ LOG(FATAL) << "input data inconsistent batch size";
break;
}
switch (trt_engine_ptr_->getBindingDataType(binding_index)) {
@@ -81,9 +99,6 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
}
}
- // Might want a different way to inform the user of batch size inconsistency
- if (!valid) LOG(WARNING) << "input data inconsistent batch size";
-
for (int i = 0; i < static_cast<int>(output_nodes_.size()); i++) {
// This is bad that we have to reallocate output buffer every run.
// Create an output tensor
@@ -126,9 +141,11 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
->implementation()
->CudaStreamMemberHack()));
- // execution handled by TF since we are getting stream from TF.
- // it is safe for CPU pointer array (buffers) to go out of scope after enqueue
- trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], *stream, nullptr);
+ // TODO(jie): trt enqueue does not return error
+ auto ret = trt_execution_context_ptr_->enqueue(num_batch, &buffers[0],
+ *stream, nullptr);
+ VLOG(2) << "enqueue returns: " << ret;
+ // sync should be done by TF.
}
REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);